diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 918e3925a..b5a5c385f 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -471,6 +471,8 @@ def __init__( round_decimal: Optional[int] = None, status: Optional[common_pb2.Status] = None, session_ts: Optional[int] = 0, + collection_id: Optional[int] = 0, + db_name: Optional[str] = None, ): self._nq = res.num_queries all_topks = res.topks @@ -505,8 +507,16 @@ def __init__( nq_thres += topk self._session_ts = session_ts self._search_iterator_v2_results = res.search_iterator_v2_results + self._collection_id = collection_id + self._db_name = db_name super().__init__(data) + def get_collection_id(self): + return self._collection_id + + def get_db_name(self): + return self._db_name + def get_session_ts(self): return self._session_ts diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index 0d31c7af1..f0929210c 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,8 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +COLLECTION_ID = "collection_id" +DB_NAME = "db_name" ITER_SEARCH_V2_KEY = "search_iter_v2" ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 258711125..560c31e99 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -37,7 +37,7 @@ is_legal_host, is_legal_port, ) -from .constants import ITERATOR_SESSION_TS_FIELD +from .constants import COLLECTION_ID, DB_NAME, ITERATOR_SESSION_TS_FIELD from .prepare import Prepare from .types import ( BulkInsertState, @@ -770,6 +770,8 @@ def _execute_search( round_decimal, status=response.status, session_ts=response.session_ts, + collection_id=response.collection_id, + db_name=response.db_name, ) except Exception as e: if kwargs.get("_async", False): @@ -1625,6 +1627,8 @@ def query( extra_dict = get_cost_extra(response.status) extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts + extra_dict[COLLECTION_ID] = response.collection_id + extra_dict[DB_NAME] = response.db_name return ExtraList(results, extra=extra_dict) @retry_on_rpc_failure() diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index f0210b69d..0d870c0ac 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -15,6 +15,8 @@ from . import __version__, blob, check, entity_helper, ts_utils, utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( + COLLECTION_ID, + DB_NAME, DEFAULT_CONSISTENCY_LEVEL, DYNAMIC_FIELD_NAME, GROUP_BY_FIELD, @@ -961,6 +963,14 @@ def search_requests_with_expr( if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator + db_name = kwargs.get(DB_NAME) + if db_name is not None: + search_params[DB_NAME] = db_name + + collection_id = kwargs.get(COLLECTION_ID) + if collection_id is not None: + search_params[COLLECTION_ID] = collection_id + is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY) if is_search_iter_v2 is not None: search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2 @@ -1327,6 +1337,16 @@ def query_request( common_types.KeyValuePair(key=ITERATOR_FIELD, value=is_iterator) ) + collection_id = kwargs.get(COLLECTION_ID) + if collection_id is not None: + req.query_params.append( + common_types.KeyValuePair(key=COLLECTION_ID, value=str(collection_id)) + ) + + db_name = kwargs.get(DB_NAME) + if db_name is not None: + req.query_params.append(common_types.KeyValuePair(key=DB_NAME, value=str(db_name))) + req.query_params.append( common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)) ) diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index cf6e2ebd4..d2c2116ed 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -49,6 +49,8 @@ REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +COLLECTION_ID = "collection_id" +DB_NAME = "db_name" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 6bd6115a5..73f0263f0 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -23,6 +23,8 @@ CALC_DIST_JACCARD, CALC_DIST_L2, CALC_DIST_TANIMOTO, + COLLECTION_ID, + DB_NAME, DEFAULT_SEARCH_EXTENSION_RATE, EF, FIELDS, @@ -101,6 +103,8 @@ def __init__( ) -> QueryIterator: self._conn = connection self._collection_name = collection_name + self._collection_id = 0 + self._db_name = "" self._output_fields = output_fields self._partition_names = partition_names self._schema = schema @@ -120,6 +124,40 @@ def __init__( self.__set_up_ts_cp() self.__seek_to_offset() + def __query_request( + self, + collection_name: str, + expr: Optional[str] = None, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + # set db name and collection_id is existed + if self._collection_id > 0: + kwargs[COLLECTION_ID] = self._collection_id + if self._db_name != "": + kwargs[DB_NAME] = self._db_name + + # query + res = self._conn.query( + collection_name=collection_name, + expr=expr, + output_field=output_fields, + partition_name=partition_names, + timeout=timeout, + **kwargs, + ) + + # reset db_name and collection_id if existed + collection_id = res.extra.get(COLLECTION_ID, 0) + if collection_id > 0 and self._collection_id == 0: + self._collection_id = collection_id + db_name = res.extra.get(DB_NAME, "") + if db_name is not None and db_name != "" and self._db_name == "": + self._db_name = db_name + return res + def __seek_to_offset(self): # read pk cursor from cp file, no need to seek offset if self._next_id is not None: @@ -134,11 +172,11 @@ def __seek_to_offset(self): def seek_offset_by_batch(batch: int, expr: str) -> int: seek_params[MILVUS_LIMIT] = batch - res = self._conn.query( - collection_name=self._collection_name, + res = self.__query_request( + self._collection_name, expr=expr, - output_field=[], - partition_name=self._partition_names, + output_fields=[], + partition_names=self._partition_names, timeout=self._timeout, **seek_params, ) @@ -234,14 +272,15 @@ def __setup_ts_by_request(self): init_ts_kwargs[OFFSET] = 0 init_ts_kwargs[MILVUS_LIMIT] = 1 # just to set up mvccTs for iterator, no need correct limit - res = self._conn.query( - collection_name=self._collection_name, + res = self.__query_request( + self._collection_name, expr=self._expr, - output_field=self._output_fields, - partition_name=self._partition_names, + output_fields=[], + partition_names=self._partition_names, timeout=self._timeout, **init_ts_kwargs, ) + if res is None: raise MilvusException( message="failed to connect to milvus for setting up " @@ -322,14 +361,15 @@ def next(self): iterator_cache.release_cache(self._cache_id_in_use) current_expr = self.__setup_next_expr() log.debug(f"query_iterator_next_expr:{current_expr}") - res = self._conn.query( - collection_name=self._collection_name, + res = self.__query_request( + self._collection_name, expr=current_expr, output_fields=self._output_fields, partition_names=self._partition_names, timeout=self._timeout, **self._kwargs, ) + self.__maybe_cache(res) ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))] @@ -406,13 +446,27 @@ class SearchPage(LoopBase): """Since we only support nq=1 in search iteration, so search iteration response should be different from raw response of search operation""" - def __init__(self, res: Hits, session_ts: Optional[int] = 0): + def __init__( + self, + res: Hits, + session_ts: Optional[int] = 0, + collection_id: Optional[int] = 0, + db_name: Optional[str] = None, + ): super().__init__() self._session_ts = session_ts + self._collection_id = collection_id + self._db_name = db_name self._results = [] if res is not None: self._results.append(res) + def get_collection_id(self): + return self._collection_id + + def get_db_name(self): + return self._db_name + def get_session_ts(self): return self._session_ts @@ -515,6 +569,10 @@ def __init__( def __init_search_iterator(self): init_page = self.__execute_next_search(self._param, self._expr, False) + self._db_name = init_page.get_db_name() + self._collection_id = init_page.get_collection_id() + self._kwargs[COLLECTION_ID] = self._collection_id + self._kwargs[DB_NAME] = self._db_name self._session_ts = init_page.get_session_ts() if self._session_ts <= 0: log.warning("failed to set up mvccTs from milvus server, use client-side ts instead") @@ -736,7 +794,7 @@ def __execute_next_search( schema=self._schema, **self._kwargs, ) - return SearchPage(res[0], res.get_session_ts()) + return SearchPage(res[0], res.get_session_ts(), res.get_collection_id(), res.get_db_name()) # at present, the range_filter parameter means 'larger/less and equal', # so there would be vectors with same distances returned multiple times in different pages