From e977f2a70c9e83a38241533c7375e14749401b45 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Wed, 15 Jan 2025 11:52:59 -0500 Subject: [PATCH 1/9] add milvus logic to ingestor as psuedo task --- .../src/nv_ingest_client/client/interface.py | 15 +++- client/src/nv_ingest_client/util/milvus.py | 68 +++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/client/src/nv_ingest_client/client/interface.py b/client/src/nv_ingest_client/client/interface.py index e5746651..7f64104f 100644 --- a/client/src/nv_ingest_client/client/interface.py +++ b/client/src/nv_ingest_client/client/interface.py @@ -29,6 +29,7 @@ from nv_ingest_client.primitives.tasks import StoreTask from nv_ingest_client.primitives.tasks import VdbUploadTask from nv_ingest_client.util.util import filter_function_kwargs +from nv_ingest_client.util.milvus import MilvusOperator DEFAULT_JOB_QUEUE_ID = "morpheus_task_queue" @@ -223,7 +224,10 @@ def ingest(self, **kwargs: Any) -> List[Dict[str, Any]]: fetch_kwargs = filter_function_kwargs(self._client.fetch_job_result, **kwargs) result = self._client.fetch_job_result(self._job_ids, **fetch_kwargs) - + if self._vdb_bulk_upload: + self._vdb_bulk_upload.run(result) + # only upload as part of jobs user specified this action + self._vdb_bulk_upload = None return result def ingest_async(self, **kwargs: Any) -> Future: @@ -271,6 +275,11 @@ def _done_callback(future): for future in future_to_job_id: future.add_done_callback(_done_callback) + if self._vdb_bulk_upload: + self._vdb_bulk_upload.run(combined_future) + # only upload as part of jobs user specified this action + self._vdb_bulk_upload = None + return combined_future @ensure_job_specs @@ -494,6 +503,10 @@ def caption(self, **kwargs: Any) -> "Ingestor": return self + def vdb_bulk_upload(self, **kwargs): + self._vdb_bulk_upload = MilvusOperator(**kwargs) + return self + def _count_job_states(self, job_states: set[JobStateEnum]) -> int: """ Counts the jobs in specified states. diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index 744bef63..b303bccb 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -18,6 +18,74 @@ from typing import List import time from urllib.parse import urlparse +from typing import Union, Dict + + +class MilvusOperator: + def __init__( + self, + collection_name: Union[str, Dict] = "nv_ingest_collection", + milvus_uri: str = "http://localhost:19530", + sparse: bool = False, + recreate: bool = True, + gpu_index: bool = True, + gpu_search: bool = False, + dense_dim: int = 1024, + minio_endpoint: str = "localhost:9000", + enable_text: bool = True, + enable_charts: bool = True, + enable_tables: bool = True, + bm25_save_path: str = "bm25_model.json", + access_key: str = "minioadmin", + secret_key: str = "minioadmin", + bucket_name: str = "a-bucket", + ): + self.milvus_kwargs = locals() + self.collection_name = self.milvus_kwargs.pop("collection_name") + self.milvus_kwargs.pop("self") + + def get_connection_params(self): + conn_dict = { + "milvus_uri": self.milvus_kwargs["milvus_uri"], + "sparse": self.milvus_kwargs["sparse"], + "recreate": self.milvus_kwargs["recreate"], + "gpu_index": self.milvus_kwargs["gpu_index"], + "gpu_search": self.milvus_kwargs["gpu_search"], + "dense_dim": self.milvus_kwargs["dense_dim"], + } + return (self.collection_name, conn_dict) + + def get_write_params(self): + write_params = self.milvus_kwargs.copy() + del write_params["recreate"] + del write_params["gpu_index"] + del write_params["gpu_search"] + del write_params["dense_dim"] + + return (self.collection_name, write_params) + + def run(self, records): + collection_name, create_params = self.get_connection_params() + _, write_params = self.get_write_params() + if isinstance(collection_name, str): + create_nvingest_collection(collection_name, **create_params) + write_to_nvingest_collection(records, collection_name, **write_params) + elif isinstance(collection_name, dict): + for coll_name, data_type in collection_name.items(): + create_nvingest_collection(collection_name, **create_params) + enabled_dtypes = { + "enable_text": False, + "enable_charts": False, + "enable_tables": False, + } + if not isinstance(data_type, list): + data_type = [data_type] + for d_type in data_type: + enabled_dtypes[f"enable_{d_type}"] = True + write_params.update(enabled_dtypes) + write_to_nvingest_collection(records, collection_name, **write_params) + else: + raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}") def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> CollectionSchema: From afc237c22425da42bf410f1151c09d7ebeaa0c07 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Wed, 15 Jan 2025 13:01:36 -0500 Subject: [PATCH 2/9] move dict split logic out of class --- client/src/nv_ingest_client/util/milvus.py | 36 ++++++++++++++-------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index b303bccb..f2c506f2 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -21,6 +21,24 @@ from typing import Union, Dict +def _dict_to_params(collections_dict: dict, write_params: dict): + params_tuple_list = [] + for coll_name, data_type in collections_dict.items(): + cp_write_params = write_params.copy() + enabled_dtypes = { + "enable_text": False, + "enable_charts": False, + "enable_tables": False, + } + if not isinstance(data_type, list): + data_type = [data_type] + for d_type in data_type: + enabled_dtypes[f"enable_{d_type}"] = True + cp_write_params.update(enabled_dtypes) + params_tuple_list.append((coll_name, cp_write_params)) + return params_tuple_list + + class MilvusOperator: def __init__( self, @@ -71,19 +89,11 @@ def run(self, records): create_nvingest_collection(collection_name, **create_params) write_to_nvingest_collection(records, collection_name, **write_params) elif isinstance(collection_name, dict): - for coll_name, data_type in collection_name.items(): - create_nvingest_collection(collection_name, **create_params) - enabled_dtypes = { - "enable_text": False, - "enable_charts": False, - "enable_tables": False, - } - if not isinstance(data_type, list): - data_type = [data_type] - for d_type in data_type: - enabled_dtypes[f"enable_{d_type}"] = True - write_params.update(enabled_dtypes) - write_to_nvingest_collection(records, collection_name, **write_params) + split_params_list = _dict_to_params(collection_name, write_params) + for sub_params in split_params_list: + coll_name, sub_write_params = sub_params + create_nvingest_collection(coll_name, **create_params) + write_to_nvingest_collection(records, coll_name, **sub_write_params) else: raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}") From 88bddb87445b2e1830259e48af3cbda1e84194d6 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Wed, 15 Jan 2025 23:49:29 -0500 Subject: [PATCH 3/9] adding test cases and swap out vdb_upload logic --- .../src/nv_ingest_client/client/interface.py | 9 +-- client/src/nv_ingest_client/util/milvus.py | 15 +++-- .../nv_ingest_client/util/test_milvus_util.py | 61 +++++++++++++++++++ 3 files changed, 72 insertions(+), 13 deletions(-) create mode 100644 tests/nv_ingest_client/util/test_milvus_util.py diff --git a/client/src/nv_ingest_client/client/interface.py b/client/src/nv_ingest_client/client/interface.py index 7f64104f..8591c806 100644 --- a/client/src/nv_ingest_client/client/interface.py +++ b/client/src/nv_ingest_client/client/interface.py @@ -27,7 +27,6 @@ from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import StoreEmbedTask from nv_ingest_client.primitives.tasks import StoreTask -from nv_ingest_client.primitives.tasks import VdbUploadTask from nv_ingest_client.util.util import filter_function_kwargs from nv_ingest_client.util.milvus import MilvusOperator @@ -463,7 +462,6 @@ def store_embed(self, **kwargs: Any) -> "Ingestor": return self - @ensure_job_specs def vdb_upload(self, **kwargs: Any) -> "Ingestor": """ Adds a VdbUploadTask to the batch job specification. @@ -478,8 +476,7 @@ def vdb_upload(self, **kwargs: Any) -> "Ingestor": Ingestor Returns self for chaining. """ - vdb_upload_task = VdbUploadTask(**kwargs) - self._job_specs.add_task(vdb_upload_task) + self._vdb_bulk_upload = MilvusOperator(**kwargs) return self @@ -503,10 +500,6 @@ def caption(self, **kwargs: Any) -> "Ingestor": return self - def vdb_bulk_upload(self, **kwargs): - self._vdb_bulk_upload = MilvusOperator(**kwargs) - return self - def _count_job_states(self, job_states: set[JobStateEnum]) -> int: """ Counts the jobs in specified states. diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index f2c506f2..d5aeaeb6 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -44,7 +44,7 @@ def __init__( self, collection_name: Union[str, Dict] = "nv_ingest_collection", milvus_uri: str = "http://localhost:19530", - sparse: bool = False, + sparse: bool = True, recreate: bool = True, gpu_index: bool = True, gpu_search: bool = False, @@ -54,6 +54,7 @@ def __init__( enable_charts: bool = True, enable_tables: bool = True, bm25_save_path: str = "bm25_model.json", + compute_bm25_stats: bool = True, access_key: str = "minioadmin", secret_key: str = "minioadmin", bucket_name: str = "a-bucket", @@ -492,11 +493,12 @@ def write_to_nvingest_collection( collection_name: str, milvus_uri: str = "http://localhost:19530", minio_endpoint: str = "localhost:9000", - sparse: bool = False, + sparse: bool = True, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, bm25_save_path: str = "bm25_model.json", + compute_bm25_stats: bool = True, access_key: str = "minioadmin", secret_key: str = "minioadmin", bucket_name: str = "a-bucket", @@ -543,11 +545,14 @@ def write_to_nvingest_collection( else: stream = True bm25_ef = None - if sparse: + if sparse and compute_bm25_stats: bm25_ef = create_bm25_model( records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables ) bm25_ef.save(bm25_save_path) + elif sparse and not compute_bm25_stats: + bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en")) + bm25_ef.load(bm25_save_path) client = MilvusClient(milvus_uri) schema = Collection(collection_name).schema if stream: @@ -697,7 +702,7 @@ def hybrid_retrieval( "data": dense_embeddings, "anns_field": dense_field, "param": s_param_1, - "limit": top_k, + "limit": top_k * 2, } dense_req = AnnSearchRequest(**search_param_1) @@ -706,7 +711,7 @@ def hybrid_retrieval( "data": sparse_embeddings, "anns_field": sparse_field, "param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, - "limit": top_k, + "limit": top_k * 2, } sparse_req = AnnSearchRequest(**search_param_2) diff --git a/tests/nv_ingest_client/util/test_milvus_util.py b/tests/nv_ingest_client/util/test_milvus_util.py new file mode 100644 index 00000000..04e7bb68 --- /dev/null +++ b/tests/nv_ingest_client/util/test_milvus_util.py @@ -0,0 +1,61 @@ +import pytest +from nv_ingest_client.util.milvus import MilvusOperator, _dict_to_params + + +@pytest.fixture +def milvus_test_dict(): + mil_op = MilvusOperator() + kwargs = mil_op.milvus_kwargs + kwargs["collection_name"] = mil_op.collection_name + return kwargs + + +@pytest.mark.parametrize("collection_name", [None, "name"]) +def test_op_collection_name(collection_name): + if collection_name: + mo = MilvusOperator(collection_name=collection_name) + else: + # default + collection_name = "nv_ingest_collection" + mo = MilvusOperator() + cr_collection_name, conn_params = mo.get_connection_params() + wr_collection_name, write_params = mo.get_write_params() + assert cr_collection_name == wr_collection_name == collection_name + + +def test_op_connection_params(milvus_test_dict): + mo = MilvusOperator() + cr_collection_name, conn_params = mo.get_connection_params() + assert cr_collection_name == milvus_test_dict["collection_name"] + for k, v in conn_params.items(): + assert milvus_test_dict[k] == v + + +def test_op_write_params(milvus_test_dict): + mo = MilvusOperator() + collection_name, wr_params = mo.get_write_params() + assert collection_name == milvus_test_dict["collection_name"] + for k, v in wr_params.items(): + assert milvus_test_dict[k] == v + + +@pytest.mark.parametrize( + "collection_name, expected_results", + [ + ({"text": ["text", "charts", "tables"]}, {"enable_text": True, "enable_charts": True, "enable_tables": True}), + ({"text": ["text", "tables"]}, {"enable_text": True, "enable_charts": False, "enable_tables": True}), + ({"text": ["text", "charts"]}, {"enable_text": True, "enable_charts": True, "enable_tables": False}), + ({"text": ["text"]}, {"enable_text": True, "enable_charts": False, "enable_tables": False}), + ], +) +def test_op_dict_to_params(collection_name, expected_results): + mo = MilvusOperator() + _, wr_params = mo.get_write_params() + response = _dict_to_params(collection_name, wr_params) + if isinstance(collection_name, str): + collection_name = {collection_name: None} + for res in response: + coll_name, write_params = res + for k, v in expected_results.items(): + assert write_params[k] == v + coll_name in collection_name.keys() From 1df859e586a6be94547850b585ed25a8f344f244 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 16 Jan 2025 10:02:49 -0500 Subject: [PATCH 4/9] and vdb_bulk_upload switch to init for ingestor --- client/src/nv_ingest_client/client/interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/client/src/nv_ingest_client/client/interface.py b/client/src/nv_ingest_client/client/interface.py index 8591c806..0d4e3b0d 100644 --- a/client/src/nv_ingest_client/client/interface.py +++ b/client/src/nv_ingest_client/client/interface.py @@ -74,6 +74,7 @@ def __init__( self._documents = documents or [] self._client = client self._job_queue_id = job_queue_id + self._vdb_bulk_upload = None if self._client is None: client_kwargs = filter_function_kwargs(NvIngestClient, **kwargs) From df1bbbbf4bcf72342101a9c71700b8a3de050e8b Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 16 Jan 2025 10:30:57 -0500 Subject: [PATCH 5/9] ensure ability to catch unknown kwargs and remove internal --- client/src/nv_ingest_client/util/milvus.py | 3 +++ tests/nv_ingest_client/util/test_milvus_util.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index d5aeaeb6..4a481ecb 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -58,9 +58,12 @@ def __init__( access_key: str = "minioadmin", secret_key: str = "minioadmin", bucket_name: str = "a-bucket", + **kwargs, ): self.milvus_kwargs = locals() self.collection_name = self.milvus_kwargs.pop("collection_name") + for k, v in kwargs.items(): + self.milvus_kwargs.pop(k) self.milvus_kwargs.pop("self") def get_connection_params(self): diff --git a/tests/nv_ingest_client/util/test_milvus_util.py b/tests/nv_ingest_client/util/test_milvus_util.py index 04e7bb68..525ca288 100644 --- a/tests/nv_ingest_client/util/test_milvus_util.py +++ b/tests/nv_ingest_client/util/test_milvus_util.py @@ -10,6 +10,12 @@ def milvus_test_dict(): return kwargs +def test_extra_kwargs(milvus_test_dict): + mil_op = MilvusOperator(filter_errors=True) + milvus_test_dict.pop("collection_name") + assert mil_op.milvus_kwargs == milvus_test_dict + + @pytest.mark.parametrize("collection_name", [None, "name"]) def test_op_collection_name(collection_name): if collection_name: From 320522c1f3a337b7683347005c5e4a6df990f798 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 16 Jan 2025 10:40:29 -0500 Subject: [PATCH 6/9] fix pop of unnecessary kwargs --- client/src/nv_ingest_client/util/milvus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index 4a481ecb..554b8489 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -63,7 +63,7 @@ def __init__( self.milvus_kwargs = locals() self.collection_name = self.milvus_kwargs.pop("collection_name") for k, v in kwargs.items(): - self.milvus_kwargs.pop(k) + self.milvus_kwargs.pop(k, None) self.milvus_kwargs.pop("self") def get_connection_params(self): From cd75d61971622e3f96c2949aa4eeaeab24e7a332 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 16 Jan 2025 10:51:47 -0500 Subject: [PATCH 7/9] remove unwanted kwargs correctly --- client/src/nv_ingest_client/util/milvus.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index 554b8489..8c8a9b6f 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -61,10 +61,9 @@ def __init__( **kwargs, ): self.milvus_kwargs = locals() - self.collection_name = self.milvus_kwargs.pop("collection_name") - for k, v in kwargs.items(): - self.milvus_kwargs.pop(k, None) self.milvus_kwargs.pop("self") + self.collection_name = self.milvus_kwargs.pop("collection_name") + self.milvus_kwargs.pop("kwargs", None) def get_connection_params(self): conn_dict = { From 6ea3f12a07debb509dba75a9aad642923f42ab78 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 16 Jan 2025 11:22:30 -0500 Subject: [PATCH 8/9] fix ingestor tests --- tests/nv_ingest_client/client/test_interface.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/nv_ingest_client/client/test_interface.py b/tests/nv_ingest_client/client/test_interface.py index 41c72bc2..38e10c24 100644 --- a/tests/nv_ingest_client/client/test_interface.py +++ b/tests/nv_ingest_client/client/test_interface.py @@ -24,7 +24,7 @@ from nv_ingest_client.primitives.tasks import StoreEmbedTask from nv_ingest_client.primitives.tasks import StoreTask from nv_ingest_client.primitives.tasks import TableExtractionTask -from nv_ingest_client.primitives.tasks import VdbUploadTask +from nv_ingest_client.util.milvus import MilvusOperator MODULE_UNDER_TEST = "nv_ingest_client.client.interface" @@ -193,15 +193,13 @@ def test_store_task_some_args_extra_param(ingestor): def test_vdb_upload_task_no_args(ingestor): ingestor.vdb_upload() - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], VdbUploadTask) + assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator) def test_vdb_upload_task_some_args(ingestor): ingestor.vdb_upload(filter_errors=True) - task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] - assert isinstance(task, VdbUploadTask) - assert task._filter_errors is True + assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator) def test_caption_task_no_args(ingestor): @@ -228,8 +226,8 @@ def test_chain(ingestor): assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[5], FilterTask) assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[6], SplitTask) assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[7], StoreTask) - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[8], VdbUploadTask) - assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 9 + assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator) + assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 8 def test_ingest(ingestor, mock_client): From 73a9e753b2e191a22f0718900ba853339545a12e Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Fri, 17 Jan 2025 13:46:18 -0500 Subject: [PATCH 9/9] update to milvus 2.5.3 and use new bm25 setup --- client/src/nv_ingest_client/util/milvus.py | 119 ++++----------------- docker-compose.yaml | 2 +- 2 files changed, 23 insertions(+), 98 deletions(-) diff --git a/client/src/nv_ingest_client/util/milvus.py b/client/src/nv_ingest_client/util/milvus.py index 8c8a9b6f..359f0d7d 100644 --- a/client/src/nv_ingest_client/util/milvus.py +++ b/client/src/nv_ingest_client/util/milvus.py @@ -4,6 +4,8 @@ DataType, CollectionSchema, connections, + Function, + FunctionType, utility, BulkInsertState, AnnSearchRequest, @@ -11,8 +13,6 @@ ) from pymilvus.milvus_client.index import IndexParams from pymilvus.bulk_writer import RemoteBulkWriter, BulkFileType -from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer -from pymilvus.model.sparse import BM25EmbeddingFunction from llama_index.embeddings.nvidia import NVIDIAEmbedding from scipy.sparse import csr_array from typing import List @@ -53,8 +53,6 @@ def __init__( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, - bm25_save_path: str = "bm25_model.json", - compute_bm25_stats: bool = True, access_key: str = "minioadmin", secret_key: str = "minioadmin", bucket_name: str = "a-bucket", @@ -125,12 +123,24 @@ def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> Colle """ schema = MilvusClient.create_schema(auto_id=True, enable_dynamic_field=True) schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True, auto_id=True) - schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535) schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim) schema.add_field(field_name="source", datatype=DataType.JSON) schema.add_field(field_name="content_metadata", datatype=DataType.JSON) if sparse: + schema.add_field( + field_name="text", datatype=DataType.VARCHAR, max_length=65535, enable_match=True, enable_analyzer=True + ) schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) + function = Function( + name="bm25", + function_type=FunctionType.BM25, + input_field_names=["text"], + output_field_names="sparse", + ) + + schema.add_function(function) + else: + schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535) return schema @@ -195,8 +205,7 @@ def create_nvingest_index_params( field_name="sparse", index_name="sparse_index", index_type="SPARSE_INVERTED_INDEX", # Index type for sparse vectors - metric_type="IP", # Currently, only IP (Inner Product) is supported for sparse vectors - params={"drop_ratio_build": 0.2}, # The ratio of small vector values to be dropped during indexing + metric_type="BM25", ) return index_params @@ -288,11 +297,6 @@ def create_nvingest_collection( create_collection(client, collection_name, schema, index_params, recreate=recreate) -def _format_sparse_embedding(sparse_vector: csr_array): - sparse_embedding = {int(k[1]): float(v) for k, v in sparse_vector.todok()._dict.items()} - return sparse_embedding if len(sparse_embedding) > 0 else {int(0): float(0)} - - def _record_dict(text, element, sparse_vector: csr_array = None): record = { "text": text, @@ -300,8 +304,6 @@ def _record_dict(text, element, sparse_vector: csr_array = None): "source": element["metadata"]["source_metadata"], "content_metadata": element["metadata"]["content_metadata"], } - if sparse_vector is not None: - record["sparse"] = _format_sparse_embedding(sparse_vector) return record @@ -321,7 +323,6 @@ def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: b def write_records_minio( records, writer: RemoteBulkWriter, - sparse_model=None, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, @@ -340,9 +341,6 @@ def write_records_minio( writer : RemoteBulkWriter The Milvus Remote BulkWriter instance that was created with necessary params to access the minio instance corresponding to milvus. - sparse_model : model, - Sparse model used to generate sparse embedding in the form of - scipy.sparse.csr_array enable_text : bool, optional When true, ensure all text type records are used. enable_charts : bool, optional @@ -361,10 +359,7 @@ def write_records_minio( for element in result: text = _pull_text(element, enable_text, enable_charts, enable_tables) if text: - if sparse_model is not None: - writer.append_row(record_func(text, element, sparse_model.encode_documents([text]))) - else: - writer.append_row(record_func(text, element)) + writer.append_row(record_func(text, element)) writer.commit() print(f"Wrote data to: {writer.batch_files}") @@ -406,49 +401,10 @@ def bulk_insert_milvus(collection_name: str, writer: RemoteBulkWriter, milvus_ur time.sleep(1) -def create_bm25_model( - records, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True -) -> BM25EmbeddingFunction: - """ - This function takes the input records and creates a corpus, - factoring in filters (i.e. texts, charts, tables) and fits - a BM25 model with that information. - - Parameters - ---------- - records : List - List of chunks with attached metadata - enable_text : bool, optional - When true, ensure all text type records are used. - enable_charts : bool, optional - When true, ensure all chart type records are used. - enable_tables : bool, optional - When true, ensure all table type records are used. - - Returns - ------- - BM25EmbeddingFunction - Returns the model fitted to the selected corpus. - """ - all_text = [] - for result in records: - for element in result: - text = _pull_text(element, enable_text, enable_charts, enable_tables) - if text: - all_text.append(text) - - analyzer = build_default_analyzer(language="en") - bm25_ef = BM25EmbeddingFunction(analyzer) - - bm25_ef.fit(all_text) - return bm25_ef - - def stream_insert_milvus( records, client: MilvusClient, collection_name: str, - sparse_model=None, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, @@ -465,9 +421,6 @@ def stream_insert_milvus( List of chunks with attached metadata collection_name : str Milvus Collection to search against - sparse_model : model, - Sparse model used to generate sparse embedding in the form of - scipy.sparse.csr_array enable_text : bool, optional When true, ensure all text type records are used. enable_charts : bool, optional @@ -483,10 +436,7 @@ def stream_insert_milvus( for element in result: text = _pull_text(element, enable_text, enable_charts, enable_tables) if text: - if sparse_model is not None: - data.append(record_func(text, element, sparse_model.encode_documents([text]))) - else: - data.append(record_func(text, element)) + data.append(record_func(text, element)) client.insert(collection_name=collection_name, data=data) @@ -499,8 +449,6 @@ def write_to_nvingest_collection( enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True, - bm25_save_path: str = "bm25_model.json", - compute_bm25_stats: bool = True, access_key: str = "minioadmin", secret_key: str = "minioadmin", bucket_name: str = "a-bucket", @@ -529,8 +477,6 @@ def write_to_nvingest_collection( When true, ensure all table type records are used. sparse : bool, optional When true, incorporates sparse embedding representations for records. - bm25_save_path : str, optional - The desired filepath for the sparse model if sparse is True. access_key : str, optional Minio access key. secret_key : str, optional @@ -546,15 +492,6 @@ def write_to_nvingest_collection( stream = True else: stream = True - bm25_ef = None - if sparse and compute_bm25_stats: - bm25_ef = create_bm25_model( - records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables - ) - bm25_ef.save(bm25_save_path) - elif sparse and not compute_bm25_stats: - bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en")) - bm25_ef.load(bm25_save_path) client = MilvusClient(milvus_uri) schema = Collection(collection_name).schema if stream: @@ -562,7 +499,6 @@ def write_to_nvingest_collection( records, client, collection_name, - bm25_ef, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables, @@ -582,7 +518,6 @@ def write_to_nvingest_collection( writer = write_records_minio( records, text_writer, - bm25_ef, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables, @@ -647,7 +582,6 @@ def hybrid_retrieval( collection_name: str, client: MilvusClient, dense_model, - sparse_model, top_k: int, dense_field: str = "vector", sparse_field: str = "sparse", @@ -670,9 +604,6 @@ def hybrid_retrieval( Client connected to mivlus instance. dense_model : NVIDIAEmbedding Dense model to generate dense embeddings for queries. - sparse_model : model, - Sparse model used to generate sparse embedding in the form of - scipy.sparse.csr_array top_k : int Number of search results to return per query. dense_field : str @@ -688,10 +619,10 @@ def hybrid_retrieval( Nested list of top_k results per query. """ dense_embeddings = [] - sparse_embeddings = [] + sparse_queries = [] for query in queries: dense_embeddings.append(dense_model.get_query_embedding(query)) - sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query]))) + sparse_queries.append(query) s_param_1 = { "metric_type": "L2", @@ -710,9 +641,9 @@ def hybrid_retrieval( dense_req = AnnSearchRequest(**search_param_1) search_param_2 = { - "data": sparse_embeddings, + "data": sparse_queries, "anns_field": sparse_field, - "param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, + "param": {"metric_type": "BM25"}, "limit": top_k * 2, } sparse_req = AnnSearchRequest(**search_param_2) @@ -732,7 +663,6 @@ def nvingest_retrieval( dense_field: str = "vector", sparse_field: str = "sparse", embedding_endpoint="http://localhost:8000/v1", - sparse_model_filepath: str = "bm25_model.json", model_name: str = "nvidia/nv-embedqa-e5-v5", output_fields: List[str] = ["text", "source", "content_metadata"], gpu_search: bool = False, @@ -763,8 +693,6 @@ def nvingest_retrieval( vector the collection. embedding_endpoint : str, optional Number of search results to return per query. - sparse_model_filepath : str, optional - The path where the sparse model has been loaded. model_name : str, optional The name of the dense embedding model available in the NIM embedding endpoint. @@ -779,14 +707,11 @@ def nvingest_retrieval( if milvus_uri.endswith(".db"): local_index = True if hybrid: - bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en")) - bm25_ef.load(sparse_model_filepath) results = hybrid_retrieval( queries, collection_name, client, embed_model, - bm25_ef, top_k, output_fields=output_fields, gpu_search=gpu_search, diff --git a/docker-compose.yaml b/docker-compose.yaml index 8d9c307a..240baa07 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -296,7 +296,7 @@ services: # Turn on to leverage the `vdb_upload` task restart: always container_name: milvus-standalone - image: milvusdb/milvus:v2.4.17-gpu + image: milvusdb/milvus:v2.5.3-gpu command: [ "milvus", "run", "standalone" ] hostname: milvus security_opt: