Skip to content

Commit

Permalink
fix: [2.5] ensure create_index and load_collection are fully comp…
Browse files Browse the repository at this point in the history
…leted (#2478)

- Add `wait_for_creating_index`, `wait_for_loading_collection` to ensure
`create_index` and `load_collection` are fully completed

---------

Signed-off-by: Ruichen Bao <[email protected]>
  • Loading branch information
brcarry authored Dec 20, 2024
1 parent 6e169a5 commit 8aa6de4
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 46 deletions.
191 changes: 146 additions & 45 deletions pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
import base64
import copy
import socket
import time
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from urllib import parse

import grpc
from grpc._cython import cygrpc

from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure
from pymilvus.exceptions import (
AmbiguousIndexName,
DescribeCollectionException,
ExceptionsMessage,
MilvusException,
ParamError,
)
Expand All @@ -32,6 +35,7 @@
from .types import (
DataType,
ExtraList,
IndexState,
Status,
get_cost_extra,
)
Expand Down Expand Up @@ -64,6 +68,7 @@ def __init__(
self._set_authorization(**kwargs)
self._setup_db_name(kwargs.get("db_name"))
self._setup_grpc_channel(**kwargs)
self._is_channel_ready = False
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
Expand Down Expand Up @@ -104,33 +109,10 @@ def __enter__(self):
def __exit__(self: object, exc_type: object, exc_val: object, exc_tb: object):
pass

def _wait_for_channel_ready(self, timeout: Union[float] = 10, retry_interval: float = 1):
try:

async def wait_for_async_channel_ready():
await self._async_channel.channel_ready()

loop = asyncio.get_event_loop()
loop.run_until_complete(wait_for_async_channel_ready())

self._setup_identifier_interceptor(self._user, timeout=timeout)
except grpc.FutureTimeoutError as e:
raise MilvusException(
code=Status.CONNECT_FAILED,
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
) from e
except Exception as e:
raise e from e

def close(self):
self.deregister_state_change_callbacks()
self._async_channel.close()

def reset_db_name(self, db_name: str):
self._setup_db_name(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)

def _setup_authorization_interceptor(self, user: str, password: str, token: str):
keys = []
values = []
Expand Down Expand Up @@ -229,33 +211,51 @@ def _setup_grpc_channel(self, **kwargs):
self._request_id = None
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)

def _setup_identifier_interceptor(self, user: str, timeout: int = 10):
host = socket.gethostname()
self._identifier = self.__internal_register(user, host, timeout=timeout)
_async_identifier_interceptor = async_header_adder_interceptor(
["identifier"], [str(self._identifier)]
)
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)

@property
def server_address(self):
return self._address

def get_server_type(self):
return get_server_type(self.server_address.split(":")[0])

async def ensure_channel_ready(self):
try:
if not self._is_channel_ready:
# wait for channel ready
await self._async_channel.channel_ready()
# set identifier interceptor
host = socket.gethostname()
req = Prepare.register_request(self._user, host)
response = await self._async_stub.Connect(request=req)
check_status(response.status)
_async_identifier_interceptor = async_header_adder_interceptor(
["identifier"], [str(response.identifier)]
)
self._async_channel._unary_unary_interceptors.append(_async_identifier_interceptor)
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._async_channel)

self._is_channel_ready = True
except grpc.FutureTimeoutError as e:
raise MilvusException(
code=Status.CONNECT_FAILED,
message=f"Fail connecting to server on {self._address}, illegal connection params or server unavailable",
) from e
except Exception as e:
raise e from e

