Logo of Sweep
Create LanceDB index after table is created in importAI-Northstar-Tech/vector-io#80

> > >

✓ Completed in 11 minutes, 2 months ago using GPT-4  •   Book a call  •   Report a bug


Progress

  Modifysrc/vdf_io/import_vdf/lancedb_import.py f168003 
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13    cleanup_df,
14    divide_into_batches,
15    set_arg_from_input,
16    set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25    DB_NAME_SLUG = DBNames.LANCEDB
26
27    @classmethod
28    def import_vdb(cls, args):
29        """
30        Import data to LanceDB
31        """
32        set_arg_from_input(
33            args,
34            "endpoint",
35            "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36            str,
37            default_value="~/.lancedb",
38        )
39        set_arg_from_password(
40            args,
41            "lancedb_api_key",
42            "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43            "LANCEDB_API_KEY",
44        )
45        lancedb_import = ImportLanceDB(args)
46        lancedb_import.upsert_data()
47        return lancedb_import
48
49    @classmethod
50    def make_parser(cls, subparsers):
51        parser_lancedb = subparsers.add_parser(
52            cls.DB_NAME_SLUG, help="Import data to lancedb"
53        )
54        parser_lancedb.add_argument(
55            "--endpoint", type=str, help="Location of LanceDB instance"
56        )
57        parser_lancedb.add_argument(
58            "--lancedb_api_key", type=str, help="LanceDB API key"
59        )
60        parser_lancedb.add_argument(
61            "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62        )
63
64    def __init__(self, args):
65        # call super class constructor
66        super().__init__(args)
67        self.db = lancedb.connect(
68            self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69        )
70
71    def upsert_data(self):
72        max_hit = False
73        self.total_imported_count = 0
74        indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75        index_names: List[str] = list(indexes_content.keys())
76        if len(index_names) == 0:
77            raise ValueError("No indexes found in VDF_META.json")
78        tables = self.db.table_names()
79        # Load Parquet file
80        # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81        for index_name, index_meta in tqdm(
82            indexes_content.items(), desc="Importing indexes"
83        ):
84            for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85                self.set_dims(namespace_meta, index_name)
86                data_path = namespace_meta["data_path"]
87                final_data_path = self.get_final_data_path(data_path)
88                # Load the data from the parquet files
89                parquet_files = self.get_parquet_files(final_data_path)
90
91                new_index_name = index_name + (
92                    f'_{namespace_meta["namespace"]}'
93                    if namespace_meta["namespace"]
94                    else ""
95                )
96                new_index_name = self.create_new_name(new_index_name, tables)
97                if new_index_name not in tables:
98                    table = self.db.create_table(
99                        new_index_name, schema=pq.read_schema(parquet_files[0])
100                    )
101                    tqdm.write(f"Created table {new_index_name}")
102                else:
103                    table = self.db.open_table(new_index_name)
104                    tqdm.write(f"Opened table {new_index_name}")
105
106                for file in tqdm(parquet_files, desc="Iterating parquet files"):
107                    file_path = self.get_file_path(final_data_path, file)
108                    df = self.read_parquet_progress(
109                        file_path,
110                        max_num_rows=(
111                            (self.args.get("max_num_rows") or INT_MAX)
112                            - self.total_imported_count
113                        ),
114                    )
115                    df = cleanup_df(df)
116                    # if there are additional columns in the parquet file, add them to the table
117                    for col in df.columns:
118                        if col not in [field.name for field in table.schema]:
119                            col_type = df[col].dtype
120                            tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121                            table.add_columns(
122                                {
123                                    col: get_default_value(col_type),
124                                }
125                            )
126                    # split in batches
127                    BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128                    for batch in tqdm(
129                        divide_into_batches(df, BATCH_SIZE),
130                        desc="Importing batches",
131                        total=len(df) // BATCH_SIZE,
132                    ):
133                        if self.total_imported_count + len(batch) >= (
134                            self.args.get("max_num_rows") or INT_MAX
135                        ):
136                            batch = batch[
137                                : (self.args.get("max_num_rows") or INT_MAX)
138                                - self.total_imported_count
139                            ]
140                            max_hit = True
141                        # convert df into list of dicts
142                        table.add(batch)
143                        self.total_imported_count += len(batch)
144                        if max_hit:
145                            break
146                tqdm.write(f"Imported {self.total_imported_count} rows")
147                tqdm.write(f"New table size: {table.count_rows()}")
148                if max_hit:
149                    break
150        print("Data imported successfully")
151
152
153def get_default_value(data_type):
154    # Define default values for common data types
155    default_values = {
156        "object": "",
157        "int64": 0,
158        "float64": 0.0,
159        "bool": False,
160        "datetime64[ns]": pd.Timestamp("NaT"),
161        "timedelta64[ns]": pd.Timedelta("NaT"),
162    }
163    # Return the default value for the specified data type, or None if not specified
164    return default_values.get(data_type.name, None)
165

