diff --git a/pymilvus/client/search_iterator.py b/pymilvus/client/search_iterator.py index 0ed490950..aabedd9d1 100644 --- a/pymilvus/client/search_iterator.py +++ b/pymilvus/client/search_iterator.py @@ -22,9 +22,6 @@ class SearchIteratorV2: - # for compatibility, save the first result during init - _saved_first_res = None - _is_saved = False # for compatibility, track the number of total results left _left_res_cnt = None @@ -51,7 +48,7 @@ def __init__( self._left_res_cnt = limit self._conn = connection - self.__set_up_collection_id(collection_name) + self._set_up_collection_id(collection_name) kwargs[COLLECTION_ID] = self._collection_id self._params = { "collection_name": collection_name, @@ -70,35 +67,41 @@ def __init__( GUARANTEE_TIMESTAMP: 0, **kwargs, } - # this raises MilvusException if the server does not support V2 - self._saved_first_res = self.next() - self._is_saved = True + self._probe_for_compability(self._params) - def __set_up_collection_id(self, collection_name: str): + def _set_up_collection_id(self, collection_name: str): res = self._conn.describe_collection(collection_name) self._collection_id = res[COLLECTION_ID] + def _check_token_exists(self, token: Union[str, None]): + if token is None or token == "": + raise ServerVersionIncompatibleException( + message=ExceptionsMessage.SearchIteratorV2FallbackWarning + ) + + # this detects whether the server supports search_iterator_v2 and is for compatibility only + # if the server holds iterator states, this implementation needs to be reconsidered + def _probe_for_compability(self, params: Dict): + dummy_params = deepcopy(params) + dummy_batch_size = 1 + dummy_params["limit"] = dummy_batch_size + dummy_params[ITER_SEARCH_BATCH_SIZE_KEY] = dummy_batch_size + iter_info = self._conn.search(**dummy_params).get_search_iterator_v2_results_info() + self._check_token_exists(iter_info.token) + def next(self): - # for compatibility - if self._is_saved: - self._is_saved = False - return self._saved_first_res - self._saved_first_res = None if self._left_res_cnt is not None and self._left_res_cnt <= 0: return SearchPage(None) res = self._conn.search(**self._params) iter_info = res.get_search_iterator_v2_results_info() + self._check_token_exists(iter_info.token) self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound # patch token and guarantee timestamp for the first next() call if ITER_SEARCH_ID_KEY not in self._params: - if iter_info.token is not None and iter_info.token != "": - self._params[ITER_SEARCH_ID_KEY] = iter_info.token - else: - raise ServerVersionIncompatibleException( - message=ExceptionsMessage.SearchIteratorV2FallbackWarning - ) + # the token should not change during the lifetime of the iterator + self._params[ITER_SEARCH_ID_KEY] = iter_info.token if self._params[GUARANTEE_TIMESTAMP] <= 0: if res.get_session_ts() > 0: self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts()