From cd4aab79b059d06c59f161a754a5ad318ac07530 Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Wed, 18 Dec 2024 03:50:39 -0500 Subject: [PATCH 1/2] enhance: support milvus-client iterator (#2461) related: #2464 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- examples/iterator/iterator.py | 99 +++++++++++++++++++ pymilvus/client/utils.py | 23 +++++ pymilvus/milvus_client/milvus_client.py | 124 ++++++++++++++++++++++++ pymilvus/orm/constants.py | 1 + 4 files changed, 247 insertions(+) create mode 100644 examples/iterator/iterator.py diff --git a/examples/iterator/iterator.py b/examples/iterator/iterator.py new file mode 100644 index 000000000..aa87815b5 --- /dev/null +++ b/examples/iterator/iterator.py @@ -0,0 +1,99 @@ +from pymilvus.milvus_client.milvus_client import MilvusClient +from pymilvus import ( + FieldSchema, CollectionSchema, DataType, +) +import numpy as np + +collection_name = "test_milvus_client_iterator" +prepare_new_data = True +clean_exist = True + +USER_ID = "id" +AGE = "age" +DEPOSIT = "deposit" +PICTURE = "picture" +DIM = 8 +NUM_ENTITIES = 10000 +rng = np.random.default_rng(seed=19530) + + +def test_query_iterator(milvus_client: MilvusClient): + # test query iterator + expr = f"10 <= {AGE} <= 25" + output_fields = [USER_ID, AGE] + queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields) + page_idx = 0 + while True: + res = queryIt.next() + if len(res) == 0: + print("query iteration finished, close") + queryIt.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + +def test_search_iterator(milvus_client: MilvusClient): + vector_to_search = rng.random((1, DIM), np.float32) + search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE) + + page_idx = 0 + while True: + res = search_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + search_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + + +def main(): + milvus_client = MilvusClient("http://localhost:19530") + if milvus_client.has_collection(collection_name) and clean_exist: + milvus_client.drop_collection(collection_name) + print(f"dropped existed collection{collection_name}") + + if not milvus_client.has_collection(collection_name): + fields = [ + FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name=AGE, dtype=DataType.INT64), + FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE), + FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM) + ] + schema = CollectionSchema(fields) + milvus_client.create_collection(collection_name, dimension=DIM, schema=schema) + + if prepare_new_data: + entities = [] + for i in range(NUM_ENTITIES): + entity = { + USER_ID: i, + AGE: (i % 100), + DEPOSIT: float(i), + PICTURE: rng.random((1, DIM))[0] + } + entities.append(entity) + milvus_client.insert(collection_name, entities) + milvus_client.flush(collection_name) + print(f"Finish flush collections:{collection_name}") + + index_params = milvus_client.prepare_index_params() + + index_params.add_index( + field_name=PICTURE, + index_type='IVF_FLAT', + metric_type='L2', + params={"nlist": 1024} + ) + milvus_client.create_index(collection_name, index_params) + milvus_client.load_collection(collection_name) + test_query_iterator(milvus_client=milvus_client) + test_search_iterator(milvus_client=milvus_client) + + +if __name__ == '__main__': + main() diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 46bc8173f..1ddee57bf 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -375,3 +375,26 @@ def is_scipy_sparse(cls, data: Any): "csr_array", "spmatrix", ] + + +def is_sparse_vector_type(data_type: DataType) -> bool: + return data_type == data_type.SPARSE_FLOAT_VECTOR + + +dense_vector_type_set = {DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR} + + +def is_dense_vector_type(data_type: DataType) -> bool: + return data_type in dense_vector_type_set + + +def is_float_vector_type(data_type: DataType): + return is_sparse_vector_type(data_type) or is_dense_vector_type(data_type) + + +def is_binary_vector_type(data_type: DataType): + return data_type == DataType.BINARY_VECTOR + + +def is_vector_type(data_type: DataType): + return is_float_vector_type(data_type) or is_binary_vector_type(data_type) diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index f80e2e396..85e4cfe86 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -13,8 +13,10 @@ OmitZeroDict, construct_cost_extra, ) +from pymilvus.client.utils import is_vector_type from pymilvus.exceptions import ( DataTypeNotMatchException, + ErrorCode, MilvusException, ParamError, PrimaryKeyException, @@ -22,6 +24,8 @@ from pymilvus.orm import utility from pymilvus.orm.collection import CollectionSchema from pymilvus.orm.connections import connections +from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED +from pymilvus.orm.iterator import QueryIterator, SearchIterator from pymilvus.orm.types import DataType from .index import IndexParams @@ -480,6 +484,126 @@ def query( return res + def query_iterator( + self, + collection_name: str, + batch_size: Optional[int] = 1000, + limit: Optional[int] = UNLIMITED, + filter: Optional[str] = "", + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if filter is not None and not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + + conn = self._get_connection() + # set up schema for iterator + try: + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + + return QueryIterator( + connection=conn, + collection_name=collection_name, + batch_size=batch_size, + limit=limit, + expr=filter, + output_fields=output_fields, + partition_names=partition_names, + schema=schema_dict, + timeout=timeout, + **kwargs, + ) + + def search_iterator( + self, + collection_name: str, + data: Union[List[list], list], + batch_size: Optional[int] = 1000, + filter: Optional[str] = None, + limit: Optional[int] = UNLIMITED, + output_fields: Optional[List[str]] = None, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + round_decimal: int = -1, + **kwargs, + ): + if filter is not None and not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + + conn = self._get_connection() + # set up schema for iterator + try: + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + # if anns_field is not provided + # if only one vector field, use to search + # if multiple vector fields, raise exception and abort + if anns_field is None or anns_field == "": + vec_field = None + fields = schema_dict[FIELDS] + vec_field_count = 0 + for field in fields: + if is_vector_type(field[TYPE]): + vec_field_count += 1 + vec_field = field + if vec_field is None: + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message="there should be at least one vector field in milvus collection", + ) + if vec_field_count > 1: + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message="must specify anns_field when there are more than one vector field", + ) + anns_field = vec_field["name"] + if anns_field is None or anns_field == "": + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message=f"cannot get anns_field name for search iterator, got:{anns_field}", + ) + # set up metrics type for search_iterator which is mandatory + if search_params is None: + search_params = {} + if METRIC_TYPE not in search_params: + indexes = conn.list_indexes(collection_name) + for index in indexes: + if anns_field == index.index_name: + params = index.params + for param in params: + if param.key == METRIC_TYPE: + search_params[METRIC_TYPE] = param.value + if METRIC_TYPE not in search_params: + raise MilvusException( + ParamError, f"Cannot set up metrics type for anns_field:{anns_field}" + ) + + return SearchIterator( + connection=self._get_connection(), + collection_name=collection_name, + data=data, + ann_field=anns_field, + param=search_params, + batch_size=batch_size, + limit=limit, + expr=filter, + partition_names=partition_names, + output_fields=output_fields, + timeout=timeout, + round_decimal=round_decimal, + schema=schema_dict, + **kwargs, + ) + def get( self, collection_name: str, diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index d7b666501..cf6e2ebd4 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -37,6 +37,7 @@ MILVUS_LIMIT = "limit" BATCH_SIZE = "batch_size" ID = "id" +TYPE = "type" METRIC_TYPE = "metric_type" PARAMS = "params" DISTANCE = "distance" From 52c366c1cdad029f51db0e7509b9a2705f748d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=B8=85=27=CF=89=27=E0=B8=85?= Date: Wed, 18 Dec 2024 16:54:39 +0800 Subject: [PATCH 2/2] fix: add `authorization_interceptor` and `db_interceptor` to async channel (#2467) Signed-off-by: Ruichen Bao --- examples/simple_async.py | 1 + pymilvus/client/async_grpc_handler.py | 46 ++++++++++++++++----------- pymilvus/client/async_interceptor.py | 2 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/examples/simple_async.py b/examples/simple_async.py index ee51f4389..91d7d681c 100644 --- a/examples/simple_async.py +++ b/examples/simple_async.py @@ -31,6 +31,7 @@ index_params.add_index(field_name = "embeddings", index_type = "HNSW", metric_type="L2", nlist=128) index_params.add_index(field_name = "embeddings2",index_type = "HNSW", metric_type="L2", nlist=128) +# Always use `await` when you want to guarantee the execution order of tasks. async def recreate_collection(): print(fmt.format("Start dropping collection")) await async_milvus_client.drop_collection(collection_name) diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py index 794858dd7..1d8825be6 100644 --- a/pymilvus/client/async_grpc_handler.py +++ b/pymilvus/client/async_grpc_handler.py @@ -19,7 +19,7 @@ from pymilvus.grpc_gen import milvus_pb2_grpc from pymilvus.settings import Config -from . import entity_helper, interceptor, ts_utils, utils +from . import entity_helper, ts_utils, utils from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult from .async_interceptor import async_header_adder_interceptor from .check import ( @@ -62,8 +62,8 @@ def __init__( self._request_id = None self._user = kwargs.get("user") self._set_authorization(**kwargs) - self._setup_db_interceptor(kwargs.get("db_name")) - self._setup_grpc_channel() # init channel and stub + self._setup_db_name(kwargs.get("db_name")) + self._setup_grpc_channel(**kwargs) self.callbacks = [] def register_state_change_callback(self, callback: Callable): @@ -96,12 +96,7 @@ def _set_authorization(self, **kwargs): self._server_pem_path = kwargs.get("server_pem_path", "") self._server_name = kwargs.get("server_name", "") - self._authorization_interceptor = None - self._setup_authorization_interceptor( - kwargs.get("user"), - kwargs.get("password"), - kwargs.get("token"), - ) + self._async_authorization_interceptor = None def __enter__(self): return self @@ -132,7 +127,7 @@ def close(self): self._async_channel.close() def reset_db_name(self, db_name: str): - self._setup_db_interceptor(db_name) + self._setup_db_name(db_name) self._setup_grpc_channel() self._setup_identifier_interceptor(self._user) @@ -148,16 +143,19 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str) keys.append("authorization") values.append(authorization) if len(keys) > 0 and len(values) > 0: - self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values) + self._async_authorization_interceptor = async_header_adder_interceptor(keys, values) + self._final_channel._unary_unary_interceptors.append( + self._async_authorization_interceptor + ) - def _setup_db_interceptor(self, db_name: str): + def _setup_db_name(self, db_name: str): if db_name is None: - self._db_interceptor = None + self._db_name = None else: check_pass_param(db_name=db_name) - self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name]) + self._db_name = db_name - def _setup_grpc_channel(self): + def _setup_grpc_channel(self, **kwargs): if self._async_channel is None: opts = [ (cygrpc.ChannelArgKey.max_send_message_length, -1), @@ -203,21 +201,31 @@ def _setup_grpc_channel(self): # avoid to add duplicate headers. self._final_channel = self._async_channel - if self._log_level: + if self._async_authorization_interceptor: + self._final_channel._unary_unary_interceptors.append( + self._async_authorization_interceptor + ) + else: + self._setup_authorization_interceptor( + kwargs.get("user"), + kwargs.get("password"), + kwargs.get("token"), + ) + if self._db_name: + async_db_interceptor = async_header_adder_interceptor(["dbname"], [self._db_name]) + self._final_channel._unary_unary_interceptors.append(async_db_interceptor) + if self._log_level: async_log_level_interceptor = async_header_adder_interceptor( ["log_level"], [self._log_level] ) self._final_channel._unary_unary_interceptors.append(async_log_level_interceptor) - self._log_level = None if self._request_id: - async_request_id_interceptor = async_header_adder_interceptor( ["client_request_id"], [self._request_id] ) self._final_channel._unary_unary_interceptors.append(async_request_id_interceptor) - self._request_id = None self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) diff --git a/pymilvus/client/async_interceptor.py b/pymilvus/client/async_interceptor.py index c456a44f2..db96b416f 100644 --- a/pymilvus/client/async_interceptor.py +++ b/pymilvus/client/async_interceptor.py @@ -72,7 +72,7 @@ async def intercept_stream_stream( return await continuation(new_details, new_request_iterator) -def async_header_adder_interceptor(headers: List[str], values: List[str]): +def async_header_adder_interceptor(headers: List[str], values: Union[List[str], List[bytes]]): def intercept_call(client_call_details: ClientCallDetails, request: Any): metadata = [] if client_call_details.metadata: