Skip to content

Commit

Permalink
Add image types to ingestor vdb upload (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Jan 23, 2025
1 parent e3d08fd commit de6983c
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _dict_to_params(collections_dict: dict, write_params: dict):
"enable_text": False,
"enable_charts": False,
"enable_tables": False,
"enable_images": False,
}
if not isinstance(data_type, list):
data_type = [data_type]
Expand All @@ -53,6 +54,7 @@ def __init__(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
Expand Down Expand Up @@ -305,7 +307,7 @@ def _record_dict(text, element, sparse_vector: csr_array = None):
return record


def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: bool):
def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: bool, enable_images: bool):
text = None
if element["document_type"] == "text" and enable_text:
text = element["metadata"]["content"]
Expand All @@ -315,6 +317,8 @@ def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: b
text = None
elif element["metadata"]["content_metadata"]["subtype"] == "table" and not enable_tables:
text = None
elif element["document_type"] == "image" and enable_images:
text = element["metadata"]["image_metadata"]["caption"]
return text


Expand All @@ -325,6 +329,7 @@ def write_records_minio(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
record_func=_record_dict,
) -> RemoteBulkWriter:
"""
Expand All @@ -349,6 +354,8 @@ def write_records_minio(
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
record_func : function, optional
This function will be used to parse the records for necessary information.
Expand All @@ -359,7 +366,7 @@ def write_records_minio(
"""
for result in records:
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images)
if text:
if sparse_model is not None:
writer.append_row(record_func(text, element, sparse_model.encode_documents([text])))
Expand Down Expand Up @@ -407,7 +414,11 @@ def bulk_insert_milvus(collection_name: str, writer: RemoteBulkWriter, milvus_ur


def create_bm25_model(
records, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True
records,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
) -> BM25EmbeddingFunction:
"""
This function takes the input records and creates a corpus,
Expand All @@ -424,6 +435,8 @@ def create_bm25_model(
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
Returns
-------
Expand All @@ -433,7 +446,7 @@ def create_bm25_model(
all_text = []
for result in records:
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images)
if text:
all_text.append(text)

Expand All @@ -452,6 +465,7 @@ def stream_insert_milvus(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
record_func=_record_dict,
):
"""
Expand All @@ -474,14 +488,16 @@ def stream_insert_milvus(
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
record_func : function, optional
This function will be used to parse the records for necessary information.
"""
data = []
for result in records:
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
text = _pull_text(element, enable_text, enable_charts, enable_tables, enable_images)
if text:
if sparse_model is not None:
data.append(record_func(text, element, sparse_model.encode_documents([text])))
Expand All @@ -499,6 +515,7 @@ def write_to_nvingest_collection(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
enable_images: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
Expand Down Expand Up @@ -527,6 +544,8 @@ def write_to_nvingest_collection(
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.
enable_images : bool, optional
When true, ensure all image type records are used.
sparse : bool, optional
When true, incorporates sparse embedding representations for records.
bm25_save_path : str, optional
Expand All @@ -549,7 +568,11 @@ def write_to_nvingest_collection(
bm25_ef = None
if sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables
records,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
)
bm25_ef.save(bm25_save_path)
elif sparse and not compute_bm25_stats:
Expand All @@ -566,6 +589,7 @@ def write_to_nvingest_collection(
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
)
else:
# Connections parameters to access the remote bucket
Expand All @@ -586,6 +610,7 @@ def write_to_nvingest_collection(
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
enable_images=enable_images,
)
bulk_insert_milvus(collection_name, writer, milvus_uri)
# this sleep is required, to ensure atleast this amount of time
Expand Down

0 comments on commit de6983c

Please sign in to comment.