Skip to content

Commit

Permalink
fix: iterator alias mismatch(#2555)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Jan 13, 2025
1 parent 3b236f0 commit c2631f9
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 13 deletions.
10 changes: 10 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def __init__(
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None,
session_ts: Optional[int] = 0,
collection_id: Optional[int] = 0,
db_name: Optional[str] = None,
):
self._nq = res.num_queries
all_topks = res.topks
Expand Down Expand Up @@ -505,8 +507,16 @@ def __init__(
nq_thres += topk
self._session_ts = session_ts
self._search_iterator_v2_results = res.search_iterator_v2_results
self._collection_id = collection_id
self._db_name = db_name
super().__init__(data)

def get_collection_id(self):
return self._collection_id

def get_db_name(self):
return self._db_name

def get_session_ts(self):
return self._session_ts

Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
COLLECTION_ID = "collection_id"
DB_NAME = "db_name"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
Expand Down
6 changes: 5 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
is_legal_host,
is_legal_port,
)
from .constants import ITERATOR_SESSION_TS_FIELD
from .constants import COLLECTION_ID, DB_NAME, ITERATOR_SESSION_TS_FIELD
from .prepare import Prepare
from .types import (
BulkInsertState,
Expand Down Expand Up @@ -770,6 +770,8 @@ def _execute_search(
round_decimal,
status=response.status,
session_ts=response.session_ts,
collection_id=response.collection_id,
db_name=response.db_name,
)
except Exception as e:
if kwargs.get("_async", False):
Expand Down Expand Up @@ -1625,6 +1627,8 @@ def query(

extra_dict = get_cost_extra(response.status)
extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts
extra_dict[COLLECTION_ID] = response.collection_id
extra_dict[DB_NAME] = response.db_name
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
Expand Down
20 changes: 20 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from . import __version__, blob, check, entity_helper, ts_utils, utils
from .check import check_pass_param, is_legal_collection_properties
from .constants import (
COLLECTION_ID,
DB_NAME,
DEFAULT_CONSISTENCY_LEVEL,
DYNAMIC_FIELD_NAME,
GROUP_BY_FIELD,
Expand Down Expand Up @@ -961,6 +963,14 @@ def search_requests_with_expr(
if is_iterator is not None:
search_params[ITERATOR_FIELD] = is_iterator

db_name = kwargs.get(DB_NAME)
if db_name is not None:
search_params[DB_NAME] = db_name

collection_id = kwargs.get(COLLECTION_ID)
if collection_id is not None:
search_params[COLLECTION_ID] = collection_id

is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY)
if is_search_iter_v2 is not None:
search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2
Expand Down Expand Up @@ -1327,6 +1337,16 @@ def query_request(
common_types.KeyValuePair(key=ITERATOR_FIELD, value=is_iterator)
)

collection_id = kwargs.get(COLLECTION_ID)
if collection_id is not None:
req.query_params.append(
common_types.KeyValuePair(key=COLLECTION_ID, value=str(collection_id))
)

db_name = kwargs.get(DB_NAME)
if db_name is not None:
req.query_params.append(common_types.KeyValuePair(key=DB_NAME, value=str(db_name)))

req.query_params.append(
common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing))
)
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
COLLECTION_ID = "collection_id"
DB_NAME = "db_name"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0
Expand Down
82 changes: 70 additions & 12 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
CALC_DIST_JACCARD,
CALC_DIST_L2,
CALC_DIST_TANIMOTO,
COLLECTION_ID,
DB_NAME,
DEFAULT_SEARCH_EXTENSION_RATE,
EF,
FIELDS,
Expand Down Expand Up @@ -101,6 +103,8 @@ def __init__(
) -> QueryIterator:
self._conn = connection
self._collection_name = collection_name
self._collection_id = 0
self._db_name = ""
self._output_fields = output_fields
self._partition_names = partition_names
self._schema = schema
Expand All @@ -120,6 +124,40 @@ def __init__(
self.__set_up_ts_cp()
self.__seek_to_offset()

def __query_request(
self,
collection_name: str,
expr: Optional[str] = None,
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
**kwargs,
):
# set db name and collection_id is existed
if self._collection_id > 0:
kwargs[COLLECTION_ID] = self._collection_id
if self._db_name != "":
kwargs[DB_NAME] = self._db_name

# query
res = self._conn.query(
collection_name=collection_name,
expr=expr,
output_field=output_fields,
partition_name=partition_names,
timeout=timeout,
**kwargs,
)

# reset db_name and collection_id if existed
collection_id = res.extra.get(COLLECTION_ID, 0)
if collection_id > 0 and self._collection_id == 0:
self._collection_id = collection_id
db_name = res.extra.get(DB_NAME, "")
if db_name is not None and db_name != "" and self._db_name == "":
self._db_name = db_name
return res

def __seek_to_offset(self):
# read pk cursor from cp file, no need to seek offset
if self._next_id is not None:
Expand All @@ -134,11 +172,11 @@ def __seek_to_offset(self):

def seek_offset_by_batch(batch: int, expr: str) -> int:
seek_params[MILVUS_LIMIT] = batch
res = self._conn.query(
collection_name=self._collection_name,
res = self.__query_request(
self._collection_name,
expr=expr,
output_field=[],
partition_name=self._partition_names,
output_fields=[],
partition_names=self._partition_names,
timeout=self._timeout,
**seek_params,
)
Expand Down Expand Up @@ -234,14 +272,15 @@ def __setup_ts_by_request(self):
init_ts_kwargs[OFFSET] = 0
init_ts_kwargs[MILVUS_LIMIT] = 1
# just to set up mvccTs for iterator, no need correct limit
res = self._conn.query(
collection_name=self._collection_name,
res = self.__query_request(
self._collection_name,
expr=self._expr,
output_field=self._output_fields,
partition_name=self._partition_names,
output_fields=[],
partition_names=self._partition_names,
timeout=self._timeout,
**init_ts_kwargs,
)

if res is None:
raise MilvusException(
message="failed to connect to milvus for setting up "
Expand Down Expand Up @@ -322,14 +361,15 @@ def next(self):
iterator_cache.release_cache(self._cache_id_in_use)
current_expr = self.__setup_next_expr()
log.debug(f"query_iterator_next_expr:{current_expr}")
res = self._conn.query(
collection_name=self._collection_name,
res = self.__query_request(
self._collection_name,
expr=current_expr,
output_fields=self._output_fields,
partition_names=self._partition_names,
timeout=self._timeout,
**self._kwargs,
)

self.__maybe_cache(res)
ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))]

Expand Down Expand Up @@ -406,13 +446,27 @@ class SearchPage(LoopBase):
"""Since we only support nq=1 in search iteration, so search iteration response
should be different from raw response of search operation"""

def __init__(self, res: Hits, session_ts: Optional[int] = 0):
def __init__(
self,
res: Hits,
session_ts: Optional[int] = 0,
collection_id: Optional[int] = 0,
db_name: Optional[str] = None,
):
super().__init__()
self._session_ts = session_ts
self._collection_id = collection_id
self._db_name = db_name
self._results = []
if res is not None:
self._results.append(res)

def get_collection_id(self):
return self._collection_id

def get_db_name(self):
return self._db_name

def get_session_ts(self):
return self._session_ts

Expand Down Expand Up @@ -515,6 +569,10 @@ def __init__(

def __init_search_iterator(self):
init_page = self.__execute_next_search(self._param, self._expr, False)
self._db_name = init_page.get_db_name()
self._collection_id = init_page.get_collection_id()
self._kwargs[COLLECTION_ID] = self._collection_id
self._kwargs[DB_NAME] = self._db_name
self._session_ts = init_page.get_session_ts()
if self._session_ts <= 0:
log.warning("failed to set up mvccTs from milvus server, use client-side ts instead")
Expand Down Expand Up @@ -736,7 +794,7 @@ def __execute_next_search(
schema=self._schema,
**self._kwargs,
)
return SearchPage(res[0], res.get_session_ts())
return SearchPage(res[0], res.get_session_ts(), res.get_collection_id(), res.get_db_name())

# at present, the range_filter parameter means 'larger/less and equal',
# so there would be vectors with same distances returned multiple times in different pages
Expand Down

0 comments on commit c2631f9

Please sign in to comment.