Import the LanceDB create_index method at the top of the file:

from lancedb import create_index
  Modifysrc/vdf_io/import_vdf/lancedb_import.py f168003 
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13    cleanup_df,
14    divide_into_batches,
15    set_arg_from_input,
16    set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25    DB_NAME_SLUG = DBNames.LANCEDB
26
27    @classmethod
28    def import_vdb(cls, args):
29        """
30        Import data to LanceDB
31        """
32        set_arg_from_input(
33            args,
34            "endpoint",
35            "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36            str,
37            default_value="~/.lancedb",
38        )
39        set_arg_from_password(
40            args,
41            "lancedb_api_key",
42            "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43            "LANCEDB_API_KEY",
44        )
45        lancedb_import = ImportLanceDB(args)
46        lancedb_import.upsert_data()
47        return lancedb_import
48
49    @classmethod
50    def make_parser(cls, subparsers):
51        parser_lancedb = subparsers.add_parser(
52            cls.DB_NAME_SLUG, help="Import data to lancedb"
53        )
54        parser_lancedb.add_argument(
55            "--endpoint", type=str, help="Location of LanceDB instance"
56        )
57        parser_lancedb.add_argument(
58            "--lancedb_api_key", type=str, help="LanceDB API key"
59        )
60        parser_lancedb.add_argument(
61            "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62        )
63
64    def __init__(self, args):
65        # call super class constructor
66        super().__init__(args)
67        self.db = lancedb.connect(
68            self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69        )
70
71    def upsert_data(self):
72        max_hit = False
73        self.total_imported_count = 0
74        indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75        index_names: List[str] = list(indexes_content.keys())
76        if len(index_names) == 0:
77            raise ValueError("No indexes found in VDF_META.json")
78        tables = self.db.table_names()
79        # Load Parquet file
80        # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81        for index_name, index_meta in tqdm(
82            indexes_content.items(), desc="Importing indexes"
83        ):
84            for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85                self.set_dims(namespace_meta, index_name)
86                data_path = namespace_meta["data_path"]
87                final_data_path = self.get_final_data_path(data_path)
88                # Load the data from the parquet files
89                parquet_files = self.get_parquet_files(final_data_path)
90
91                new_index_name = index_name + (
92                    f'_{namespace_meta["namespace"]}'
93                    if namespace_meta["namespace"]
94                    else ""
95                )
96                new_index_name = self.create_new_name(new_index_name, tables)
97                if new_index_name not in tables:
98                    table = self.db.create_table(
99                        new_index_name, schema=pq.read_schema(parquet_files[0])
100                    )
101                    tqdm.write(f"Created table {new_index_name}")
102                else:
103                    table = self.db.open_table(new_index_name)
104                    tqdm.write(f"Opened table {new_index_name}")
105
106                for file in tqdm(parquet_files, desc="Iterating parquet files"):
107                    file_path = self.get_file_path(final_data_path, file)
108                    df = self.read_parquet_progress(
109                        file_path,
110                        max_num_rows=(
111                            (self.args.get("max_num_rows") or INT_MAX)
112                            - self.total_imported_count
113                        ),
114                    )
115                    df = cleanup_df(df)
116                    # if there are additional columns in the parquet file, add them to the table
117                    for col in df.columns:
118                        if col not in [field.name for field in table.schema]:
119                            col_type = df[col].dtype
120                            tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121                            table.add_columns(
122                                {
123                                    col: get_default_value(col_type),
124                                }
125                            )
126                    # split in batches
127                    BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128                    for batch in tqdm(
129                        divide_into_batches(df, BATCH_SIZE),
130                        desc="Importing batches",
131                        total=len(df) // BATCH_SIZE,
132                    ):
133                        if self.total_imported_count + len(batch) >= (
134                            self.args.get("max_num_rows") or INT_MAX
135                        ):
136                            batch = batch[
137                                : (self.args.get("max_num_rows") or INT_MAX)
138                                - self.total_imported_count
139                            ]
140                            max_hit = True
141                        # convert df into list of dicts
142                        table.add(batch)
143                        self.total_imported_count += len(batch)
144                        if max_hit:
145                            break
146                tqdm.write(f"Imported {self.total_imported_count} rows")
147                tqdm.write(f"New table size: {table.count_rows()}")
148                if max_hit:
149                    break
150        print("Data imported successfully")
151
152
153def get_default_value(data_type):
154    # Define default values for common data types
155    default_values = {
156        "object": "",
157        "int64": 0,
158        "float64": 0.0,
159        "bool": False,
160        "datetime64[ns]": pd.Timestamp("NaT"),
161        "timedelta64[ns]": pd.Timedelta("NaT"),
162    }
163    # Return the default value for the specified data type, or None if not specified
164    return default_values.get(data_type.name, None)
165

