diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index 0d31c7af1..f0929210c 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,8 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +COLLECTION_ID = "collection_id" +DB_NAME = "db_name" ITER_SEARCH_V2_KEY = "search_iter_v2" ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 258711125..15b1229e1 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -37,7 +37,7 @@ is_legal_host, is_legal_port, ) -from .constants import ITERATOR_SESSION_TS_FIELD +from .constants import ITERATOR_SESSION_TS_FIELD, COLLECTION_ID, DB_NAME from .prepare import Prepare from .types import ( BulkInsertState, @@ -1625,6 +1625,8 @@ def query( extra_dict = get_cost_extra(response.status) extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts + extra_dict[COLLECTION_ID] = response.collection_id + extra_dict[DB_NAME] = response.db_name return ExtraList(results, extra=extra_dict) @retry_on_rpc_failure() diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index f0210b69d..0b6a8213e 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -25,6 +25,7 @@ ITER_SEARCH_LAST_BOUND_KEY, ITER_SEARCH_V2_KEY, ITERATOR_FIELD, + ITERATOR_COLLECTION_ID, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, REDUCE_STOP_FOR_BEST, @@ -1327,6 +1328,12 @@ def query_request( common_types.KeyValuePair(key=ITERATOR_FIELD, value=is_iterator) ) + iterator_collection_id = kwargs.get(ITERATOR_COLLECTION_ID) + if is_iterator is not None: + req.query_params.append( + common_types.KeyValuePair(key=ITERATOR_COLLECTION_ID, value=str(iterator_collection_id)) + ) + req.query_params.append( common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)) ) diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index cf6e2ebd4..88577d5c8 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -49,6 +49,8 @@ REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +COLLECTION_ID = "collection_id" +DB_NAME = "db_name" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0 @@ -65,3 +67,4 @@ ITERATOR_SESSION_CP_FILE = "iterator_cp_file" BM25_k1 = "bm25_k1" BM25_b = "bm25_b" + diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 7609309b4..ba16cd80e 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -13,7 +13,7 @@ ) from pymilvus.grpc_gen import milvus_pb2 as milvus_types -from .connections import Connections +from .connections import Connections, connections from .constants import ( BATCH_SIZE, CALC_DIST_BM25, @@ -29,6 +29,8 @@ GUARANTEE_TIMESTAMP, INT64_MAX, IS_PRIMARY, + COLLECTION_ID, + DB_NAME, ITERATOR_FIELD, ITERATOR_SESSION_CP_FILE, ITERATOR_SESSION_TS_FIELD, @@ -101,7 +103,6 @@ def __init__( ) -> QueryIterator: self._conn = connection self._collection_name = collection_name - self.__set_up_collection_id() self._output_fields = output_fields self._partition_names = partition_names self._schema = schema @@ -121,9 +122,37 @@ def __init__( self.__set_up_ts_cp() self.__seek_to_offset() - def __set_up_collection_id(self): - col_desc = self._conn.describe_collection(self._collection_name, timeout=60.0) - self._collection_id = col_desc.get("collection_id") + def __query_request(self, + collection_name: str, + expr: Optional[str] = None, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs): + # set db name and collection_id is existed + if self._collection_id is not None: + kwargs[COLLECTION_ID] = self._collection_id + if self._db_name is not None: + kwargs[DB_NAME] = self._db_name + + # query + res = self._conn.query( + collection_name=collection_name, + expr=expr, + output_field=output_fields, + partition_name=partition_names, + timeout=timeout, + **kwargs, + ) + + # reset db_name and collection_id if existed + collection_id = res.extra.get(COLLECTION_ID, 0) + if collection_id> 0 and self._collection_id is None: + self._collection_id = collection_id + db_name = res.extra.get(DB_NAME, "") + if db_name is not None and db_name != "" and self._db_name is None: + self._db_name = db_name + return res def __seek_to_offset(self): # read pk cursor from cp file, no need to seek offset @@ -147,7 +176,6 @@ def seek_offset_by_batch(batch: int, expr: str) -> int: timeout=self._timeout, **seek_params, ) - self.__check_collection_match(res) self.__update_cursor(res) return len(res) @@ -340,7 +368,6 @@ def next(self): ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))] ret = self.__check_reached_limit(ret) - self.__check_collection_match(res) self.__update_cursor(ret) io_operation(self.__save_pk_cursor, "failed to save pk cursor") self._returned_count += len(ret) @@ -381,19 +408,6 @@ def __setup_next_expr(self) -> str: return filtered_pk_str return "(" + current_expr + ")" + " and " + filtered_pk_str - def __check_collection_match(self, res: List): - res_collection_id = res[-1]["collection_id"] - if ( - res_collection_id is not None - and res_collection_id > 0 - and res_collection_id != self._collection_id - ): - raise MilvusException( - message="collection_id in the result is not the " - "same as the inited collection id, the alias may be changed, cut off" - "iterator connection" - ) - def __update_cursor(self, res: List) -> None: if len(res) == 0: return