diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index 8c8a9b6f..65b85566 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -29,6 +29,7 @@ def _dict_to_params(collections_dict: dict, write_params: dict): "enable_text": False, "enable_charts": False, "enable_tables": False, + "enable_images": False, } if not isinstance(data_type, list): data_type = [data_type] @@ -53,6 +54,7 @@ def __init__( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, + enable_images: bool = True, bm25_save_path: str = "bm25_model.json", compute_bm25_stats: bool = True, access_key: str = "minioadmin", @@ -305,7 +307,7 @@ def _record_dict(text, element, sparse_vector: csr_array = None): return record -def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: bool): +def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: bool, enable_images: bool): text = None if element["document_type"] == "text" and enable_text: text = element["metadata"]["content"] @@ -315,6 +317,8 @@ def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: b text = None elif element["metadata"]["content_metadata"]["subtype"] == "table" and not enable_tables: text = None + elif element["document_type"] == "image" and enable_images: + text = element["metadata"]["image_metadata"]["caption"] return text @@ -325,6 +329,7 @@ def write_records_minio( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, + enable_images: bool = True, record_func=_record_dict, ) -> RemoteBulkWriter: """ @@ -349,6 +354,8 @@ def write_records_minio( When true, ensure all chart type records are used. enable_tables : bool, optional When true, ensure all table type records are used. + enable_images : bool, optional + When true, ensure all image type records are used. record_func : function, optional This function will be used to parse the records for necessary information. @@ -359,7 +366,7 @@ def write_records_minio( """ for result in records: for element in result: - text = _pull_text(element, enable_text, enable_charts, enable_tables) + text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images) if text: if sparse_model is not None: writer.append_row(record_func(text, element, sparse_model.encode_documents([text]))) @@ -407,7 +414,11 @@ def bulk_insert_milvus(collection_name: str, writer: RemoteBulkWriter, milvus_ur def create_bm25_model( - records, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True + records, + enable_text: bool = True, + enable_charts: bool = True, + enable_tables: bool = True, + enable_images: bool = True, ) -> BM25EmbeddingFunction: """ This function takes the input records and creates a corpus, @@ -424,6 +435,8 @@ def create_bm25_model( When true, ensure all chart type records are used. enable_tables : bool, optional When true, ensure all table type records are used. + enable_images : bool, optional + When true, ensure all image type records are used. Returns ------- @@ -433,7 +446,7 @@ def create_bm25_model( all_text = [] for result in records: for element in result: - text = _pull_text(element, enable_text, enable_charts, enable_tables) + text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images) if text: all_text.append(text) @@ -452,6 +465,7 @@ def stream_insert_milvus( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, + enable_images: bool = True, record_func=_record_dict, ): """ @@ -474,6 +488,8 @@ def stream_insert_milvus( When true, ensure all chart type records are used. enable_tables : bool, optional When true, ensure all table type records are used. + enable_images : bool, optional + When true, ensure all image type records are used. record_func : function, optional This function will be used to parse the records for necessary information. @@ -481,7 +497,7 @@ def stream_insert_milvus( data = [] for result in records: for element in result: - text = _pull_text(element, enable_text, enable_charts, enable_tables) + text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images) if text: if sparse_model is not None: data.append(record_func(text, element, sparse_model.encode_documents([text]))) @@ -499,6 +515,7 @@ def write_to_nvingest_collection( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, + enable_images: bool = True, bm25_save_path: str = "bm25_model.json", compute_bm25_stats: bool = True, access_key: str = "minioadmin", @@ -527,6 +544,8 @@ def write_to_nvingest_collection( When true, ensure all chart type records are used. enable_tables : bool, optional When true, ensure all table type records are used. + enable_images : bool, optional + When true, ensure all image type records are used. sparse : bool, optional When true, incorporates sparse embedding representations for records. bm25_save_path : str, optional @@ -549,7 +568,11 @@ def write_to_nvingest_collection( bm25_ef = None if sparse and compute_bm25_stats: bm25_ef = create_bm25_model( - records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables + records, + enable_text=enable_text, + enable_charts=enable_charts, + enable_tables=enable_tables, + enable_images=enable_images, ) bm25_ef.save(bm25_save_path) elif sparse and not compute_bm25_stats: @@ -566,6 +589,7 @@ def write_to_nvingest_collection( enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables, + enable_images=enable_images, ) else: # Connections parameters to access the remote bucket @@ -586,6 +610,7 @@ def write_to_nvingest_collection( enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables, + enable_images=enable_images, ) bulk_insert_milvus(collection_name, writer, milvus_uri) # this sleep is required, to ensure atleast this amount of time