In the upsert_data method of the ImportLanceDB class, after the code block that creates a new table or opens an existing one, add the following to create an index on the table:

# Get the ID column from the parquet file schema
parquet_schema = pq.read_schema(parquet_files[0])
id_column = "id" # Default 
for field in parquet_schema:
    if field.name == ID_COLUMN:
        id_column = field.name
        break

# Create index on the table  
create_index(table, id_column)
tqdm.write(f"Created index on {id_column} for table {new_index_name}")

This code reads the schema of the first parquet file to determine the name of the ID column (defaulting to "id" if not found). It then calls create_index passing the table object and ID column name to create an index on that column.

Plan

This is based on the results of the Planning step. The plan may expand from failed GitHub Actions runs.

Code Snippets Found

This is based on the results of the Searching step.

src/vdf_io/import_vdf/lancedb_import.py:0-165 
1from typing import Dict, List
2from dotenv import load_dotenv
3import pandas as pd
4from tqdm import tqdm
5import pyarrow.parquet as pq
6
7import lancedb
8
9from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX
10from vdf_io.meta_types import NamespaceMeta
11from vdf_io.names import DBNames
12from vdf_io.util import (
13    cleanup_df,
14    divide_into_batches,
15    set_arg_from_input,
16    set_arg_from_password,
17)
18from vdf_io.import_vdf.vdf_import_cls import ImportVDB
19
20
21load_dotenv()
22
23
24class ImportLanceDB(ImportVDB):
25    DB_NAME_SLUG = DBNames.LANCEDB
26
27    @classmethod
28    def import_vdb(cls, args):
29        """
30        Import data to LanceDB
31        """
32        set_arg_from_input(
33            args,
34            "endpoint",
35            "Enter the URL of LanceDB instance (default: '~/.lancedb'): ",
36            str,
37            default_value="~/.lancedb",
38        )
39        set_arg_from_password(
40            args,
41            "lancedb_api_key",
42            "Enter the LanceDB API key (default: value of os.environ['LANCEDB_API_KEY']): ",
43            "LANCEDB_API_KEY",
44        )
45        lancedb_import = ImportLanceDB(args)
46        lancedb_import.upsert_data()
47        return lancedb_import
48
49    @classmethod
50    def make_parser(cls, subparsers):
51        parser_lancedb = subparsers.add_parser(
52            cls.DB_NAME_SLUG, help="Import data to lancedb"
53        )
54        parser_lancedb.add_argument(
55            "--endpoint", type=str, help="Location of LanceDB instance"
56        )
57        parser_lancedb.add_argument(
58            "--lancedb_api_key", type=str, help="LanceDB API key"
59        )
60        parser_lancedb.add_argument(
61            "--tables", type=str, help="LanceDB tables to export (comma-separated)"
62        )
63
64    def __init__(self, args):
65        # call super class constructor
66        super().__init__(args)
67        self.db = lancedb.connect(
68            self.args["endpoint"], api_key=self.args.get("lancedb_api_key") or None
69        )
70
71    def upsert_data(self):
72        max_hit = False
73        self.total_imported_count = 0
74        indexes_content: Dict[str, List[NamespaceMeta]] = self.vdf_meta["indexes"]
75        index_names: List[str] = list(indexes_content.keys())
76        if len(index_names) == 0:
77            raise ValueError("No indexes found in VDF_META.json")
78        tables = self.db.table_names()
79        # Load Parquet file
80        # print(indexes_content[index_names[0]]):List[NamespaceMeta]
81        for index_name, index_meta in tqdm(
82            indexes_content.items(), desc="Importing indexes"
83        ):
84            for namespace_meta in tqdm(index_meta, desc="Importing namespaces"):
85                self.set_dims(namespace_meta, index_name)
86                data_path = namespace_meta["data_path"]
87                final_data_path = self.get_final_data_path(data_path)
88                # Load the data from the parquet files
89                parquet_files = self.get_parquet_files(final_data_path)
90
91                new_index_name = index_name + (
92                    f'_{namespace_meta["namespace"]}'
93                    if namespace_meta["namespace"]
94                    else ""
95                )
96                new_index_name = self.create_new_name(new_index_name, tables)
97                if new_index_name not in tables:
98                    table = self.db.create_table(
99                        new_index_name, schema=pq.read_schema(parquet_files[0])
100                    )
101                    tqdm.write(f"Created table {new_index_name}")
102                else:
103                    table = self.db.open_table(new_index_name)
104                    tqdm.write(f"Opened table {new_index_name}")
105
106                for file in tqdm(parquet_files, desc="Iterating parquet files"):
107                    file_path = self.get_file_path(final_data_path, file)
108                    df = self.read_parquet_progress(
109                        file_path,
110                        max_num_rows=(
111                            (self.args.get("max_num_rows") or INT_MAX)
112                            - self.total_imported_count
113                        ),
114                    )
115                    df = cleanup_df(df)
116                    # if there are additional columns in the parquet file, add them to the table
117                    for col in df.columns:
118                        if col not in [field.name for field in table.schema]:
119                            col_type = df[col].dtype
120                            tqdm.write(f"Adding column {col} of type {col_type} to {new_index_name}")
121                            table.add_columns(
122                                {
123                                    col: get_default_value(col_type),
124                                }
125                            )
126                    # split in batches
127                    BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE
128                    for batch in tqdm(
129                        divide_into_batches(df, BATCH_SIZE),
130                        desc="Importing batches",
131                        total=len(df) // BATCH_SIZE,
132                    ):
133                        if self.total_imported_count + len(batch) >= (
134                            self.args.get("max_num_rows") or INT_MAX
135                        ):
136                            batch = batch[
137                                : (self.args.get("max_num_rows") or INT_MAX)
138                                - self.total_imported_count
139                            ]
140                            max_hit = True
141                        # convert df into list of dicts
142                        table.add(batch)
143                        self.total_imported_count += len(batch)
144                        if max_hit:
145                            break
146                tqdm.write(f"Imported {self.total_imported_count} rows")
147                tqdm.write(f"New table size: {table.count_rows()}")
148                if max_hit:
149                    break
150        print("Data imported successfully")
151
152
153def get_default_value(data_type):
154    # Define default values for common data types
155    default_values = {
156        "object": "",
157        "int64": 0,
158        "float64": 0.0,
159        "bool": False,
160        "datetime64[ns]": pd.Timestamp("NaT"),
161        "timedelta64[ns]": pd.Timedelta("NaT"),
162    }
163    # Return the default value for the specified data type, or None if not specified
164    return default_values.get(data_type.name, None)
165
src/vdf_io/util.py:0-505 
1from pathlib import Path
2from collections import OrderedDict
3from getpass import getpass
4import hashlib
5import json
6import os
7import time
8from uuid import UUID
9import numpy as np
10import pandas as pd
11from io import StringIO
12import sys
13from tqdm import tqdm
14from PIL import Image
15from halo import Halo
16
17from qdrant_client.http.models import Distance
18
19from vdf_io.constants import ID_COLUMN, INT_MAX
20from vdf_io.names import DBNames
21
22
23def sort_recursive(d):
24    """
25    Recursively sort the nested dictionary by its keys.
26    """
27    # if isinstance(d, list):
28    #     return [sort_recursive(v) for v in d]
29    # if isinstance(d, tuple):
30    #     return tuple(sort_recursive(v) for v in d)
31    # if isinstance(d, set):
32    #     return list({sort_recursive(v) for v in d}).sort()
33    if (
34        isinstance(d, str)
35        or isinstance(d, int)
36        or isinstance(d, float)
37        or isinstance(d, bool)
38        or d is None
39        or isinstance(d, OrderedDict)
40    ):
41        return d
42    if hasattr(d, "attribute_map"):
43        return sort_recursive(d.attribute_map)
44    if not isinstance(d, dict):
45        try:
46            d = dict(d)
47        except Exception:
48            d = {"": str(d)}
49
50    sorted_dict = OrderedDict()
51    for key, value in sorted(d.items()):
52        sorted_dict[key] = sort_recursive(value)
53
54    return sorted_dict
55
56
57def convert_to_consistent_value(d):
58    """
59    Convert a nested dictionary to a consistent string regardless of key order.
60    """
61    sorted_dict = sort_recursive(d)
62    return json.dumps(sorted_dict, sort_keys=True)
63
64
65def extract_data_hash(arg_dict_combined):
66    arg_dict_combined_copy = arg_dict_combined.copy()
67    data_hash = hashlib.md5(
68        convert_to_consistent_value(arg_dict_combined_copy).encode("utf-8")
69    )
70    # make it 5 characters long
71    data_hash = data_hash.hexdigest()[:5]
72    return data_hash
73
74
75def extract_numerical_hash(string_value):
76    """
77    Extract a numerical hash from a string
78    """
79    return int(hashlib.md5(string_value.encode("utf-8")).hexdigest(), 16)
80
81
82def set_arg_from_input(
83    args,
84    arg_name,
85    prompt,
86    type_name=str,
87    default_value=None,
88    choices=None,
89    env_var=None,
90):
91    """
92    Set the value of an argument from user input if it is not already present
93    """
94    if (
95        (default_value is None)
96        and (env_var is not None)
97        and (os.getenv(env_var) is not None)
98    ):
99        default_value = os.getenv(env_var)
100    if arg_name not in args or (
101        args[arg_name] is None and default_value != "DO_NOT_PROMPT"
102    ):
103        while True:
104            inp = input(
105                prompt
106                + (" " + str(list(choices)) + ": " if choices is not None else "")
107            )
108            if len(inp) >= 2:
109                if inp[0] == '"' and inp[-1] == '"':
110                    inp = inp[1:-1]
111                elif inp[0] == "'" and inp[-1] == "'":
112                    inp = inp[1:-1]
113            if inp == "":
114                args[arg_name] = (
115                    None if default_value is None else type_name(default_value)
116                )
117                break
118            elif choices is not None and not all(
119                choice in choices for choice in inp.split(",")
120            ):
121                print(f"Invalid input. Please choose from {choices}")
122                continue
123            else:
124                args[arg_name] = type_name(inp)
125                break
126    return
127
128
129def set_arg_from_password(args, arg_name, prompt, env_var_name):
130    """
131    Set the value of an argument from user input if it is not already present
132    """
133    if os.getenv(env_var_name) is not None:
134        args[arg_name] = os.getenv(env_var_name)
135    elif arg_name not in args or args[arg_name] is None:
136        args[arg_name] = getpass(prompt)
137    return
138
139
140def expand_shorthand_path(shorthand_path):
141    """
142    Expand shorthand notations in a file path to a full path-like object.
143
144    :param shorthand_path: A string representing the shorthand path.
145    :return: A Path object representing the full path.
146    """
147    if shorthand_path is None:
148        return None
149    # Expand '~' to the user's home directory
150    expanded_path = os.path.expanduser(shorthand_path)
151
152    # Resolve '.' and '..' to get the absolute path
153    full_path = Path(expanded_path).resolve()
154
155    return str(full_path)
156
157
158db_metric_to_standard_metric = {
159    DBNames.PINECONE: {
160        "cosine": Distance.COSINE,
161        "euclidean": Distance.EUCLID,
162        "dotproduct": Distance.DOT,
163    },
164    DBNames.QDRANT: {
165        Distance.COSINE: Distance.COSINE,
166        Distance.EUCLID: Distance.EUCLID,
167        Distance.DOT: Distance.DOT,
168        Distance.MANHATTAN: Distance.MANHATTAN,
169    },
170    DBNames.MILVUS: {
171        "COSINE": Distance.COSINE,
172        "IP": Distance.DOT,
173        "L2": Distance.EUCLID,
174    },
175    DBNames.KDBAI: {
176        "L2": Distance.EUCLID,
177        "CS": Distance.COSINE,
178        "IP": Distance.DOT,
179    },
180    DBNames.VERTEXAI: {
181        "DOT_PRODUCT_DISTANCE": Distance.DOT,
182        "SQUARED_L2_DISTANCE": Distance.EUCLID,
183        "COSINE_DISTANCE": Distance.COSINE,
184        "L1_DISTANCE": Distance.MANHATTAN,
185    },
186    DBNames.LANCEDB: {
187        "L2": Distance.EUCLID,
188        "Cosine": Distance.COSINE,
189        "Dot": Distance.DOT,
190    },
191    DBNames.CHROMA: {
192        "l2": Distance.EUCLID,
193        "cosine": Distance.COSINE,
194        "ip": Distance.DOT,
195    },
196    DBNames.ASTRADB: {
197        "cosine": Distance.COSINE,
198        "euclidean": Distance.EUCLID,
199        "dot_product": Distance.DOT,
200    },
201    DBNames.WEAVIATE: {
202        "cosine": Distance.COSINE,
203        "l2-squared": Distance.EUCLID,
204        "dot": Distance.DOT,
205        "manhattan": Distance.MANHATTAN,
206    },
207    DBNames.VESPA: {
208        "angular": Distance.COSINE,
209        "euclidean": Distance.EUCLID,
210        "dotproduct": Distance.DOT,
211    },
212}
213
214
215def standardize_metric(metric, db):
216    """
217    Standardize the metric name to the one used in the standard library.
218    """
219    if (
220        db in db_metric_to_standard_metric
221        and metric in db_metric_to_standard_metric[db]
222    ):
223        return db_metric_to_standard_metric[db][metric]
224    else:
225        raise Exception(f"Invalid metric '{metric}' for database '{db}'")
226
227
228def standardize_metric_reverse(metric, db):
229    """
230    Standardize the metric name to the one used in the standard library.
231    """
232    if (
233        db in db_metric_to_standard_metric
234        and metric in db_metric_to_standard_metric[db].values()
235    ):
236        for key, value in db_metric_to_standard_metric[db].items():
237            if value == metric:
238                return key
239    else:
240        tqdm.write(f"Invalid metric '{metric}' for database '{db}'. Using cosine")
241        return standardize_metric_reverse(Distance.COSINE, db)
242
243
244def get_final_data_path(cwd, dir, data_path, args):
245    if args.get("hf_dataset", None):
246        return data_path
247    final_data_path = os.path.join(cwd, dir, data_path)
248    if not os.path.isdir(final_data_path):
249        raise Exception(
250            f"Invalid data path\n"
251            f"data_path: {data_path},\n"
252            f"Joined path: {final_data_path}\n"
253            f"Current working directory: {cwd}\n"
254            f"Command line arg (dir): {dir}"
255        )
256    return final_data_path
257
258
259def list_configs_and_splits(name):
260    if "HUGGING_FACE_TOKEN" not in os.environ:
261        yield "train", None
262    import requests
263
264    headers = {"Authorization": f"Bearer {os.environ['HUGGING_FACE_TOKEN']}"}
265    API_URL = f"https://datasets-server.huggingface.co/splits?dataset={name}"
266
267    def query():
268        response = requests.get(API_URL, headers=headers)
269        return response.json()
270
271    data = query()
272    if "splits" in data:
273        for split in data["splits"]:
274            if "config" in split:
275                yield split["split"], split["config"]
276            else:
277                yield split["split"], None
278    else:
279        yield "train", None
280
281
282def get_parquet_files(data_path, args, temp_file_paths=[], id_column=ID_COLUMN):
283    # Load the data from the parquet files
284    if args.get("hf_dataset", None):
285        if args.get("max_num_rows", None):
286            from datasets import load_dataset
287
288            total_rows_loaded = 0
289            for i, (split, config) in enumerate(
290                list_configs_and_splits(args.get("hf_dataset"))
291            ):
292                tqdm.write(f"Split: {split}, Config: {config}")
293                ds = load_dataset(
294                    args.get("hf_dataset"), name=config, split=split, streaming=True
295                )
296                with Halo(text="Taking a subset of the dataset", spinner="dots"):
297                    it_ds = ds.take(args.get("max_num_rows") - total_rows_loaded)
298                start_time = time.time()
299                with Halo(
300                    text="Converting to pandas dataframe (this may take a while)",
301                    spinner="dots",
302                ):
303                    df = pd.DataFrame(it_ds)
304                end_time = time.time()
305                tqdm.write(
306                    f"Time taken to convert to pandas dataframe: {end_time - start_time:.2f} seconds"
307                )
308                df = cleanup_df(df)
309                if id_column not in df.columns:
310                    # remove all rows
311                    tqdm.write(
312                        (
313                            f"ID column '{id_column}' not found in parquet file '{data_path}'."
314                            f" Skipping split '{split}', config '{config}'."
315                        )
316                    )
317                    continue
318                total_rows_loaded += len(df)
319                temp_file_path = f"{os.getcwd()}/temp_{args['hash_value']}_{i}.parquet"
320                with Halo(text="Saving to parquet", spinner="dots"):
321                    df.to_parquet(temp_file_path)
322                temp_file_paths.append(temp_file_path)
323                if total_rows_loaded >= args.get("max_num_rows"):
324                    break
325            return temp_file_paths
326        from huggingface_hub import HfFileSystem
327
328        fs = HfFileSystem()
329        return [
330            "hf://" + x
331            for x in fs.glob(
332                f"datasets/{args.get('hf_dataset')}/{data_path if data_path!='.' else ''}/**.parquet"
333            )
334        ]
335    if not os.path.isdir(data_path):
336        if data_path.endswith(".parquet"):
337            return [data_path]
338        else:
339            raise Exception(f"Invalid data path '{data_path}'")
340    else:
341        # recursively find all parquet files (it should be a file acc to OS)
342        parquet_files = []
343        for root, _, files in os.walk(data_path):
344            for file in files:
345                if file.endswith(".parquet"):
346                    parquet_files.append(os.path.join(root, file))
347        return parquet_files
348
349
350def cleanup_df(df):
351    for col in df.columns:
352        if df[col].dtype == "object":
353            first_el = df[col].iloc[0]
354            # if isinstance(first_el, bytes):
355            #     df[col] = df[col].apply(lambda x: x.decode("utf-8"))
356            if isinstance(first_el, Image.Image):
357                # delete the image column
358                df = df.drop(columns=[col])
359                tqdm.write(
360                    f"Warning: Image column '{col}' detected. Image columns are not supported in parquet files. The column has been removed."
361                )
362        # replace NaT with start of epoch
363        if df[col].dtype == "datetime64[ns]":
364            df[col] = df[col].fillna(pd.Timestamp(0))
365
366    # for float columns, replace inf with nan
367    numeric_cols = df.select_dtypes(include=[np.number])
368    df[numeric_cols.columns] = numeric_cols.map(lambda x: np.nan if np.isinf(x) else x)
369
370    return df
371
372
373# Function to recursively print help messages
374def print_help_recursively(parser, level=0):
375    # Temporarily redirect stdout to capture the help message
376    old_stdout = sys.stdout
377    sys.stdout = StringIO()
378
379    # Print the current parser's help message
380    parser.print_help()
381
382    # Retrieve and print the help message from the StringIO object
383    help_message = sys.stdout.getvalue()
384    sys.stdout = old_stdout  # Restore stdout
385
386    # Print the captured help message with indentation for readability
387    print("\n" + "\t" * level + "Help message for level " + str(level) + ":")
388    for line in help_message.split("\n"):
389        print("\t" * level + line)
390
391    # Check if the current parser has subparsers
392    if hasattr(parser, "_subparsers"):
393        for _, subparser in parser._subparsers._group_actions[0].choices.items():
394            # Recursively print help for each subparser
395            print_help_recursively(subparser, level + 1)
396
397
398def is_str_uuid(id_str):
399    try:
400        uuid_obj = UUID(id_str)
401        return str(uuid_obj)
402    except ValueError:
403        return False
404
405
406def get_qdrant_id_from_id(idx):
407    if isinstance(idx, int) or idx.isdigit():
408        return int(idx)
409    elif not is_str_uuid(idx):
410        hex_string = hashlib.md5(idx.encode("UTF-8")).hexdigest()
411        return str(UUID(hex=hex_string))
412    else:
413        return str(UUID(idx))
414
415
416def read_parquet_progress(file_path, id_column, **kwargs):
417    if file_path.startswith("hf://"):
418        from huggingface_hub import HfFileSystem
419        from huggingface_hub import hf_hub_download
420
421        fs = HfFileSystem()
422        resolved_path = fs.resolve_path(file_path)
423        cache_path = hf_hub_download(
424            repo_id=resolved_path.repo_id,
425            filename=resolved_path.path_in_repo,
426            repo_type=resolved_path.repo_type,
427        )
428        file_path_to_be_read = cache_path
429    else:
430        file_path = os.path.abspath(file_path)
431        file_path_to_be_read = file_path
432    # read schema of the parquet file to check if columns are present
433    from pyarrow import parquet as pq
434
435    schema = pq.read_schema(file_path_to_be_read)
436    # list columns
437    columns = schema.names
438    # if kwargs has columns, check if all columns are present
439    cols = set()
440    cols.add(id_column)
441    return_empty = False
442    if "columns" in kwargs:
443        for col in kwargs["columns"]:
444            cols.add(col)
445            if col not in columns:
446                tqdm.write(
447                    f"Column '{col}' not found in parquet file '{file_path_to_be_read}'. Returning empty DataFrame."
448                )
449                return_empty = True
450    if return_empty:
451        return pd.DataFrame(columns=list(cols))
452    with Halo(text=f"Reading parquet file {file_path_to_be_read}", spinner="dots"):
453        if (
454            "max_num_rows" in kwargs
455            and (kwargs.get("max_num_rows", INT_MAX) or INT_MAX) < INT_MAX
456        ):
457            from pyarrow.parquet import ParquetFile
458            import pyarrow as pa
459
460            pf = ParquetFile(file_path_to_be_read)
461            first_ten_rows = next(pf.iter_batches(batch_size=kwargs["max_num_rows"]))
462            df = pa.Table.from_batches([first_ten_rows]).to_pandas()
463        else:
464            df = pd.read_parquet(file_path_to_be_read)
465    tqdm.write(f"{file_path_to_be_read} read successfully. {len(df)=} rows")
466    return df
467
468
469def get_author_name():
470    return (os.environ.get("USER", os.environ.get("USERNAME"))) or "unknown"
471
472
473def clean_value(v):
474    if hasattr(v, "__iter__") and not isinstance(v, str):
475        if any(pd.isna(x) for x in v):
476            return [None if pd.isna(x) else x for x in v]
477    if isinstance(v, float) and np.isnan(v):
478        return None
479    if isinstance(v, np.datetime64) and np.isnat(v):
480        return None
481    if not hasattr(v, "__iter__") and pd.isna(v):
482        return None
483    return v
484
485
486def clean_documents(documents):
487    for doc in documents:
488        to_be_replaced = []
489        for k, v in doc.items():
490            doc[k] = clean_value(v)
491            # if k doesn't conform to CQL standards, replace it
492            # like spaces
493            if " " in k:
494                to_be_replaced.append(k)
495        for k in to_be_replaced:
496            doc[k.replace(" ", "_")] = doc.pop(k)
497
498
499def divide_into_batches(df, batch_size):
500    """
501    Divide the dataframe into batches of size batch_size
502    """
503    for i in range(0, len(df), batch_size):
504        yield df[i : i + batch_size]
505