Create LanceDB index after table is created in importAI-Northstar-Tech/vector-io#80
![Logo of Sweep](/_next/image?url=%2Flogo.png&w=64&q=75)
Create LanceDB index after table is created in import
AI-Northstar-Tech/vector-io#80
> > >
✓ Completed in 11 minutes, 2 months ago using GPT-4 • Book a call • Report a bug
Progress
Modify
src/vdf_io/import_vdf/lancedb_import.py
f168003
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13 cleanup_df,
14 divide_into_batches,
15 set_arg_from_input,
16 set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25 DB_NAME_SLUG = DBNames.LANCEDB
26
27 @classmethod
28 def import_vdb(cls, args):
29 """
30 Import data to LanceDB
31 """
32 set_arg_from_input(
33 args,
34 "endpoint",
35 "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36 str,
37 default_value="~/.lancedb",
38 )
39 set_arg_from_password(
40 args,
41 "lancedb_api_key",
42 "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43 "LANCEDB_API_KEY",
44 )
45 lancedb_import = ImportLanceDB(args)
46 lancedb_import.upsert_data()
47 return lancedb_import
48
49 @classmethod
50 def make_parser(cls, subparsers):
51 parser_lancedb = subparsers.add_parser(
52 cls.DB_NAME_SLUG, help="Import data to lancedb"
53 )
54 parser_lancedb.add_argument(
55 "--endpoint", type=str, help="Location of LanceDB instance"
56 )
57 parser_lancedb.add_argument(
58 "--lancedb_api_key", type=str, help="LanceDB API key"
59 )
60 parser_lancedb.add_argument(
61 "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62 )
63
64 def __init__(self, args):
65 # call super class constructor
66 super().__init__(args)
67 self.db = lancedb.connect(
68 self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69 )
70
71 def upsert_data(self):
72 max_hit = False
73 self.total_imported_count = 0
74 indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75 index_names: List[str] = list(indexes_content.keys())
76 if len(index_names) == 0:
77 raise ValueError("No indexes found in VDF_META.json")
78 tables = self.db.table_names()
79 # Load Parquet file
80 # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81 for index_name, index_meta in tqdm(
82 indexes_content.items(), desc="Importing indexes"
83 ):
84 for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85 self.set_dims(namespace_meta, index_name)
86 data_path = namespace_meta["data_path"]
87 final_data_path = self.get_final_data_path(data_path)
88 # Load the data from the parquet files
89 parquet_files = self.get_parquet_files(final_data_path)
90
91 new_index_name = index_name + (
92 f'_{namespace_meta["namespace"]}'
93 if namespace_meta["namespace"]
94 else ""
95 )
96 new_index_name = self.create_new_name(new_index_name, tables)
97 if new_index_name not in tables:
98 table = self.db.create_table(
99 new_index_name, schema=pq.read_schema(parquet_files[0])
100 )
101 tqdm.write(f"Created table {new_index_name}")
102 else:
103 table = self.db.open_table(new_index_name)
104 tqdm.write(f"Opened table {new_index_name}")
105
106 for file in tqdm(parquet_files, desc="Iterating parquet files"):
107 file_path = self.get_file_path(final_data_path, file)
108 df = self.read_parquet_progress(
109 file_path,
110 max_num_rows=(
111 (self.args.get("max_num_rows") or INT_MAX)
112 - self.total_imported_count
113 ),
114 )
115 df = cleanup_df(df)
116 # if there are additional columns in the parquet file, add them to the table
117 for col in df.columns:
118 if col not in [field.name for field in table.schema]:
119 col_type = df[col].dtype
120 tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121 table.add_columns(
122 {
123 col: get_default_value(col_type),
124 }
125 )
126 # split in batches
127 BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128 for batch in tqdm(
129 divide_into_batches(df, BATCH_SIZE),
130 desc="Importing batches",
131 total=len(df) // BATCH_SIZE,
132 ):
133 if self.total_imported_count + len(batch) >= (
134 self.args.get("max_num_rows") or INT_MAX
135 ):
136 batch = batch[
137 : (self.args.get("max_num_rows") or INT_MAX)
138 - self.total_imported_count
139 ]
140 max_hit = True
141 # convert df into list of dicts
142 table.add(batch)
143 self.total_imported_count += len(batch)
144 if max_hit:
145 break
146 tqdm.write(f"Imported {self.total_imported_count} rows")
147 tqdm.write(f"New table size: {table.count_rows()}")
148 if max_hit:
149 break
150 print("Data imported successfully")
151
152
153def get_default_value(data_type):
154 # Define default values for common data types
155 default_values = {
156 "object": "",
157 "int64": 0,
158 "float64": 0.0,
159 "bool": False,
160 "datetime64[ns]": pd.Timestamp("NaT"),
161 "timedelta64[ns]": pd.Timedelta("NaT"),
162 }
163 # Return the default value for the specified data type, or None if not specified
164 return default_values.get(data_type.name, None)
165
Import the LanceDB create_index
method at the top of the file:
from lancedb import create_index
Modify
src/vdf_io/import_vdf/lancedb_import.py
f168003
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13 cleanup_df,
14 divide_into_batches,
15 set_arg_from_input,
16 set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25 DB_NAME_SLUG = DBNames.LANCEDB
26
27 @classmethod
28 def import_vdb(cls, args):
29 """
30 Import data to LanceDB
31 """
32 set_arg_from_input(
33 args,
34 "endpoint",
35 "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36 str,
37 default_value="~/.lancedb",
38 )
39 set_arg_from_password(
40 args,
41 "lancedb_api_key",
42 "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43 "LANCEDB_API_KEY",
44 )
45 lancedb_import = ImportLanceDB(args)
46 lancedb_import.upsert_data()
47 return lancedb_import
48
49 @classmethod
50 def make_parser(cls, subparsers):
51 parser_lancedb = subparsers.add_parser(
52 cls.DB_NAME_SLUG, help="Import data to lancedb"
53 )
54 parser_lancedb.add_argument(
55 "--endpoint", type=str, help="Location of LanceDB instance"
56 )
57 parser_lancedb.add_argument(
58 "--lancedb_api_key", type=str, help="LanceDB API key"
59 )
60 parser_lancedb.add_argument(
61 "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62 )
63
64 def __init__(self, args):
65 # call super class constructor
66 super().__init__(args)
67 self.db = lancedb.connect(
68 self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69 )
70
71 def upsert_data(self):
72 max_hit = False
73 self.total_imported_count = 0
74 indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75 index_names: List[str] = list(indexes_content.keys())
76 if len(index_names) == 0:
77 raise ValueError("No indexes found in VDF_META.json")
78 tables = self.db.table_names()
79 # Load Parquet file
80 # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81 for index_name, index_meta in tqdm(
82 indexes_content.items(), desc="Importing indexes"
83 ):
84 for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85 self.set_dims(namespace_meta, index_name)
86 data_path = namespace_meta["data_path"]
87 final_data_path = self.get_final_data_path(data_path)
88 # Load the data from the parquet files
89 parquet_files = self.get_parquet_files(final_data_path)
90
91 new_index_name = index_name + (
92 f'_{namespace_meta["namespace"]}'
93 if namespace_meta["namespace"]
94 else ""
95 )
96 new_index_name = self.create_new_name(new_index_name, tables)
97 if new_index_name not in tables:
98 table = self.db.create_table(
99 new_index_name, schema=pq.read_schema(parquet_files[0])
100 )
101 tqdm.write(f"Created table {new_index_name}")
102 else:
103 table = self.db.open_table(new_index_name)
104 tqdm.write(f"Opened table {new_index_name}")
105
106 for file in tqdm(parquet_files, desc="Iterating parquet files"):
107 file_path = self.get_file_path(final_data_path, file)
108 df = self.read_parquet_progress(
109 file_path,
110 max_num_rows=(
111 (self.args.get("max_num_rows") or INT_MAX)
112 - self.total_imported_count
113 ),
114 )
115 df = cleanup_df(df)
116 # if there are additional columns in the parquet file, add them to the table
117 for col in df.columns:
118 if col not in [field.name for field in table.schema]:
119 col_type = df[col].dtype
120 tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121 table.add_columns(
122 {
123 col: get_default_value(col_type),
124 }
125 )
126 # split in batches
127 BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128 for batch in tqdm(
129 divide_into_batches(df, BATCH_SIZE),
130 desc="Importing batches",
131 total=len(df) // BATCH_SIZE,
132 ):
133 if self.total_imported_count + len(batch) >= (
134 self.args.get("max_num_rows") or INT_MAX
135 ):
136 batch = batch[
137 : (self.args.get("max_num_rows") or INT_MAX)
138 - self.total_imported_count
139 ]
140 max_hit = True
141 # convert df into list of dicts
142 table.add(batch)
143 self.total_imported_count += len(batch)
144 if max_hit:
145 break
146 tqdm.write(f"Imported {self.total_imported_count} rows")
147 tqdm.write(f"New table size: {table.count_rows()}")
148 if max_hit:
149 break
150 print("Data imported successfully")
151
152
153def get_default_value(data_type):
154 # Define default values for common data types
155 default_values = {
156 "object": "",
157 "int64": 0,
158 "float64": 0.0,
159 "bool": False,
160 "datetime64[ns]": pd.Timestamp("NaT"),
161 "timedelta64[ns]": pd.Timedelta("NaT"),
162 }
163 # Return the default value for the specified data type, or None if not specified
164 return default_values.get(data_type.name, None)
165
In the upsert_data
method of the ImportLanceDB
class, after the code block that creates a new table or opens an existing one, add the following to create an index on the table:
# Get the ID column from the parquet file schema
parquet_schema = pq.read_schema(parquet_files[0])
id_column = "id" # Default
for field in parquet_schema:
if field.name == ID_COLUMN:
id_column = field.name
break
# Create index on the table
create_index(table, id_column)
tqdm.write(f"Created index on {id_column} for table {new_index_name}")
This code reads the schema of the first parquet file to determine the name of the ID column (defaulting to "id" if not found). It then calls create_index
passing the table object and ID column name to create an index on that column.
Plan
This is based on the results of the Planning step. The plan may expand from failed GitHub Actions runs.
Code Snippets Found
This is based on the results of the Searching step.
src/vdf_io/import_vdf/lancedb_import.py:0-165
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13 cleanup_df,
14 divide_into_batches,
15 set_arg_from_input,
16 set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25 DB_NAME_SLUG = DBNames.LANCEDB
26
27 @classmethod
28 def import_vdb(cls, args):
29 """
30 Import data to LanceDB
31 """
32 set_arg_from_input(
33 args,
34 "endpoint",
35 "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36 str,
37 default_value="~/.lancedb",
38 )
39 set_arg_from_password(
40 args,
41 "lancedb_api_key",
42 "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43 "LANCEDB_API_KEY",
44 )
45 lancedb_import = ImportLanceDB(args)
46 lancedb_import.upsert_data()
47 return lancedb_import
48
49 @classmethod
50 def make_parser(cls, subparsers):
51 parser_lancedb = subparsers.add_parser(
52 cls.DB_NAME_SLUG, help="Import data to lancedb"
53 )
54 parser_lancedb.add_argument(
55 "--endpoint", type=str, help="Location of LanceDB instance"
56 )
57 parser_lancedb.add_argument(
58 "--lancedb_api_key", type=str, help="LanceDB API key"
59 )
60 parser_lancedb.add_argument(
61 "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62 )
63
64 def __init__(self, args):
65 # call super class constructor
66 super().__init__(args)
67 self.db = lancedb.connect(
68 self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69 )
70
71 def upsert_data(self):
72 max_hit = False
73 self.total_imported_count = 0
74 indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75 index_names: List[str] = list(indexes_content.keys())
76 if len(index_names) == 0:
77 raise ValueError("No indexes found in VDF_META.json")
78 tables = self.db.table_names()
79 # Load Parquet file
80 # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81 for index_name, index_meta in tqdm(
82 indexes_content.items(), desc="Importing indexes"
83 ):
84 for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85 self.set_dims(namespace_meta, index_name)
86 data_path = namespace_meta["data_path"]
87 final_data_path = self.get_final_data_path(data_path)
88 # Load the data from the parquet files
89 parquet_files = self.get_parquet_files(final_data_path)
90
91 new_index_name = index_name + (
92 f'_{namespace_meta["namespace"]}'
93 if namespace_meta["namespace"]
94 else ""
95 )
96 new_index_name = self.create_new_name(new_index_name, tables)
97 if new_index_name not in tables:
98 table = self.db.create_table(
99 new_index_name, schema=pq.read_schema(parquet_files[0])
100 )
101 tqdm.write(f"Created table {new_index_name}")
102 else:
103 table = self.db.open_table(new_index_name)
104 tqdm.write(f"Opened table {new_index_name}")
105
106 for file in tqdm(parquet_files, desc="Iterating parquet files"):
107 file_path = self.get_file_path(final_data_path, file)
108 df = self.read_parquet_progress(
109 file_path,
110 max_num_rows=(
111 (self.args.get("max_num_rows") or INT_MAX)
112 - self.total_imported_count
113 ),
114 )
115 df = cleanup_df(df)
116 # if there are additional columns in the parquet file, add them to the table
117 for col in df.columns:
118 if col not in [field.name for field in table.schema]:
119 col_type = df[col].dtype
120 tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121 table.add_columns(
122 {
123 col: get_default_value(col_type),
124 }
125 )
126 # split in batches
127 BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128 for batch in tqdm(
129 divide_into_batches(df, BATCH_SIZE),
130 desc="Importing batches",
131 total=len(df) // BATCH_SIZE,
132 ):
133 if self.total_imported_count + len(batch) >= (
134 self.args.get("max_num_rows") or INT_MAX
135 ):
136 batch = batch[
137 : (self.args.get("max_num_rows") or INT_MAX)
138 - self.total_imported_count
139 ]
140 max_hit = True
141 # convert df into list of dicts
142 table.add(batch)
143 self.total_imported_count += len(batch)
144 if max_hit:
145 break
146 tqdm.write(f"Imported {self.total_imported_count} rows")
147 tqdm.write(f"New table size: {table.count_rows()}")
148 if max_hit:
149 break
150 print("Data imported successfully")
151
152
153def get_default_value(data_type):
154 # Define default values for common data types
155 default_values = {
156 "object": "",
157 "int64": 0,
158 "float64": 0.0,
159 "bool": False,
160 "datetime64[ns]": pd.Timestamp("NaT"),
161 "timedelta64[ns]": pd.Timedelta("NaT"),
162 }
163 # Return the default value for the specified data type, or None if not specified
164 return default_values.get(data_type.name, None)
165
src/vdf_io/util.py:0-505
1from pathlib import Path
2from collections import OrderedDict
3from getpass import getpass
4import hashlib
5import json
6import os
7import time
8from uuid import UUID
9import numpy as np
10import pandas as pd
11from io import StringIO
12import sys
13from tqdm import tqdm
14from PIL import Image
15from halo import Halo
16
17from qdrant_client.http.models import Distance
18
19from vdf_io.constants import ID_COLUMN, INT_MAX
20from vdf_io.names import DBNames
21
22
23def sort_recursive(d):
24 """
25 Recursively sort the nested dictionary by its keys.
26 """
27 # if isinstance(d, list):
28 # return [sort_recursive(v) for v in d]
29 # if isinstance(d, tuple):
30 # return tuple(sort_recursive(v) for v in d)
31 # if isinstance(d, set):
32 # return list({sort_recursive(v) for v in d}).sort()
33 if (
34 isinstance(d, str)
35 or isinstance(d, int)
36 or isinstance(d, float)
37 or isinstance(d, bool)
38 or d is None
39 or isinstance(d, OrderedDict)
40 ):
41 return d
42 if hasattr(d, "attribute_map"):
43 return sort_recursive(d.attribute_map)
44 if not isinstance(d, dict):
45 try:
46 d = dict(d)
47 except Exception:
48 d = {"": str(d)}
49
50 sorted_dict = OrderedDict()
51 for key, value in sorted(d.items()):
52 sorted_dict[key] = sort_recursive(value)
53
54 return sorted_dict
55
56
57def convert_to_consistent_value(d):
58 """
59 Convert a nested dictionary to a consistent string regardless of key order.
60 """
61 sorted_dict = sort_recursive(d)
62 return json.dumps(sorted_dict, sort_keys=True)
63
64
65def extract_data_hash(arg_dict_combined):
66 arg_dict_combined_copy = arg_dict_combined.copy()
67 data_hash = hashlib.md5(
68 convert_to_consistent_value(arg_dict_combined_copy).encode("utf-8")
69 )
70 # make it 5 characters long
71 data_hash = data_hash.hexdigest()[:5]
72 return data_hash
73
74
75def extract_numerical_hash(string_value):
76 """
77 Extract a numerical hash from a string
78 """
79 return int(hashlib.md5(string_value.encode("utf-8")).hexdigest(), 16)
80
81
82def set_arg_from_input(
83 args,
84 arg_name,
85 prompt,
86 type_name=str,
87 default_value=None,
88 choices=None,
89 env_var=None,
90):
91 """
92 Set the value of an argument from user input if it is not already present
93 """
94 if (
95 (default_value is None)
96 and (env_var is not None)
97 and (os.getenv(env_var) is not None)
98 ):
99 default_value = os.getenv(env_var)
100 if arg_name not in args or (
101 args[arg_name] is None and default_value != "DO_NOT_PROMPT"
102 ):
103 while True:
104 inp = input(
105 prompt
106 + (" " + str(list(choices)) + ": " if choices is not None else "")
107 )
108 if len(inp) >= 2:
109 if inp[0] == '"' and inp[-1] == '"':
110 inp = inp[1:-1]
111 elif inp[0] == "'" and inp[-1] == "'":
112 inp = inp[1:-1]
113 if inp == "":
114 args[arg_name] = (
115 None if default_value is None else type_name(default_value)
116 )
117 break
118 elif choices is not None and not all(
119 choice in choices for choice in inp.split(",")
120 ):
121 print(f"Invalid input. Please choose from {choices}")
122 continue
123 else:
124 args[arg_name] = type_name(inp)
125 break
126 return
127
128
129def set_arg_from_password(args, arg_name, prompt, env_var_name):
130 """
131 Set the value of an argument from user input if it is not already present
132 """
133 if os.getenv(env_var_name) is not None:
134 args[arg_name] = os.getenv(env_var_name)
135 elif arg_name not in args or args[arg_name] is None:
136 args[arg_name] = getpass(prompt)
137 return
138
139
140def expand_shorthand_path(shorthand_path):
141 """
142 Expand shorthand notations in a file path to a full path-like object.
143
144 :param shorthand_path: A string representing the shorthand path.
145 :return: A Path object representing the full path.
146 """
147 if shorthand_path is None:
148 return None
149 # Expand '~' to the user's home directory
150 expanded_path = os.path.expanduser(shorthand_path)
151
152 # Resolve '.' and '..' to get the absolute path
153 full_path = Path(expanded_path).resolve()
154
155 return str(full_path)
156
157
158db_metric_to_standard_metric = {
159 DBNames.PINECONE: {
160 "cosine": Distance.COSINE,
161 "euclidean": Distance.EUCLID,
162 "dotproduct": Distance.DOT,
163 },
164 DBNames.QDRANT: {
165 Distance.COSINE: Distance.COSINE,
166 Distance.EUCLID: Distance.EUCLID,
167 Distance.DOT: Distance.DOT,
168 Distance.MANHATTAN: Distance.MANHATTAN,
169 },
170 DBNames.MILVUS: {
171 "COSINE": Distance.COSINE,
172 "IP": Distance.DOT,
173 "L2": Distance.EUCLID,
174 },
175 DBNames.KDBAI: {
176 "L2": Distance.EUCLID,
177 "CS": Distance.COSINE,
178 "IP": Distance.DOT,
179 },
180 DBNames.VERTEXAI: {
181 "DOT_PRODUCT_DISTANCE": Distance.DOT,
182 "SQUARED_L2_DISTANCE": Distance.EUCLID,
183 "COSINE_DISTANCE": Distance.COSINE,
184 "L1_DISTANCE": Distance.MANHATTAN,
185 },
186 DBNames.LANCEDB: {
187 "L2": Distance.EUCLID,
188 "Cosine": Distance.COSINE,
189 "Dot": Distance.DOT,
190 },
191 DBNames.CHROMA: {
192 "l2": Distance.EUCLID,
193 "cosine": Distance.COSINE,
194 "ip": Distance.DOT,
195 },
196 DBNames.ASTRADB: {
197 "cosine": Distance.COSINE,
198 "euclidean": Distance.EUCLID,
199 "dot_product": Distance.DOT,
200 },
201 DBNames.WEAVIATE: {
202 "cosine": Distance.COSINE,
203 "l2-squared": Distance.EUCLID,
204 "dot": Distance.DOT,
205 "manhattan": Distance.MANHATTAN,
206 },
207 DBNames.VESPA: {
208 "angular": Distance.COSINE,
209 "euclidean": Distance.EUCLID,
210 "dotproduct": Distance.DOT,
211 },
212}
213
214
215def standardize_metric(metric, db):
216 """
217 Standardize the metric name to the one used in the standard library.
218 """
219 if (
220 db in db_metric_to_standard_metric
221 and metric in db_metric_to_standard_metric[db]
222 ):
223 return db_metric_to_standard_metric[db][metric]
224 else:
225 raise Exception(f"Invalid metric '{metric}' for database '{db}'")
226
227
228def standardize_metric_reverse(metric, db):
229 """
230 Standardize the metric name to the one used in the standard library.
231 """
232 if (
233 db in db_metric_to_standard_metric
234 and metric in db_metric_to_standard_metric[db].values()
235 ):
236 for key, value in db_metric_to_standard_metric[db].items():
237 if value == metric:
238 return key
239 else:
240 tqdm.write(f"Invalid metric '{metric}' for database '{db}'. Using cosine")
241 return standardize_metric_reverse(Distance.COSINE, db)
242
243
244def get_final_data_path(cwd, dir, data_path, args):
245 if args.get("hf_dataset", None):
246 return data_path
247 final_data_path = os.path.join(cwd, dir, data_path)
248 if not os.path.isdir(final_data_path):
249 raise Exception(
250 f"Invalid data path\n"
251 f"data_path: {data_path},\n"
252 f"Joined path: {final_data_path}\n"
253 f"Current working directory: {cwd}\n"
254 f"Command line arg (dir): {dir}"
255 )
256 return final_data_path
257
258
259def list_configs_and_splits(name):
260 if "HUGGING_FACE_TOKEN" not in os.environ:
261 yield "train", None
262 import requests
263
264 headers = {"Authorization": f"Bearer {os.environ['HUGGING_FACE_TOKEN']}"}
265 API_URL = f"https://datasets-server.huggingface.co/splits?dataset={name}"
266
267 def query():
268 response = requests.get(API_URL, headers=headers)
269 return response.json()
270
271 data = query()
272 if "splits" in data:
273 for split in data["splits"]:
274 if "config" in split:
275 yield split["split"], split["config"]
276 else:
277 yield split["split"], None
278 else:
279 yield "train", None
280
281
282def get_parquet_files(data_path, args, temp_file_paths=[], id_column=ID_COLUMN):
283 # Load the data from the parquet files
284 if args.get("hf_dataset", None):
285 if args.get("max_num_rows", None):
286 from datasets import load_dataset
287
288 total_rows_loaded = 0
289 for i, (split, config) in enumerate(
290 list_configs_and_splits(args.get("hf_dataset"))
291 ):
292 tqdm.write(f"Split: {split}, Config: {config}")
293 ds = load_dataset(
294 args.get("hf_dataset"), name=config, split=split, streaming=True
295 )
296 with Halo(text="Taking a subset of the dataset", spinner="dots"):
297 it_ds = ds.take(args.get("max_num_rows") - total_rows_loaded)
298 start_time = time.time()
299 with Halo(
300 text="Converting to pandas dataframe (this may take a while)",
301 spinner="dots",
302 ):
303 df = pd.DataFrame(it_ds)
304 end_time = time.time()
305 tqdm.write(
306 f"Time taken to convert to pandas dataframe: {end_time - start_time:.2f} seconds"
307 )
308 df = cleanup_df(df)
309 if id_column not in df.columns:
310 # remove all rows
311 tqdm.write(
312 (
313 f"ID column '{id_column}' not found in parquet file '{data_path}'."
314 f" Skipping split '{split}', config '{config}'."
315 )
316 )
317 continue
318 total_rows_loaded += len(df)
319 temp_file_path = f"{os.getcwd()}/temp_{args['hash_value']}_{i}.parquet"
320 with Halo(text="Saving to parquet", spinner="dots"):
321 df.to_parquet(temp_file_path)
322 temp_file_paths.append(temp_file_path)
323 if total_rows_loaded >= args.get("max_num_rows"):
324 break
325 return temp_file_paths
326 from huggingface_hub import HfFileSystem
327
328 fs = HfFileSystem()
329 return [
330 "hf://" + x
331 for x in fs.glob(
332 f"datasets/{args.get('hf_dataset')}/{data_path if data_path!='.' else ''}/**.parquet"
333 )
334 ]
335 if not os.path.isdir(data_path):
336 if data_path.endswith(".parquet"):
337 return [data_path]
338 else:
339 raise Exception(f"Invalid data path '{data_path}'")
340 else:
341 # recursively find all parquet files (it should be a file acc to OS)
342 parquet_files = []
343 for root, _, files in os.walk(data_path):
344 for file in files:
345 if file.endswith(".parquet"):
346 parquet_files.append(os.path.join(root, file))
347 return parquet_files
348
349
350def cleanup_df(df):
351 for col in df.columns:
352 if df[col].dtype == "object":
353 first_el = df[col].iloc[0]
354 # if isinstance(first_el, bytes):
355 # df[col] = df[col].apply(lambda x: x.decode("utf-8"))
356 if isinstance(first_el, Image.Image):
357 # delete the image column
358 df = df.drop(columns=[col])
359 tqdm.write(
360 f"Warning: Image column '{col}' detected. Image columns are not supported in parquet files. The column has been removed."
361 )
362 # replace NaT with start of epoch
363 if df[col].dtype == "datetime64[ns]":
364 df[col] = df[col].fillna(pd.Timestamp(0))
365
366 # for float columns, replace inf with nan
367 numeric_cols = df.select_dtypes(include=[np.number])
368 df[numeric_cols.columns] = numeric_cols.map(lambda x: np.nan if np.isinf(x) else x)
369
370 return df
371
372
373# Function to recursively print help messages
374def print_help_recursively(parser, level=0):
375 # Temporarily redirect stdout to capture the help message
376 old_stdout = sys.stdout
377 sys.stdout = StringIO()
378
379 # Print the current parser's help message
380 parser.print_help()
381
382 # Retrieve and print the help message from the StringIO object
383 help_message = sys.stdout.getvalue()
384 sys.stdout = old_stdout # Restore stdout
385
386 # Print the captured help message with indentation for readability
387 print("\n" + "\t" * level + "Help message for level " + str(level) + ":")
388 for line in help_message.split("\n"):
389 print("\t" * level + line)
390
391 # Check if the current parser has subparsers
392 if hasattr(parser, "_subparsers"):
393 for _, subparser in parser._subparsers._group_actions[0].choices.items():
394 # Recursively print help for each subparser
395 print_help_recursively(subparser, level + 1)
396
397
398def is_str_uuid(id_str):
399 try:
400 uuid_obj = UUID(id_str)
401 return str(uuid_obj)
402 except ValueError:
403 return False
404
405
406def get_qdrant_id_from_id(idx):
407 if isinstance(idx, int) or idx.isdigit():
408 return int(idx)
409 elif not is_str_uuid(idx):
410 hex_string = hashlib.md5(idx.encode("UTF-8")).hexdigest()
411 return str(UUID(hex=hex_string))
412 else:
413 return str(UUID(idx))
414
415
416def read_parquet_progress(file_path, id_column, **kwargs):
417 if file_path.startswith("hf://"):
418 from huggingface_hub import HfFileSystem
419 from huggingface_hub import hf_hub_download
420
421 fs = HfFileSystem()
422 resolved_path = fs.resolve_path(file_path)
423 cache_path = hf_hub_download(
424 repo_id=resolved_path.repo_id,
425 filename=resolved_path.path_in_repo,
426 repo_type=resolved_path.repo_type,
427 )
428 file_path_to_be_read = cache_path
429 else:
430 file_path = os.path.abspath(file_path)
431 file_path_to_be_read = file_path
432 # read schema of the parquet file to check if columns are present
433 from pyarrow import parquet as pq
434
435 schema = pq.read_schema(file_path_to_be_read)
436 # list columns
437 columns = schema.names
438 # if kwargs has columns, check if all columns are present
439 cols = set()
440 cols.add(id_column)
441 return_empty = False
442 if "columns" in kwargs:
443 for col in kwargs["columns"]:
444 cols.add(col)
445 if col not in columns:
446 tqdm.write(
447 f"Column '{col}' not found in parquet file '{file_path_to_be_read}'. Returning empty DataFrame."
448 )
449 return_empty = True
450 if return_empty:
451 return pd.DataFrame(columns=list(cols))
452 with Halo(text=f"Reading parquet file {file_path_to_be_read}", spinner="dots"):
453 if (
454 "max_num_rows" in kwargs
455 and (kwargs.get("max_num_rows", INT_MAX) or INT_MAX) < INT_MAX
456 ):
457 from pyarrow.parquet import ParquetFile
458 import pyarrow as pa
459
460 pf = ParquetFile(file_path_to_be_read)
461 first_ten_rows = next(pf.iter_batches(batch_size=kwargs["max_num_rows"]))
462 df = pa.Table.from_batches([first_ten_rows]).to_pandas()
463 else:
464 df = pd.read_parquet(file_path_to_be_read)
465 tqdm.write(f"{file_path_to_be_read} read successfully. {len(df)=} rows")
466 return df
467
468
469def get_author_name():
470 return (os.environ.get("USER", os.environ.get("USERNAME"))) or "unknown"
471
472
473def clean_value(v):
474 if hasattr(v, "__iter__") and not isinstance(v, str):
475 if any(pd.isna(x) for x in v):
476 return [None if pd.isna(x) else x for x in v]
477 if isinstance(v, float) and np.isnan(v):
478 return None
479 if isinstance(v, np.datetime64) and np.isnat(v):
480 return None
481 if not hasattr(v, "__iter__") and pd.isna(v):
482 return None
483 return v
484
485
486def clean_documents(documents):
487 for doc in documents:
488 to_be_replaced = []
489 for k, v in doc.items():
490 doc[k] = clean_value(v)
491 # if k doesn't conform to CQL standards, replace it
492 # like spaces
493 if " " in k:
494 to_be_replaced.append(k)
495 for k in to_be_replaced:
496 doc[k.replace(" ", "_")] = doc.pop(k)
497
498
499def divide_into_batches(df, batch_size):
500 """
501 Divide the dataframe into batches of size batch_size
502 """
503 for i in range(0, len(df), batch_size):
504 yield df[i : i + batch_size]
505