Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: [2.4] add authorization_interceptor and db_interceptor to async channel #2472

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/simple_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
index_params.add_index(field_name = "embeddings", index_type = "HNSW", metric_type="L2", nlist=128)
index_params.add_index(field_name = "embeddings2",index_type = "HNSW", metric_type="L2", nlist=128)

# Always use `await` when you want to guarantee the execution order of tasks.
async def recreate_collection():
print(fmt.format("Start dropping collection"))
await async_milvus_client.drop_collection(collection_name)
Expand Down
46 changes: 27 additions & 19 deletions pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pymilvus.grpc_gen import milvus_pb2_grpc
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils, utils
from . import entity_helper, ts_utils, utils
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
from .async_interceptor import async_header_adder_interceptor
from .check import (
Expand Down Expand Up @@ -61,8 +61,8 @@ def __init__(
self._request_id = None
self._user = kwargs.get("user")
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel() # init channel and stub
self._setup_db_name(kwargs.get("db_name"))
self._setup_grpc_channel(**kwargs)
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
Expand Down Expand Up @@ -95,12 +95,7 @@ def _set_authorization(self, **kwargs):
self._server_pem_path = kwargs.get("server_pem_path", "")
self._server_name = kwargs.get("server_name", "")

self._authorization_interceptor = None
self._setup_authorization_interceptor(
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)
self._async_authorization_interceptor = None

def __enter__(self):
return self
Expand Down Expand Up @@ -131,7 +126,7 @@ def close(self):
self._async_channel.close()

def reset_db_name(self, db_name: str):
self._setup_db_interceptor(db_name)
self._setup_db_name(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)

Expand All @@ -147,16 +142,19 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
keys.append("authorization")
values.append(authorization)
if len(keys) > 0 and len(values) > 0:
self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values)
self._async_authorization_interceptor = async_header_adder_interceptor(keys, values)
self._final_channel._unary_unary_interceptors.append(
self._async_authorization_interceptor
)

def _setup_db_interceptor(self, db_name: str):
def _setup_db_name(self, db_name: str):
if db_name is None:
self._db_interceptor = None
self._db_name = None
else:
check_pass_param(db_name=db_name)
self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name])
self._db_name = db_name

def _setup_grpc_channel(self):
def _setup_grpc_channel(self, **kwargs):
if self._async_channel is None:
opts = [
(cygrpc.ChannelArgKey.max_send_message_length, -1),
Expand Down Expand Up @@ -202,21 +200,31 @@ def _setup_grpc_channel(self):

# avoid to add duplicate headers.
self._final_channel = self._async_channel
if self._log_level:

if self._async_authorization_interceptor:
self._final_channel._unary_unary_interceptors.append(
self._async_authorization_interceptor
)
else:
self._setup_authorization_interceptor(
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)
if self._db_name:
async_db_interceptor = async_header_adder_interceptor(["dbname"], [self._db_name])
self._final_channel._unary_unary_interceptors.append(async_db_interceptor)
if self._log_level:
async_log_level_interceptor = async_header_adder_interceptor(
["log_level"], [self._log_level]
)
self._final_channel._unary_unary_interceptors.append(async_log_level_interceptor)

self._log_level = None
if self._request_id:

async_request_id_interceptor = async_header_adder_interceptor(
["client_request_id"], [self._request_id]
)
self._final_channel._unary_unary_interceptors.append(async_request_id_interceptor)

self._request_id = None
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/async_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def intercept_stream_stream(
return await continuation(new_details, new_request_iterator)


def async_header_adder_interceptor(headers: List[str], values: List[str]):
def async_header_adder_interceptor(headers: List[str], values: Union[List[str], List[bytes]]):
def intercept_call(client_call_details: ClientCallDetails, request: Any):
metadata = []
if client_call_details.metadata:
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/milvus_client/async_milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def _fast_create_collection(
if "consistency_level" not in kwargs:
kwargs["consistency_level"] = DEFAULT_CONSISTENCY_LEVEL
try:
await conn.async_create_collection(collection_name, schema, timeout=timeout, **kwargs)
await conn.create_collection(collection_name, schema, timeout=timeout, **kwargs)
logger.debug("Successfully created collection: %s", collection_name)
except Exception as ex:
logger.error("Failed to create collection: %s", collection_name)
Expand Down
Loading