@retry_on_rpc_failure()
async def create_collection(
self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.create_collection_request(collection_name, fields, **kwargs)
response = await self._async_stub.CreateCollection(request, timeout=timeout)
check_status(response)

@retry_on_rpc_failure()
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.drop_collection_request(collection_name)
response = await self._async_stub.DropCollection(request, timeout=timeout)
Expand All @@ -269,6 +269,7 @@ async def load_collection(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
collection_name=collection_name, replica_number=replica_number, timeout=timeout
)
Expand All @@ -291,10 +292,48 @@ async def load_collection(
response = await self._async_stub.LoadCollection(request, timeout=timeout)
check_status(response)

await self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh)

@retry_on_rpc_failure()
async def wait_for_loading_collection(
self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False
):
start = time.time()

def can_loop(t: int) -> bool:
return True if timeout is None else t <= (start + timeout)

while can_loop(time.time()):
progress = await self.get_loading_progress(
collection_name, timeout=timeout, is_refresh=is_refresh
)
if progress >= 100:
return
await asyncio.sleep(Config.WaitTimeDurationWhenLoad)
raise MilvusException(
message=f"wait for loading collection timeout, collection: {collection_name}"
)

@retry_on_rpc_failure()
async def get_loading_progress(
self,
collection_name: str,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
is_refresh: bool = False,
):
request = Prepare.get_loading_progress(collection_name, partition_names)
response = await self._async_stub.GetLoadingProgress(request, timeout=timeout)
check_status(response.status)
if is_refresh:
return response.refresh_progress
return response.progress

@retry_on_rpc_failure()
async def describe_collection(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
request = Prepare.describe_collection_request(collection_name)
response = await self._async_stub.DescribeCollection(request, timeout=timeout)
Expand Down Expand Up @@ -325,6 +364,7 @@ async def insert_rows(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
request = await self._prepare_row_insert_request(
collection_name, entities, partition_name, schema, timeout, **kwargs
)
Expand Down Expand Up @@ -359,6 +399,7 @@ async def _prepare_row_insert_request(
enable_dynamic=enable_dynamic,
)

@retry_on_rpc_failure()
async def delete(
self,
collection_name: str,
Expand All @@ -367,6 +408,7 @@ async def delete(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(collection_name=collection_name, timeout=timeout)
try:
req = Prepare.delete_request(
Expand Down Expand Up @@ -421,6 +463,7 @@ async def upsert(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if not check_invalid_binary_vector(entities):
raise ParamError(message="Invalid binary vector data exists")

Expand Down Expand Up @@ -466,6 +509,7 @@ async def upsert_rows(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if isinstance(entities, dict):
entities = [entities]
request = await self._prepare_row_upsert_request(
Expand Down Expand Up @@ -520,6 +564,7 @@ async def search(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand Down Expand Up @@ -557,6 +602,7 @@ async def hybrid_search(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
check_pass_param(
limit=limit,
round_decimal=round_decimal,
Expand Down Expand Up @@ -610,7 +656,7 @@ async def create_index(
collection_desc = await self.describe_collection(
collection_name, timeout=timeout, **copy_kwargs
)

await self.ensure_channel_ready()
valid_field = False
for fields in collection_desc["fields"]:
if field_name != fields["name"]:
Expand All @@ -635,8 +681,67 @@ async def create_index(
status = await self._async_stub.CreateIndex(index_param, timeout=timeout)
check_status(status)

index_success, fail_reason = await self.wait_for_creating_index(
collection_name=collection_name,
index_name=index_name,
timeout=timeout,
field_name=field_name,
)

if not index_success:
raise MilvusException(message=fail_reason)

return Status(status.code, status.reason)

@retry_on_rpc_failure()
async def wait_for_creating_index(
self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs
):
timestamp = await self.alloc_timestamp()
start = time.time()
while True:
await asyncio.sleep(0.5)
state, fail_reason = await self.get_index_state(
collection_name, index_name, timeout=timeout, timestamp=timestamp, **kwargs
)
if state == IndexState.Finished:
return True, fail_reason
if state == IndexState.Failed:
return False, fail_reason
end = time.time()
if isinstance(timeout, int) and end - start > timeout:
msg = (
f"collection {collection_name} create index {index_name} "
f"timeout in {timeout}s"
)
raise MilvusException(message=msg)

@retry_on_rpc_failure()
async def get_index_state(
self,
collection_name: str,
index_name: str,
timeout: Optional[float] = None,
timestamp: Optional[int] = None,
**kwargs,
):
request = Prepare.describe_index_request(collection_name, index_name, timestamp)
response = await self._async_stub.DescribeIndex(request, timeout=timeout)
status = response.status
check_status(status)

if len(response.index_descriptions) == 1:
index_desc = response.index_descriptions[0]
return index_desc.state, index_desc.index_state_fail_reason
# just for create_index.
field_name = kwargs.pop("field_name", "")
if field_name != "":
for index_desc in response.index_descriptions:
if index_desc.field_name == field_name:
return index_desc.state, index_desc.index_state_fail_reason

raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName)

@retry_on_rpc_failure()
async def get(
self,
Expand All @@ -647,6 +752,7 @@ async def get(
timeout: Optional[float] = None,
):
# TODO: some check
await self.ensure_channel_ready()
request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names)
return await self._async_stub.Retrieve.get(request, timeout=timeout)

Expand All @@ -660,6 +766,7 @@ async def query(
timeout: Optional[float] = None,
**kwargs,
):
await self.ensure_channel_ready()
if output_fields is not None and not isinstance(output_fields, (list,)):
raise ParamError(message="Invalid query format. 'output_fields' must be a list")
request = Prepare.query_request(
Expand Down Expand Up @@ -693,15 +800,9 @@ async def query(
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
@upgrade_reminder
def __internal_register(self, user: str, host: str, **kwargs) -> int:
req = Prepare.register_request(user, host)

async def wait_for_connect_response():
return await self._async_stub.Connect(request=req)

loop = asyncio.get_event_loop()
response = loop.run_until_complete(wait_for_connect_response())

@ignore_unimplemented(0)
async def alloc_timestamp(self, timeout: Optional[float] = None) -> int:
request = milvus_types.AllocTimestampRequest()
response = await self._async_stub.AllocTimestamp(request, timeout=timeout)
check_status(response.status)
return response.identifier
return response.timestamp
4 changes: 3 additions & 1 deletion pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def connect_milvus(**kwargs):
t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
if not _async:
gh._wait_for_channel_ready(timeout=timeout)

if kwargs.get("keep_alive", False):
gh.register_state_change_callback(
ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle
Expand Down

0 comments on commit 8aa6de4

Please sign in to comment.