Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Milvus bm25 2.5.x #339

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 22 additions & 97 deletions client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
DataType,
CollectionSchema,
connections,
Function,
FunctionType,
utility,
BulkInsertState,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.milvus_client.index import IndexParams
from pymilvus.bulk_writer import RemoteBulkWriter, BulkFileType
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from scipy.sparse import csr_array
from typing import List
Expand Down Expand Up @@ -53,8 +53,6 @@ def __init__(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
Expand Down Expand Up @@ -125,12 +123,24 @@ def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> Colle
"""
schema = MilvusClient.create_schema(auto_id=True, enable_dynamic_field=True)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim)
schema.add_field(field_name="source", datatype=DataType.JSON)
schema.add_field(field_name="content_metadata", datatype=DataType.JSON)
if sparse:
schema.add_field(
field_name="text", datatype=DataType.VARCHAR, max_length=65535, enable_match=True, enable_analyzer=True
)
schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR)
function = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse",
)

schema.add_function(function)
else:
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)

return schema

Expand Down Expand Up @@ -195,8 +205,7 @@ def create_nvingest_index_params(
field_name="sparse",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX", # Index type for sparse vectors
metric_type="IP", # Currently, only IP (Inner Product) is supported for sparse vectors
params={"drop_ratio_build": 0.2}, # The ratio of small vector values to be dropped during indexing
metric_type="BM25",
)
return index_params

Expand Down Expand Up @@ -288,20 +297,13 @@ def create_nvingest_collection(
create_collection(client, collection_name, schema, index_params, recreate=recreate)


def _format_sparse_embedding(sparse_vector: csr_array):
sparse_embedding = {int(k[1]): float(v) for k, v in sparse_vector.todok()._dict.items()}
return sparse_embedding if len(sparse_embedding) > 0 else {int(0): float(0)}


def _record_dict(text, element, sparse_vector: csr_array = None):
record = {
"text": text,
"vector": element["metadata"]["embedding"],
"source": element["metadata"]["source_metadata"],
"content_metadata": element["metadata"]["content_metadata"],
}
if sparse_vector is not None:
record["sparse"] = _format_sparse_embedding(sparse_vector)
return record


Expand All @@ -321,7 +323,6 @@ def _pull_text(element, enable_text: bool, enable_charts: bool, enable_tables: b
def write_records_minio(
records,
writer: RemoteBulkWriter,
sparse_model=None,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
Expand All @@ -340,9 +341,6 @@ def write_records_minio(
writer : RemoteBulkWriter
The Milvus Remote BulkWriter instance that was created with necessary
params to access the minio instance corresponding to milvus.
sparse_model : model,
Sparse model used to generate sparse embedding in the form of
scipy.sparse.csr_array
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
Expand All @@ -361,10 +359,7 @@ def write_records_minio(
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
if text:
if sparse_model is not None:
writer.append_row(record_func(text, element, sparse_model.encode_documents([text])))
else:
writer.append_row(record_func(text, element))
writer.append_row(record_func(text, element))

writer.commit()
print(f"Wrote data to: {writer.batch_files}")
Expand Down Expand Up @@ -406,49 +401,10 @@ def bulk_insert_milvus(collection_name: str, writer: RemoteBulkWriter, milvus_ur
time.sleep(1)


def create_bm25_model(
records, enable_text: bool = True, enable_charts: bool = True, enable_tables: bool = True
) -> BM25EmbeddingFunction:
"""
This function takes the input records and creates a corpus,
factoring in filters (i.e. texts, charts, tables) and fits
a BM25 model with that information.

Parameters
----------
records : List
List of chunks with attached metadata
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
When true, ensure all chart type records are used.
enable_tables : bool, optional
When true, ensure all table type records are used.

Returns
-------
BM25EmbeddingFunction
Returns the model fitted to the selected corpus.
"""
all_text = []
for result in records:
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
if text:
all_text.append(text)

analyzer = build_default_analyzer(language="en")
bm25_ef = BM25EmbeddingFunction(analyzer)

bm25_ef.fit(all_text)
return bm25_ef


def stream_insert_milvus(
records,
client: MilvusClient,
collection_name: str,
sparse_model=None,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
Expand All @@ -465,9 +421,6 @@ def stream_insert_milvus(
List of chunks with attached metadata
collection_name : str
Milvus Collection to search against
sparse_model : model,
Sparse model used to generate sparse embedding in the form of
scipy.sparse.csr_array
enable_text : bool, optional
When true, ensure all text type records are used.
enable_charts : bool, optional
Expand All @@ -483,10 +436,7 @@ def stream_insert_milvus(
for element in result:
text = _pull_text(element, enable_text, enable_charts, enable_tables)
if text:
if sparse_model is not None:
data.append(record_func(text, element, sparse_model.encode_documents([text])))
else:
data.append(record_func(text, element))
data.append(record_func(text, element))
client.insert(collection_name=collection_name, data=data)


Expand All @@ -499,8 +449,6 @@ def write_to_nvingest_collection(
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
Expand Down Expand Up @@ -529,8 +477,6 @@ def write_to_nvingest_collection(
When true, ensure all table type records are used.
sparse : bool, optional
When true, incorporates sparse embedding representations for records.
bm25_save_path : str, optional
The desired filepath for the sparse model if sparse is True.
access_key : str, optional
Minio access key.
secret_key : str, optional
Expand All @@ -546,23 +492,13 @@ def write_to_nvingest_collection(
stream = True
else:
stream = True
bm25_ef = None
if sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables
)
bm25_ef.save(bm25_save_path)
elif sparse and not compute_bm25_stats:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(bm25_save_path)
client = MilvusClient(milvus_uri)
schema = Collection(collection_name).schema
if stream:
stream_insert_milvus(
records,
client,
collection_name,
bm25_ef,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
Expand All @@ -582,7 +518,6 @@ def write_to_nvingest_collection(
writer = write_records_minio(
records,
text_writer,
bm25_ef,
enable_text=enable_text,
enable_charts=enable_charts,
enable_tables=enable_tables,
Expand Down Expand Up @@ -647,7 +582,6 @@ def hybrid_retrieval(
collection_name: str,
client: MilvusClient,
dense_model,
sparse_model,
top_k: int,
dense_field: str = "vector",
sparse_field: str = "sparse",
Expand All @@ -670,9 +604,6 @@ def hybrid_retrieval(
Client connected to mivlus instance.
dense_model : NVIDIAEmbedding
Dense model to generate dense embeddings for queries.
sparse_model : model,
Sparse model used to generate sparse embedding in the form of
scipy.sparse.csr_array
top_k : int
Number of search results to return per query.
dense_field : str
Expand All @@ -688,10 +619,10 @@ def hybrid_retrieval(
Nested list of top_k results per query.
"""
dense_embeddings = []
sparse_embeddings = []
sparse_queries = []
for query in queries:
dense_embeddings.append(dense_model.get_query_embedding(query))
sparse_embeddings.append(_format_sparse_embedding(sparse_model.encode_queries([query])))
sparse_queries.append(query)

s_param_1 = {
"metric_type": "L2",
Expand All @@ -710,9 +641,9 @@ def hybrid_retrieval(
dense_req = AnnSearchRequest(**search_param_1)

search_param_2 = {
"data": sparse_embeddings,
"data": sparse_queries,
"anns_field": sparse_field,
"param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
"param": {"metric_type": "BM25"},
"limit": top_k * 2,
}
sparse_req = AnnSearchRequest(**search_param_2)
Expand All @@ -732,7 +663,6 @@ def nvingest_retrieval(
dense_field: str = "vector",
sparse_field: str = "sparse",
embedding_endpoint="http://localhost:8000/v1",
sparse_model_filepath: str = "bm25_model.json",
model_name: str = "nvidia/nv-embedqa-e5-v5",
output_fields: List[str] = ["text", "source", "content_metadata"],
gpu_search: bool = False,
Expand Down Expand Up @@ -763,8 +693,6 @@ def nvingest_retrieval(
vector the collection.
embedding_endpoint : str, optional
Number of search results to return per query.
sparse_model_filepath : str, optional
The path where the sparse model has been loaded.
model_name : str, optional
The name of the dense embedding model available in the NIM embedding endpoint.

Expand All @@ -779,14 +707,11 @@ def nvingest_retrieval(
if milvus_uri.endswith(".db"):
local_index = True
if hybrid:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(sparse_model_filepath)
results = hybrid_retrieval(
queries,
collection_name,
client,
embed_model,
bm25_ef,
top_k,
output_fields=output_fields,
gpu_search=gpu_search,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ services:
# Turn on to leverage the `vdb_upload` task
restart: always
container_name: milvus-standalone
image: milvusdb/milvus:v2.4.17-gpu
image: milvusdb/milvus:v2.5.3-gpu
command: [ "milvus", "run", "standalone" ]
hostname: milvus
security_opt:
Expand Down
Loading