Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Add schema update time verification to insert and upsert to use cache #2551

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, raw: Any):
self.num_shards = 0
self.num_partitions = 0
self.enable_dynamic_field = False

self.update_timestamp = 0
if self._raw:
self.__pack(self._raw)

Expand All @@ -209,7 +209,7 @@ def __pack(self, raw: Any):
# for kv in raw.extra_params:

self.fields = [FieldSchema(f) for f in raw.schema.fields]

self.update_timestamp = raw.update_timestamp
self.functions = [FunctionSchema(f) for f in raw.schema.functions]
function_output_field_names = [f for fn in self.functions for f in fn.output_field_names]
for field in self.fields:
Expand Down Expand Up @@ -247,6 +247,7 @@ def dict(self):
"properties": self.properties,
"num_partitions": self.num_partitions,
"enable_dynamic_field": self.enable_dynamic_field,
"update_timestamp": self.update_timestamp,
}
self._rewrite_schema_dict(_dict)
return _dict
Expand Down
64 changes: 60 additions & 4 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel()
self.callbacks = []
self.schema_cache = {}

def register_state_change_callback(self, callback: Callable):
self.callbacks.append(callback)
Expand Down Expand Up @@ -161,6 +162,7 @@ def close(self):
self._channel.close()

def reset_db_name(self, db_name: str):
self.schema_cache.clear()
self._setup_db_interceptor(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)
Expand Down Expand Up @@ -526,10 +528,28 @@ def insert_rows(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
if resp.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
check_status(resp.status)
ts_utils.update_collection_ts(collection_name, resp.timestamp)
return MutationResult(resp)

def update_schema(self, collection_name: str, timeout: Optional[float] = None):
self.schema_cache.pop(collection_name, None)
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema

def _prepare_row_insert_request(
self,
collection_name: str,
Expand All @@ -542,9 +562,9 @@ def _prepare_row_insert_request(
if isinstance(entity_rows, dict):
entity_rows = [entity_rows]

if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)

schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, schema, timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)

Expand All @@ -554,8 +574,33 @@ def _prepare_row_insert_request(
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

def _get_schema_from_cache_or_remote(
self, collection_name: str, schema: Optional[dict] = None, timeout: Optional[float] = None
):
"""
checks the cache for the schema. If not found, it fetches it remotely and updates the cache
"""
if collection_name in self.schema_cache:
# Use the cached schema and timestamp
schema = self.schema_cache[collection_name]["schema"]
schema_timestamp = self.schema_cache[collection_name]["schema_timestamp"]
else:
# Fetch the schema remotely if not in cache
if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)
schema_timestamp = schema.get("update_timestamp", 0)

# Cache the fetched schema and timestamp
self.schema_cache[collection_name] = {
"schema": schema,
"schema_timestamp": schema_timestamp,
}

return schema, schema_timestamp

def _prepare_batch_insert_request(
self,
collection_name: str,
Expand Down Expand Up @@ -723,13 +768,18 @@ def _prepare_row_upsert_request(
if not isinstance(rows, list):
raise ParamError(message="'rows' must be a list, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
schema, schema_timestamp = self._get_schema_from_cache_or_remote(
collection_name, timeout=timeout
)
fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)
return Prepare.row_upsert_param(
collection_name,
rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
schema_timestamp=schema_timestamp,
)

@retry_on_rpc_failure()
Expand All @@ -748,6 +798,12 @@ def upsert_rows(
)
rf = self._stub.Upsert.future(request, timeout=timeout)
response = rf.result()
if response.status.error_code == common_pb2.SchemaMismatch:
schema = self.update_schema(collection_name, timeout)
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
response = self._stub.Insert(request=request, timeout=timeout)
check_status(response.status)
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def row_insert_param(
entities: List,
partition_name: str,
fields_info: Dict,
schema_timestamp: int = 0,
enable_dynamic: bool = False,
):
if not fields_info:
Expand All @@ -617,6 +618,7 @@ def row_insert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -629,6 +631,7 @@ def row_upsert_param(
partition_name: str,
fields_info: Any,
enable_dynamic: bool = False,
schema_timestamp: int = 0,
):
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")
Expand All @@ -639,6 +642,7 @@ def row_upsert_param(
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
schema_timestamp=schema_timestamp,
)

return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)
Expand Down
Loading
Loading