Skip to content

Commit

Permalink
enhance search iterator extension
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Nov 24, 2023
1 parent 9075512 commit 2d915b6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@
MAX_BATCH_SIZE: int = 16384
DEFAULT_SEARCH_EXTENSION_RATE: int = 10
UNLIMITED: int = -1
MAX_TRY_TIME: int = 10
MAX_TRY_TIME: int = 20
60 changes: 41 additions & 19 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@
SearchIterator = TypeVar("SearchIterator")


def extend_batch_size(batch_size: int, next_param: dict) -> int:
def extend_batch_size(batch_size: int, next_param: dict, to_extend_batch_size: bool) -> int:
extend_rate = 1
if to_extend_batch_size:
extend_rate = DEFAULT_SEARCH_EXTENSION_RATE
if EF in next_param[PARAMS]:
return min(
MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE, next_param[PARAMS][EF]
)
return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE)
return min(MAX_BATCH_SIZE, batch_size * extend_rate, next_param[PARAMS][EF])
return min(MAX_BATCH_SIZE, batch_size * extend_rate)


class QueryIterator:
Expand Down Expand Up @@ -314,7 +315,7 @@ def __init__(
self.__init_search_iterator()

def __init_search_iterator(self):
init_page = self.__execute_next_search(self._param, self._expr)
init_page = self.__execute_next_search(self._param, self._expr, False)
if len(init_page) == 0:
message = (
"Cannot init search iterator because init page contains no matched rows, "
Expand All @@ -329,17 +330,28 @@ def __init_search_iterator(self):
self.__update_filtered_ids(init_page)
self._init_success = True

def __set_up_range_parameters(self, page: SearchPage):
def __update_width(self, page: SearchPage):
first_hit, last_hit = page[0], page[-1]
if metrics_positive_related(self._param[METRIC_TYPE]):
self._width = last_hit.distance - first_hit.distance
else:
self._width = first_hit.distance - last_hit.distance
self._tail_band = last_hit.distance
self._width *= 2
# commonly, we extend width to twice to avoid one more time search

def __set_up_range_parameters(self, page: SearchPage):
self.__update_width(page)
self._tail_band = page[-1].distance
LOGGER.debug(
f"set up init parameter for searchIterator width:{self._width} tail_band:{self._tail_band}"
)

def __check_reached_limit(self) -> bool:
if self._limit == UNLIMITED or self._returned_count < self._limit:
return False
LOGGER.debug(
f"reached search limit:{self._limit}, returned_count:{self._returned_count}, directly return"
)
return True

def __check_set_params(self, param: Dict):
Expand Down Expand Up @@ -473,6 +485,8 @@ def next(self):
cached_page_len = self.__push_new_page_to_cache(new_page)
ret_len = min(cached_page_len, ret_len)
ret_page = self.__extract_page_from_cache(ret_len)
if len(ret_page) == self._iterator_params[BATCH_SIZE]:
self.__update_width(ret_page)

# 3. update filter ids to avoid returning result repeatedly
self._returned_count += ret_len
Expand All @@ -485,29 +499,31 @@ def __try_search_fill(self) -> SearchPage:
while True:
next_params = self.__next_params(coefficient)
next_expr = self.__filtered_duplicated_result_expr(self._expr)
new_page = self.__execute_next_search(next_params, next_expr)
new_page = self.__execute_next_search(next_params, next_expr, True)
self.__update_filtered_ids(new_page)
try_time += 1
if len(new_page) > 0:
final_page.merge(new_page.get_res())
self._tail_band = new_page[-1].distance
# if the current ring contains vectors, we always set coefficient back to 1
coefficient = 1
else:
# if there's a ring containing no vectors matched, then we need to extend
# the ring continually to avoid empty ring problem
coefficient += 1
if len(final_page) > self._iterator_params[BATCH_SIZE] or try_time > MAX_TRY_TIME:
if len(final_page) >= self._iterator_params[BATCH_SIZE]:
break
if try_time > MAX_TRY_TIME:
LOGGER.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break")
break
# if there's a ring containing no vectors matched, then we need to extend
# the ring continually to avoid empty ring problem
coefficient += 1
return final_page

def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage:
def __execute_next_search(
self, next_params: dict, next_expr: str, to_extend_batch: bool
) -> SearchPage:
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
self._iterator_params["ann_field"],
next_params,
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params),
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch),
next_expr,
self._iterator_params["partition_names"],
self._iterator_params["output_fields"],
Expand Down Expand Up @@ -542,7 +558,7 @@ def __filtered_duplicated_result_expr(self, expr: str):

def __next_params(self, coefficient: int):
coefficient = max(1, coefficient)
next_params = self._param.copy()
next_params = deepcopy(self._param)
if metrics_positive_related(self._param[METRIC_TYPE]):
next_radius = self._tail_band + self._width * coefficient
if RADIUS in self._param[PARAMS] and next_radius > self._param[PARAMS][RADIUS]:
Expand All @@ -556,6 +572,12 @@ def __next_params(self, coefficient: int):
else:
next_params[PARAMS][RADIUS] = next_radius
next_params[PARAMS][RANGE_FILTER] = self._tail_band
LOGGER.debug(
f"next round search iteration radius:{next_params[PARAMS][RADIUS]},"
f"range_filter:{next_params[PARAMS][RANGE_FILTER]},"
f"coefficient:{coefficient}"
)

return next_params

def close(self):
Expand Down

0 comments on commit 2d915b6

Please sign in to comment.