Skip to content

Commit

Permalink
Support hybrid search with expression template
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Zhang <[email protected]>
  • Loading branch information
xiaocai2333 committed Dec 20, 2024
1 parent 8aa6de4 commit 796d95b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 5 deletions.
22 changes: 21 additions & 1 deletion examples/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,27 @@
req = AnnSearchRequest(**search_param)
req_list.append(req)

print("rank by RRFRanker")
print(fmt.format("rank by RRFRanker"))
hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")

req_list = []
for i in range(len(field_names)):
# 4. generate search data
vectors_to_search = rng.random((nq, dim))
search_param = {
"data": vectors_to_search,
"anns_field": field_names[i],
"param": {"metric_type": "L2"},
"limit": default_limit,
"expr": "random > {radius}",
"expr_params": {"radius": 0.5}}
req = AnnSearchRequest(**search_param)
req_list.append(req)

print(fmt.format("rank by RRFRanker with expression template"))
hybrid_res = milvus_client.hybrid_search(collection_name, req_list, RRFRanker(), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
Expand Down
29 changes: 27 additions & 2 deletions examples/hybrid_search/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,39 @@
req = AnnSearchRequest(**search_param)
req_list.append(req)

print(fmt.format("rank by WightedRanker"))
hybrid_res = hello_milvus.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")

print("rank by WightedRanker")
print(fmt.format("rank by RRFRanker"))
hybrid_res = hello_milvus.hybrid_search(req_list, RRFRanker(), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")

req_list = []
for i in range(len(field_names)):
# 4. generate search data
vectors_to_search = rng.random((nq, dim))
search_param = {
"data": vectors_to_search,
"anns_field": field_names[i],
"param": {"metric_type": "L2"},
"limit": default_limit,
"expr": "random > {radius}",
"expr_params": {"radius": 0.5}}
req = AnnSearchRequest(**search_param)
req_list.append(req)

print(fmt.format("rank by WightedRanker with expression template"))
hybrid_res = hello_milvus.hybrid_search(req_list, WeightedRanker(*weights), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
print(f" hybrid search hit: {hit}")

print("rank by RRFRanker")
print(fmt.format("rank by RRFRanker with expression template"))
hybrid_res = hello_milvus.hybrid_search(req_list, RRFRanker(), default_limit, output_fields=["random"])
for hits in hybrid_res:
for hit in hits:
Expand Down
6 changes: 6 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def __init__(
param: Dict,
limit: int,
expr: Optional[str] = None,
expr_params: Optional[dict] = None,
):
self._data = data
self._anns_field = anns_field
Expand All @@ -426,6 +427,7 @@ def __init__(
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
self._expr = expr
self._expr_params = expr_params

@property
def data(self):
Expand All @@ -447,6 +449,10 @@ def limit(self):
def expr(self):
return self._expr

@property
def expr_params(self):
return self._expr_params

def __str__(self):
return {
"anns_field": self.anns_field,
Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def hybrid_search(
req.expr,
partition_names=partition_names,
round_decimal=round_decimal,
expr_params=req.expr_params,
**kwargs,
)
requests.append(search_request)
Expand Down
4 changes: 3 additions & 1 deletion pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,9 @@ def search_requests_with_expr(
placeholder_group=plg_str,
dsl_type=common_types.DslType.BoolExprV1,
search_params=req_params,
expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})),
expr_template_values=cls.prepare_expression_template(
{} if kwargs.get("expr_params") is None else kwargs.get("expr_params")
),
)
if expr is not None:
request.dsl = expr
Expand Down
1 change: 0 additions & 1 deletion pymilvus/milvus_client/async_milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ async def hybrid_search(
partition_names: Optional[List[str]] = None,
**kwargs,
) -> List[List[dict]]:

conn = self._get_connection()
try:
res = await conn.hybrid_search(
Expand Down

0 comments on commit 796d95b

Please sign in to comment.