diff --git a/pymilvus/client/async_grpc_handler.py b/pymilvus/client/async_grpc_handler.py index 1d8825be6..c9fe1320e 100644 --- a/pymilvus/client/async_grpc_handler.py +++ b/pymilvus/client/async_grpc_handler.py @@ -2,6 +2,7 @@ import base64 import copy import socket +import time from pathlib import Path from typing import Callable, Dict, List, Optional, Union from urllib import parse @@ -9,9 +10,11 @@ import grpc from grpc._cython import cygrpc -from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder +from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure from pymilvus.exceptions import ( + AmbiguousIndexName, DescribeCollectionException, + ExceptionsMessage, MilvusException, ParamError, ) @@ -32,6 +35,7 @@ from .types import ( DataType, ExtraList, + IndexState, Status, get_cost_extra, ) @@ -64,6 +68,7 @@ def __init__( self._set_authorization(**kwargs) self._setup_db_name(kwargs.get("db_name")) self._setup_grpc_channel(**kwargs) + self._is_channel_ready = False self.callbacks = [] def register_state_change_callback(self, callback: Callable): @@ -104,33 +109,10 @@ def __enter__(self): def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object): pass - def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1): - try: - - async def wait_for_async_channel_ready(): - await self._async_channel.channel_ready() - - loop = asyncio.get_event_loop() - loop.run_until_complete(wait_for_async_channel_ready()) - - self._setup_identifier_interceptor(self._user, timeout=timeout) - except grpc.FutureTimeoutError as e: - raise MilvusException( - code=Status.CONNECT_FAILED, - message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable", - ) from e - except Exception as e: - raise e from e - def close(self): self.deregister_state_change_callbacks() self._async_channel.close() - def reset_db_name(self, db_name: str): - self._setup_db_name(db_name) - self._setup_grpc_channel() - self._setup_identifier_interceptor(self._user) - def _setup_authorization_interceptor(self, user: str, password: str, token: str): keys = [] values = [] @@ -229,15 +211,6 @@ def _setup_grpc_channel(self, **kwargs): self._request_id = None self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) - def _setup_identifier_interceptor(self, user: str, timeout: int = 10): - host = socket.gethostname() - self._identifier = self.__internal_register(user, host, timeout=timeout) - _async_identifier_interceptor = async_header_adder_interceptor( - ["identifier"], [str(self._identifier)] - ) - self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor) - self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel) - @property def server_address(self): return self._address @@ -245,10 +218,36 @@ def server_address(self): def get_server_type(self): return get_server_type(self.server_address.split(":")[0]) + async def ensure_channel_ready(self): + try: + if not self._is_channel_ready: + # wait for channel ready + await self._async_channel.channel_ready() + # set identifier interceptor + host = socket.gethostname() + req = Prepare.register_request(self._user, host) + response = await self._async_stub.Connect(request=req) + check_status(response.status) + _async_identifier_interceptor = async_header_adder_interceptor( + ["identifier"], [str(response.identifier)] + ) + self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor) + self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel) + + self._is_channel_ready = True + except grpc.FutureTimeoutError as e: + raise MilvusException( + code=Status.CONNECT_FAILED, + message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable", + ) from e + except Exception as e: + raise e from e + @retry_on_rpc_failure() async def create_collection( self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.create_collection_request(collection_name, fields, **kwargs) response = await self._async_stub.CreateCollection(request, timeout=timeout) @@ -256,6 +255,7 @@ async def create_collection( @retry_on_rpc_failure() async def drop_collection(self, collection_name: str, timeout: Optional[float] = None): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.drop_collection_request(collection_name) response = await self._async_stub.DropCollection(request, timeout=timeout) @@ -269,6 +269,7 @@ async def load_collection( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( collection_name=collection_name, replica_number=replica_number, timeout=timeout ) @@ -291,10 +292,48 @@ async def load_collection( response = await self._async_stub.LoadCollection(request, timeout=timeout) check_status(response) + await self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh) + + @retry_on_rpc_failure() + async def wait_for_loading_collection( + self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False + ): + start = time.time() + + def can_loop(t: int) -> bool: + return True if timeout is None else t <= (start + timeout) + + while can_loop(time.time()): + progress = await self.get_loading_progress( + collection_name, timeout=timeout, is_refresh=is_refresh + ) + if progress >= 100: + return + await asyncio.sleep(Config.WaitTimeDurationWhenLoad) + raise MilvusException( + message=f"wait for loading collection timeout, collection: {collection_name}" + ) + + @retry_on_rpc_failure() + async def get_loading_progress( + self, + collection_name: str, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + is_refresh: bool = False, + ): + request = Prepare.get_loading_progress(collection_name, partition_names) + response = await self._async_stub.GetLoadingProgress(request, timeout=timeout) + check_status(response.status) + if is_refresh: + return response.refresh_progress + return response.progress + @retry_on_rpc_failure() async def describe_collection( self, collection_name: str, timeout: Optional[float] = None, **kwargs ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) request = Prepare.describe_collection_request(collection_name) response = await self._async_stub.DescribeCollection(request, timeout=timeout) @@ -325,6 +364,7 @@ async def insert_rows( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() request = await self._prepare_row_insert_request( collection_name, entities, partition_name, schema, timeout, **kwargs ) @@ -359,6 +399,7 @@ async def _prepare_row_insert_request( enable_dynamic=enable_dynamic, ) + @retry_on_rpc_failure() async def delete( self, collection_name: str, @@ -367,6 +408,7 @@ async def delete( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param(collection_name=collection_name, timeout=timeout) try: req = Prepare.delete_request( @@ -421,6 +463,7 @@ async def upsert( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if not check_invalid_binary_vector(entities): raise ParamError(message="Invalid binary vector data exists") @@ -466,6 +509,7 @@ async def upsert_rows( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if isinstance(entities, dict): entities = [entities] request = await self._prepare_row_upsert_request( @@ -520,6 +564,7 @@ async def search( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( limit=limit, round_decimal=round_decimal, @@ -557,6 +602,7 @@ async def hybrid_search( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() check_pass_param( limit=limit, round_decimal=round_decimal, @@ -610,7 +656,7 @@ async def create_index( collection_desc = await self.describe_collection( collection_name, timeout=timeout, **copy_kwargs ) - + await self.ensure_channel_ready() valid_field = False for fields in collection_desc["fields"]: if field_name != fields["name"]: @@ -635,8 +681,67 @@ async def create_index( status = await self._async_stub.CreateIndex(index_param, timeout=timeout) check_status(status) + index_success, fail_reason = await self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) + + if not index_success: + raise MilvusException(message=fail_reason) + return Status(status.code, status.reason) + @retry_on_rpc_failure() + async def wait_for_creating_index( + self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs + ): + timestamp = await self.alloc_timestamp() + start = time.time() + while True: + await asyncio.sleep(0.5) + state, fail_reason = await self.get_index_state( + collection_name, index_name, timeout=timeout, timestamp=timestamp, **kwargs + ) + if state == IndexState.Finished: + return True, fail_reason + if state == IndexState.Failed: + return False, fail_reason + end = time.time() + if isinstance(timeout, int) and end - start > timeout: + msg = ( + f"collection {collection_name} create index {index_name} " + f"timeout in {timeout}s" + ) + raise MilvusException(message=msg) + + @retry_on_rpc_failure() + async def get_index_state( + self, + collection_name: str, + index_name: str, + timeout: Optional[float] = None, + timestamp: Optional[int] = None, + **kwargs, + ): + request = Prepare.describe_index_request(collection_name, index_name, timestamp) + response = await self._async_stub.DescribeIndex(request, timeout=timeout) + status = response.status + check_status(status) + + if len(response.index_descriptions) == 1: + index_desc = response.index_descriptions[0] + return index_desc.state, index_desc.index_state_fail_reason + # just for create_index. + field_name = kwargs.pop("field_name", "") + if field_name != "": + for index_desc in response.index_descriptions: + if index_desc.field_name == field_name: + return index_desc.state, index_desc.index_state_fail_reason + + raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) + @retry_on_rpc_failure() async def get( self, @@ -647,6 +752,7 @@ async def get( timeout: Optional[float] = None, ): # TODO: some check + await self.ensure_channel_ready() request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names) return await self._async_stub.Retrieve.get(request, timeout=timeout) @@ -660,6 +766,7 @@ async def query( timeout: Optional[float] = None, **kwargs, ): + await self.ensure_channel_ready() if output_fields is not None and not isinstance(output_fields, (list,)): raise ParamError(message="Invalid query format. 'output_fields' must be a list") request = Prepare.query_request( @@ -693,15 +800,9 @@ async def query( return ExtraList(results, extra=extra_dict) @retry_on_rpc_failure() - @upgrade_reminder - def __internal_register(self, user: str, host: str, **kwargs) -> int: - req = Prepare.register_request(user, host) - - async def wait_for_connect_response(): - return await self._async_stub.Connect(request=req) - - loop = asyncio.get_event_loop() - response = loop.run_until_complete(wait_for_connect_response()) - + @ignore_unimplemented(0) + async def alloc_timestamp(self, timeout: Optional[float] = None) -> int: + request = milvus_types.AllocTimestampRequest() + response = await self._async_stub.AllocTimestamp(request, timeout=timeout) check_status(response.status) - return response.identifier + return response.timestamp diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 5d6306661..151b31e7d 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -400,7 +400,9 @@ def connect_milvus(**kwargs): t = kwargs.get("timeout") timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT - gh._wait_for_channel_ready(timeout=timeout) + if not _async: + gh._wait_for_channel_ready(timeout=timeout) + if kwargs.get("keep_alive", False): gh.register_state_change_callback( ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle