From 796d95bde1be0c8afd15652ceacb5ff0c2f5e8d8 Mon Sep 17 00:00:00 2001 From: Cai Zhang Date: Fri, 20 Dec 2024 16:37:10 +0800 Subject: [PATCH] Support hybrid search with expression template Signed-off-by: Cai Zhang --- examples/hybrid_search.py | 22 +++++++++++++- examples/hybrid_search/hybrid_search.py | 29 +++++++++++++++++-- pymilvus/client/abstract.py | 6 ++++ pymilvus/client/grpc_handler.py | 1 + pymilvus/client/prepare.py | 4 ++- pymilvus/milvus_client/async_milvus_client.py | 1 - 6 files changed, 58 insertions(+), 5 deletions(-) diff --git a/examples/hybrid_search.py b/examples/hybrid_search.py index 28ae0b309..e2a87a634 100644 --- a/examples/hybrid_search.py +++ b/examples/hybrid_search.py @@ -68,7 +68,27 @@ req = AnnSearchRequest(**search_param) req_list.append(req) -print("rank by RRFRanker") +print(fmt.format("rank by RRFRanker")) +hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"]) +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") + +req_list = [] +for i in range(len(field_names)): + # 4. generate search data + vectors_to_search = rng.random((nq, dim)) + search_param = { + "data": vectors_to_search, + "anns_field": field_names[i], + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > {radius}", + "expr_params": {"radius": 0.5}} + req = AnnSearchRequest(**search_param) + req_list.append(req) + +print(fmt.format("rank by RRFRanker with expression template")) hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"]) for hits in hybrid_res: for hit in hits: diff --git a/examples/hybrid_search/hybrid_search.py b/examples/hybrid_search/hybrid_search.py index 6a13045f0..b03d3f95f 100644 --- a/examples/hybrid_search/hybrid_search.py +++ b/examples/hybrid_search/hybrid_search.py @@ -79,14 +79,39 @@ req = AnnSearchRequest(**search_param) req_list.append(req) +print(fmt.format("rank by WightedRanker")) hybrid_res = hello_milvus.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["random"]) +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") -print("rank by WightedRanker") +print(fmt.format("rank by RRFRanker")) +hybrid_res = hello_milvus.hybrid_search(req_list, RRFRanker(), default_limit, output_fields=["random"]) +for hits in hybrid_res: + for hit in hits: + print(f" hybrid search hit: {hit}") + +req_list = [] +for i in range(len(field_names)): + # 4. generate search data + vectors_to_search = rng.random((nq, dim)) + search_param = { + "data": vectors_to_search, + "anns_field": field_names[i], + "param": {"metric_type": "L2"}, + "limit": default_limit, + "expr": "random > {radius}", + "expr_params": {"radius": 0.5}} + req = AnnSearchRequest(**search_param) + req_list.append(req) + +print(fmt.format("rank by WightedRanker with expression template")) +hybrid_res = hello_milvus.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["random"]) for hits in hybrid_res: for hit in hits: print(f" hybrid search hit: {hit}") -print("rank by RRFRanker") +print(fmt.format("rank by RRFRanker with expression template")) hybrid_res = hello_milvus.hybrid_search(req_list, RRFRanker(), default_limit, output_fields=["random"]) for hits in hybrid_res: for hit in hits: diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 9606b96dd..81211463f 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -417,6 +417,7 @@ def __init__( param: Dict, limit: int, expr: Optional[str] = None, + expr_params: Optional[dict] = None, ): self._data = data self._anns_field = anns_field @@ -426,6 +427,7 @@ def __init__( if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) self._expr = expr + self._expr_params = expr_params @property def data(self): @@ -447,6 +449,10 @@ def limit(self): def expr(self): return self._expr + @property + def expr_params(self): + return self._expr_params + def __str__(self): return { "anns_field": self.anns_field, diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 0fe7222e6..56143a7df 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -868,6 +868,7 @@ def hybrid_search( req.expr, partition_names=partition_names, round_decimal=round_decimal, + expr_params=req.expr_params, **kwargs, ) requests.append(search_request) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index b0bbc4484..4b1842dee 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -997,7 +997,9 @@ def search_requests_with_expr( placeholder_group=plg_str, dsl_type=common_types.DslType.BoolExprV1, search_params=req_params, - expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), + expr_template_values=cls.prepare_expression_template( + {} if kwargs.get("expr_params") is None else kwargs.get("expr_params") + ), ) if expr is not None: request.dsl = expr diff --git a/pymilvus/milvus_client/async_milvus_client.py b/pymilvus/milvus_client/async_milvus_client.py index 64845c1dc..61d5f71db 100644 --- a/pymilvus/milvus_client/async_milvus_client.py +++ b/pymilvus/milvus_client/async_milvus_client.py @@ -286,7 +286,6 @@ async def hybrid_search( partition_names: Optional[List[str]] = None, **kwargs, ) -> List[List[dict]]: - conn = self._get_connection() try: res = await conn.hybrid_search(