Skip to content

Commit

Permalink
add check db collection code
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPresent-Han committed Jan 13, 2025
1 parent e04c169 commit a194c0d
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 21 deletions.
2 changes: 2 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
COLLECTION_ID = "collection_id"
DB_NAME = "db_name"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
Expand Down
4 changes: 3 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
is_legal_host,
is_legal_port,
)
from .constants import ITERATOR_SESSION_TS_FIELD
from .constants import ITERATOR_SESSION_TS_FIELD, COLLECTION_ID, DB_NAME
from .prepare import Prepare
from .types import (
BulkInsertState,
Expand Down Expand Up @@ -1625,6 +1625,8 @@ def query(

extra_dict = get_cost_extra(response.status)
extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts
extra_dict[COLLECTION_ID] = response.collection_id
extra_dict[DB_NAME] = response.db_name
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
Expand Down
7 changes: 7 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
ITERATOR_COLLECTION_ID,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
REDUCE_STOP_FOR_BEST,
Expand Down Expand Up @@ -1327,6 +1328,12 @@ def query_request(
common_types.KeyValuePair(key=ITERATOR_FIELD, value=is_iterator)
)

iterator_collection_id = kwargs.get(ITERATOR_COLLECTION_ID)
if is_iterator is not None:
req.query_params.append(
common_types.KeyValuePair(key=ITERATOR_COLLECTION_ID, value=str(iterator_collection_id))
)

req.query_params.append(
common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing))
)
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
COLLECTION_ID = "collection_id"
DB_NAME = "db_name"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0
Expand All @@ -65,3 +67,4 @@
ITERATOR_SESSION_CP_FILE = "iterator_cp_file"
BM25_k1 = "bm25_k1"
BM25_b = "bm25_b"

54 changes: 34 additions & 20 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pymilvus.grpc_gen import milvus_pb2 as milvus_types

from .connections import Connections
from .connections import Connections, connections
from .constants import (
BATCH_SIZE,
CALC_DIST_BM25,
Expand All @@ -29,6 +29,8 @@
GUARANTEE_TIMESTAMP,
INT64_MAX,
IS_PRIMARY,
COLLECTION_ID,
DB_NAME,
ITERATOR_FIELD,
ITERATOR_SESSION_CP_FILE,
ITERATOR_SESSION_TS_FIELD,
Expand Down Expand Up @@ -101,7 +103,6 @@ def __init__(
) -> QueryIterator:
self._conn = connection
self._collection_name = collection_name
self.__set_up_collection_id()
self._output_fields = output_fields
self._partition_names = partition_names
self._schema = schema
Expand All @@ -121,9 +122,37 @@ def __init__(
self.__set_up_ts_cp()
self.__seek_to_offset()

def __set_up_collection_id(self):
col_desc = self._conn.describe_collection(self._collection_name, timeout=60.0)
self._collection_id = col_desc.get("collection_id")
def __query_request(self,
collection_name: str,
expr: Optional[str] = None,
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
**kwargs):
# set db name and collection_id is existed
if self._collection_id is not None:
kwargs[COLLECTION_ID] = self._collection_id
if self._db_name is not None:
kwargs[DB_NAME] = self._db_name

# query
res = self._conn.query(
collection_name=collection_name,
expr=expr,
output_field=output_fields,
partition_name=partition_names,
timeout=timeout,
**kwargs,
)

# reset db_name and collection_id if existed
collection_id = res.extra.get(COLLECTION_ID, 0)
if collection_id> 0 and self._collection_id is None:
self._collection_id = collection_id
db_name = res.extra.get(DB_NAME, "")
if db_name is not None and db_name != "" and self._db_name is None:
self._db_name = db_name
return res

def __seek_to_offset(self):
# read pk cursor from cp file, no need to seek offset
Expand All @@ -147,7 +176,6 @@ def seek_offset_by_batch(batch: int, expr: str) -> int:
timeout=self._timeout,
**seek_params,
)
self.__check_collection_match(res)
self.__update_cursor(res)
return len(res)

Expand Down Expand Up @@ -340,7 +368,6 @@ def next(self):
ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))]

ret = self.__check_reached_limit(ret)
self.__check_collection_match(res)
self.__update_cursor(ret)
io_operation(self.__save_pk_cursor, "failed to save pk cursor")
self._returned_count += len(ret)
Expand Down Expand Up @@ -381,19 +408,6 @@ def __setup_next_expr(self) -> str:
return filtered_pk_str
return "(" + current_expr + ")" + " and " + filtered_pk_str

def __check_collection_match(self, res: List):
res_collection_id = res[-1]["collection_id"]
if (
res_collection_id is not None
and res_collection_id > 0
and res_collection_id != self._collection_id
):
raise MilvusException(
message="collection_id in the result is not the "
"same as the inited collection id, the alias may be changed, cut off"
"iterator connection"
)

def __update_cursor(self, res: List) -> None:
if len(res) == 0:
return
Expand Down

0 comments on commit a194c0d

Please sign in to comment.