Skip to content

Commit

Permalink
enhance: [2.4]Add schema update time verification to insert and upser…
Browse files Browse the repository at this point in the history
…t to use cache (#2597)

enhance: Add schema update time verification to insert and upsert to use
cache
issue: milvus-io/milvus#39093
pr: #2551

---------

Signed-off-by: Xianhui.Lin <xianhui.lin@zilliz.com>
JsDove authored Jan 24, 2025
1 parent 6d418d7 commit c34bb27
Showing 8 changed files with 478 additions and 409 deletions.
3 changes: 1 addition & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ def __init__(self, raw: Any):
self.num_shards = 0
self.num_partitions = 0
self.enable_dynamic_field = False

self.update_timestamp = 0
if self._raw:
self.__pack(self._raw)

@@ -150,7 +150,6 @@ def __pack(self, raw: Any):
# for kv in raw.extra_params:

self.fields = [FieldSchema(f) for f in raw.schema.fields]

# for s in raw.statistics:

for p in raw.properties:
64 changes: 60 additions & 4 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@ def __init__(
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel()
self.callbacks = []
self.schema_cache = {}

def register_state_change_callback(self, callback: Callable):
self.callbacks.append(callback)
@@ -160,6 +161,7 @@ def close(self):
self._channel.close()

def reset_db_name(self, db_name: str):
self.schema_cache.clear()
self._setup_db_interceptor(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)
@@ -525,10 +527,28 @@ def insert_rows(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
if resp.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
check_status(resp.status)
ts_utils.update_collection_ts(collection_name, resp.timestamp)
return MutationResult(resp)

def update_schema(self, collection_name: str, timeout: Optional[float] = None):
self.schema_cache.pop(collection_name, None)
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema

def _prepare_row_insert_request(
self,
collection_name: str,
@@ -541,9 +561,9 @@ def _prepare_row_insert_request(
if isinstance(entity_rows, dict):
entity_rows = [entity_rows]

if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)

schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, schema, timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)

@@ -553,8 +573,33 @@ def _prepare_row_insert_request(
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

def _get_schema_from_cache_or_remote(
self, collection_name: str, schema: Optional[dict] = None, timeout: Optional[float] = None
):
"""
checks the cache for the schema. If not found, it fetches it remotely and updates the cache
"""
if collection_name in self.schema_cache:
# Use the cached schema and timestamp
schema = self.schema_cache[collection_name]["schema"]
schema_timestamp = self.schema_cache[collection_name]["schema_timestamp"]
else:
# Fetch the schema remotely if not in cache
if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

# Cache the fetched schema and timestamp
self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema, schema_timestamp

def _prepare_batch_insert_request(
self,
collection_name: str,
@@ -722,13 +767,18 @@ def _prepare_row_upsert_request(
if not isinstance(rows, list):
raise ParamError(message="'rows' must be a list, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, timeout=timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)
return Prepare.row_upsert_param(
collection_name,
rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

@retry_on_rpc_failure()
@@ -747,6 +797,12 @@ def upsert_rows(
)
rf = self._stub.Upsert.future(request, timeout=timeout)
response = rf.result()
if response.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_upsert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
response = self._stub.Upsert(request=request, timeout=timeout)
check_status(response.status)
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
4 changes: 4 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
@@ -516,6 +516,7 @@ def row_insert_param(
entities: List,
partition_name: str,
fields_info: Dict,
schema_timestamp: int = 0,
enable_dynamic: bool = False,
):
if not fields_info:
@@ -527,6 +528,7 @@ def row_insert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
@@ -539,6 +541,7 @@ def row_upsert_param(
partition_name: str,
fields_info: Any,
enable_dynamic: bool = False,
schema_timestamp: int = 0,
):
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")
@@ -549,6 +552,7 @@ def row_upsert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)
56 changes: 28 additions & 28 deletions pymilvus/grpc_gen/common_pb2.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pymilvus/grpc_gen/common_pb2.pyi
Original file line number Diff line number Diff line change
@@ -70,6 +70,7 @@ class ErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
DropPrivilegeGroupFailure: _ClassVar[ErrorCode]
ListPrivilegeGroupsFailure: _ClassVar[ErrorCode]
OperatePrivilegeGroupFailure: _ClassVar[ErrorCode]
SchemaMismatch: _ClassVar[ErrorCode]
DataCoordNA: _ClassVar[ErrorCode]
DDRequestRace: _ClassVar[ErrorCode]

@@ -409,6 +410,7 @@ CreatePrivilegeGroupFailure: ErrorCode
DropPrivilegeGroupFailure: ErrorCode
ListPrivilegeGroupsFailure: ErrorCode
OperatePrivilegeGroupFailure: ErrorCode
SchemaMismatch: ErrorCode
DataCoordNA: ErrorCode
DDRequestRace: ErrorCode
IndexStateNone: IndexState
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
736 changes: 368 additions & 368 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

20 changes: 14 additions & 6 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
@@ -245,7 +245,7 @@ class DescribeCollectionRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., collectionID: _Optional[int] = ..., time_stamp: _Optional[int] = ...) -> None: ...

class DescribeCollectionResponse(_message.Message):
__slots__ = ("status", "schema", "collectionID", "virtual_channel_names", "physical_channel_names", "created_timestamp", "created_utc_timestamp", "shards_num", "aliases", "start_positions", "consistency_level", "collection_name", "properties", "db_name", "num_partitions", "db_id")
__slots__ = ("status", "schema", "collectionID", "virtual_channel_names", "physical_channel_names", "created_timestamp", "created_utc_timestamp", "shards_num", "aliases", "start_positions", "consistency_level", "collection_name", "properties", "db_name", "num_partitions", "db_id", "request_time", "update_timestamp")
STATUS_FIELD_NUMBER: _ClassVar[int]
SCHEMA_FIELD_NUMBER: _ClassVar[int]
COLLECTIONID_FIELD_NUMBER: _ClassVar[int]
@@ -262,6 +262,8 @@ class DescribeCollectionResponse(_message.Message):
DB_NAME_FIELD_NUMBER: _ClassVar[int]
NUM_PARTITIONS_FIELD_NUMBER: _ClassVar[int]
DB_ID_FIELD_NUMBER: _ClassVar[int]
REQUEST_TIME_FIELD_NUMBER: _ClassVar[int]
UPDATE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
schema: _schema_pb2.CollectionSchema
collectionID: int
@@ -278,7 +280,9 @@ class DescribeCollectionResponse(_message.Message):
db_name: str
num_partitions: int
db_id: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., schema: _Optional[_Union[_schema_pb2.CollectionSchema, _Mapping]] = ..., collectionID: _Optional[int] = ..., virtual_channel_names: _Optional[_Iterable[str]] = ..., physical_channel_names: _Optional[_Iterable[str]] = ..., created_timestamp: _Optional[int] = ..., created_utc_timestamp: _Optional[int] = ..., shards_num: _Optional[int] = ..., aliases: _Optional[_Iterable[str]] = ..., start_positions: _Optional[_Iterable[_Union[_common_pb2.KeyDataPair, _Mapping]]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., collection_name: _Optional[str] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., db_name: _Optional[str] = ..., num_partitions: _Optional[int] = ..., db_id: _Optional[int] = ...) -> None: ...
request_time: int
update_timestamp: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., schema: _Optional[_Union[_schema_pb2.CollectionSchema, _Mapping]] = ..., collectionID: _Optional[int] = ..., virtual_channel_names: _Optional[_Iterable[str]] = ..., physical_channel_names: _Optional[_Iterable[str]] = ..., created_timestamp: _Optional[int] = ..., created_utc_timestamp: _Optional[int] = ..., shards_num: _Optional[int] = ..., aliases: _Optional[_Iterable[str]] = ..., start_positions: _Optional[_Iterable[_Union[_common_pb2.KeyDataPair, _Mapping]]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., collection_name: _Optional[str] = ..., properties: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., db_name: _Optional[str] = ..., num_partitions: _Optional[int] = ..., db_id: _Optional[int] = ..., request_time: _Optional[int] = ..., update_timestamp: _Optional[int] = ...) -> None: ...

class LoadCollectionRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "replica_number", "resource_groups", "refresh", "load_fields", "skip_load_dynamic_field")
@@ -687,40 +691,44 @@ class DropIndexRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., field_name: _Optional[str] = ..., index_name: _Optional[str] = ...) -> None: ...

class InsertRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_name", "fields_data", "hash_keys", "num_rows")
__slots__ = ("base", "db_name", "collection_name", "partition_name", "fields_data", "hash_keys", "num_rows", "schema_timestamp")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
PARTITION_NAME_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
HASH_KEYS_FIELD_NUMBER: _ClassVar[int]
NUM_ROWS_FIELD_NUMBER: _ClassVar[int]
SCHEMA_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
partition_name: str
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
hash_keys: _containers.RepeatedScalarFieldContainer[int]
num_rows: int
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., hash_keys: _Optional[_Iterable[int]] = ..., num_rows: _Optional[int] = ...) -> None: ...
schema_timestamp: int
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., hash_keys: _Optional[_Iterable[int]] = ..., num_rows: _Optional[int] = ..., schema_timestamp: _Optional[int] = ...) -> None: ...

class UpsertRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_name", "fields_data", "hash_keys", "num_rows")
__slots__ = ("base", "db_name", "collection_name", "partition_name", "fields_data", "hash_keys", "num_rows", "schema_timestamp")
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
PARTITION_NAME_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
HASH_KEYS_FIELD_NUMBER: _ClassVar[int]
NUM_ROWS_FIELD_NUMBER: _ClassVar[int]
SCHEMA_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
partition_name: str
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
hash_keys: _containers.RepeatedScalarFieldContainer[int]
num_rows: int
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., hash_keys: _Optional[_Iterable[int]] = ..., num_rows: _Optional[int] = ...) -> None: ...
schema_timestamp: int
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_name: _Optional[str] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., hash_keys: _Optional[_Iterable[int]] = ..., num_rows: _Optional[int] = ..., schema_timestamp: _Optional[int] = ...) -> None: ...

class MutationResult(_message.Message):
__slots__ = ("status", "IDs", "succ_index", "err_index", "acknowledged", "insert_cnt", "delete_cnt", "upsert_cnt", "timestamp")

0 comments on commit c34bb27

Please sign in to comment.