From 2d915b62c598a9c61ade0bc6a2917f6d72c45c7e Mon Sep 17 00:00:00 2001 From: MrPresent-Han Date: Thu, 23 Nov 2023 16:25:50 +0800 Subject: [PATCH] enhance search iterator extension --- pymilvus/orm/constants.py | 2 +- pymilvus/orm/iterator.py | 60 ++++++++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index b69ae5294..0b85d1cff 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -49,4 +49,4 @@ MAX_BATCH_SIZE: int = 16384 DEFAULT_SEARCH_EXTENSION_RATE: int = 10 UNLIMITED: int = -1 -MAX_TRY_TIME: int = 10 +MAX_TRY_TIME: int = 20 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index bdc395e44..b39c4f06d 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -42,12 +42,13 @@ SearchIterator = TypeVar("SearchIterator") -def extend_batch_size(batch_size: int, next_param: dict) -> int: +def extend_batch_size(batch_size: int, next_param: dict, to_extend_batch_size: bool) -> int: + extend_rate = 1 + if to_extend_batch_size: + extend_rate = DEFAULT_SEARCH_EXTENSION_RATE if EF in next_param[PARAMS]: - return min( - MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE, next_param[PARAMS][EF] - ) - return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE) + return min(MAX_BATCH_SIZE, batch_size * extend_rate, next_param[PARAMS][EF]) + return min(MAX_BATCH_SIZE, batch_size * extend_rate) class QueryIterator: @@ -314,7 +315,7 @@ def __init__( self.__init_search_iterator() def __init_search_iterator(self): - init_page = self.__execute_next_search(self._param, self._expr) + init_page = self.__execute_next_search(self._param, self._expr, False) if len(init_page) == 0: message = ( "Cannot init search iterator because init page contains no matched rows, " @@ -329,17 +330,28 @@ def __init_search_iterator(self): self.__update_filtered_ids(init_page) self._init_success = True - def __set_up_range_parameters(self, page: SearchPage): + def __update_width(self, page: SearchPage): first_hit, last_hit = page[0], page[-1] if metrics_positive_related(self._param[METRIC_TYPE]): self._width = last_hit.distance - first_hit.distance else: self._width = first_hit.distance - last_hit.distance - self._tail_band = last_hit.distance + self._width *= 2 + # commonly, we extend width to twice to avoid one more time search + + def __set_up_range_parameters(self, page: SearchPage): + self.__update_width(page) + self._tail_band = page[-1].distance + LOGGER.debug( + f"set up init parameter for searchIterator width:{self._width} tail_band:{self._tail_band}" + ) def __check_reached_limit(self) -> bool: if self._limit == UNLIMITED or self._returned_count < self._limit: return False + LOGGER.debug( + f"reached search limit:{self._limit}, returned_count:{self._returned_count}, directly return" + ) return True def __check_set_params(self, param: Dict): @@ -473,6 +485,8 @@ def next(self): cached_page_len = self.__push_new_page_to_cache(new_page) ret_len = min(cached_page_len, ret_len) ret_page = self.__extract_page_from_cache(ret_len) + if len(ret_page) == self._iterator_params[BATCH_SIZE]: + self.__update_width(ret_page) # 3. update filter ids to avoid returning result repeatedly self._returned_count += ret_len @@ -485,29 +499,31 @@ def __try_search_fill(self) -> SearchPage: while True: next_params = self.__next_params(coefficient) next_expr = self.__filtered_duplicated_result_expr(self._expr) - new_page = self.__execute_next_search(next_params, next_expr) + new_page = self.__execute_next_search(next_params, next_expr, True) self.__update_filtered_ids(new_page) try_time += 1 if len(new_page) > 0: final_page.merge(new_page.get_res()) self._tail_band = new_page[-1].distance - # if the current ring contains vectors, we always set coefficient back to 1 - coefficient = 1 - else: - # if there's a ring containing no vectors matched, then we need to extend - # the ring continually to avoid empty ring problem - coefficient += 1 - if len(final_page) > self._iterator_params[BATCH_SIZE] or try_time > MAX_TRY_TIME: + if len(final_page) >= self._iterator_params[BATCH_SIZE]: + break + if try_time > MAX_TRY_TIME: + LOGGER.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break") break + # if there's a ring containing no vectors matched, then we need to extend + # the ring continually to avoid empty ring problem + coefficient += 1 return final_page - def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage: + def __execute_next_search( + self, next_params: dict, next_expr: str, to_extend_batch: bool + ) -> SearchPage: res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"], self._iterator_params["ann_field"], next_params, - extend_batch_size(self._iterator_params[BATCH_SIZE], next_params), + extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch), next_expr, self._iterator_params["partition_names"], self._iterator_params["output_fields"], @@ -542,7 +558,7 @@ def __filtered_duplicated_result_expr(self, expr: str): def __next_params(self, coefficient: int): coefficient = max(1, coefficient) - next_params = self._param.copy() + next_params = deepcopy(self._param) if metrics_positive_related(self._param[METRIC_TYPE]): next_radius = self._tail_band + self._width * coefficient if RADIUS in self._param[PARAMS] and next_radius > self._param[PARAMS][RADIUS]: @@ -556,6 +572,12 @@ def __next_params(self, coefficient: int): else: next_params[PARAMS][RADIUS] = next_radius next_params[PARAMS][RANGE_FILTER] = self._tail_band + LOGGER.debug( + f"next round search iteration radius:{next_params[PARAMS][RADIUS]}," + f"range_filter:{next_params[PARAMS][RANGE_FILTER]}," + f"coefficient:{coefficient}" + ) + return next_params def close(self):