Skip to content

Commit

Permalink
Suport search on multi-vec fields
Browse files Browse the repository at this point in the history
Signed-off-by: xige-16 <[email protected]>
  • Loading branch information
xige-16 committed Dec 8, 2023
1 parent f12777b commit e95d39f
Show file tree
Hide file tree
Showing 12 changed files with 774 additions and 405 deletions.
77 changes: 76 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import ujson

from pymilvus.exceptions import MilvusException
from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config

Expand Down Expand Up @@ -271,6 +271,81 @@ def __next__(self) -> Any:
raise StopIteration


class RRFRanker:
def __init__(
self,
k: int = 60,
):
self._strategy = "rrf"
self._k = k

def dict(self):
params = {
"k": self._k,
}
return {
"strategy": self._strategy,
"params": params,
}


class WeightedRanker:
def __init__(self, *nums):
self._strategy = "Weighted"
weights = []
for num in nums:
weights.append(num)
self._weights = weights

def dict(self):
params = {
"weights": self._weights,
}
return {
"strategy": self._strategy,
"params": params,
}


class AnnSearchRequest:
def __init__(
self,
data: List,
anns_field: str,
param: Dict,
limit: int,
expr: Optional[str] = None,
):
self._data = data
self._anns_field = anns_field
self._param = param
self._limit = limit

if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
self._expr = expr

@property
def data(self):
return self._data

@property
def anns_field(self):
return self._anns_field

@property
def param(self):
return self._param

@property
def limit(self):
return self._limit

@property
def expr(self):
return self._param


class SearchResult(list):
"""nq results: List[Hits]"""

Expand Down
72 changes: 71 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils
from .abstract import CollectionSchema, MutationResult, SearchResult
from .abstract import AnnSearchRequest, CollectionSchema, MutationResult, SearchResult
from .asynch import (
CreateIndexFuture,
FlushFuture,
Expand Down Expand Up @@ -708,6 +708,25 @@ def _execute_search(
return SearchFuture(None, None, e)
raise e from e

def _execute_searchV2(
self, request: milvus_types.SearchRequestV2, timeout: Optional[float] = None, **kwargs
):
try:
if kwargs.get("_async", False):
future = self._stub.SearchV2.future(request, timeout=timeout)
func = kwargs.get("_callback", None)
return SearchFuture(future, func)

response = self._stub.SearchV2(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)

except Exception as e:
if kwargs.get("_async", False):
return SearchFuture(None, None, e)
raise e from e

@retry_on_rpc_failure()
def search(
self,
Expand Down Expand Up @@ -747,6 +766,57 @@ def search(
)
return self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs)

@retry_on_rpc_failure()
def searchV2(
self,
collection_name: str,
reqs: List[AnnSearchRequest],
rerank_param: Dict,
limit: int,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[float] = None,
**kwargs,
):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
partition_name_array=partition_names,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", None),
)

requests = []
for req in reqs:
search_request = Prepare.search_requests_with_expr(
collection_name,
req.data,
req.anns_field,
req.param,
req.limit,
req.expr,
partition_names,
output_fields,
round_decimal,
**kwargs,
)
requests.append(search_request)

search_request_v2 = Prepare.search_requestV2_with_ranker(
collection_name,
requests,
rerank_param,
limit,
partition_names,
output_fields,
round_decimal,
**kwargs,
)
return self._execute_searchV2(
search_request_v2, timeout, round_decimal=round_decimal, **kwargs
)

@retry_on_rpc_failure()
def get_query_segment_info(self, collection_name: str, timeout: float = 30, **kwargs):
req = Prepare.get_query_segment_info_request(collection_name)
Expand Down
41 changes: 41 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,47 @@ def dump(v: Dict):

return request

@classmethod
def search_requestV2_with_ranker(
cls,
collection_name: str,
reqs: List,
rerank_param: Dict,
limit: int,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
round_decimal: int = -1,
**kwargs,
) -> milvus_types.SearchRequestV2:

use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)
rerank_param["limit"] = limit
rerank_param["round_decimal"] = round_decimal

def dump(v: Dict):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.SearchRequestV2(
collection_name=collection_name,
partition_names=partition_names,
requests=reqs,
output_fields=output_fields,
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0),
use_default_consistency=use_default_consistency,
consistency_level=kwargs.get("consistency_level", 0),
)

request.rank_params.extend(
[
common_types.KeyValuePair(key=str(key), value=dump(value))
for key, value in rerank_param.items()
]
)

return request

@classmethod
def create_alias_request(cls, collection_name: str, alias: str):
return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias)
Expand Down
130 changes: 67 additions & 63 deletions pymilvus/grpc_gen/common_pb2.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions pymilvus/grpc_gen/common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class PlaceholderType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
BinaryVector: _ClassVar[PlaceholderType]
FloatVector: _ClassVar[PlaceholderType]
Float16Vector: _ClassVar[PlaceholderType]
BFloat16Vector: _ClassVar[PlaceholderType]
Int64: _ClassVar[PlaceholderType]
VarChar: _ClassVar[PlaceholderType]

Expand Down Expand Up @@ -369,6 +370,7 @@ None: PlaceholderType
BinaryVector: PlaceholderType
FloatVector: PlaceholderType
Float16Vector: PlaceholderType
BFloat16Vector: PlaceholderType
Int64: PlaceholderType
VarChar: PlaceholderType
Undefined: MsgType
Expand Down Expand Up @@ -543,14 +545,18 @@ PRIVILEGE_EXT_OBJ_FIELD_NUMBER: _ClassVar[int]
privilege_ext_obj: _descriptor.FieldDescriptor

class Status(_message.Message):
__slots__ = ["error_code", "reason", "code"]
__slots__ = ["error_code", "reason", "code", "retriable", "detail"]
ERROR_CODE_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
CODE_FIELD_NUMBER: _ClassVar[int]
RETRIABLE_FIELD_NUMBER: _ClassVar[int]
DETAIL_FIELD_NUMBER: _ClassVar[int]
error_code: ErrorCode
reason: str
code: int
def __init__(self, error_code: _Optional[_Union[ErrorCode, str]] = ..., reason: _Optional[str] = ..., code: _Optional[int] = ...) -> None: ...
retriable: bool
detail: str
def __init__(self, error_code: _Optional[_Union[ErrorCode, str]] = ..., reason: _Optional[str] = ..., code: _Optional[int] = ..., retriable: bool = ..., detail: _Optional[str] = ...) -> None: ...

class KeyValuePair(_message.Message):
__slots__ = ["key", "value"]
Expand Down
612 changes: 318 additions & 294 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,20 @@ class CreateIndexRequest(_message.Message):
index_name: str
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., field_name: _Optional[str] = ..., extra_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., index_name: _Optional[str] = ...) -> None: ...

class AlterIndexRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "index_name", "extra_params"]
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
INDEX_NAME_FIELD_NUMBER: _ClassVar[int]
EXTRA_PARAMS_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
index_name: str
extra_params: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., index_name: _Optional[str] = ..., extra_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ...) -> None: ...

class DescribeIndexRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "field_name", "index_name", "timestamp"]
BASE_FIELD_NUMBER: _ClassVar[int]
Expand Down Expand Up @@ -757,6 +771,34 @@ class SearchResults(_message.Message):
collection_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ...) -> None: ...

class SearchRequestV2(_message.Message):
__slots__ = ["base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency"]
BASE_FIELD_NUMBER: _ClassVar[int]
DB_NAME_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
PARTITION_NAMES_FIELD_NUMBER: _ClassVar[int]
REQUESTS_FIELD_NUMBER: _ClassVar[int]
RANK_PARAMS_FIELD_NUMBER: _ClassVar[int]
TRAVEL_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
GUARANTEE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
NOT_RETURN_ALL_META_FIELD_NUMBER: _ClassVar[int]
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
CONSISTENCY_LEVEL_FIELD_NUMBER: _ClassVar[int]
USE_DEFAULT_CONSISTENCY_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
db_name: str
collection_name: str
partition_names: _containers.RepeatedScalarFieldContainer[str]
requests: _containers.RepeatedCompositeFieldContainer[SearchRequest]
rank_params: _containers.RepeatedCompositeFieldContainer[_common_pb2.KeyValuePair]
travel_timestamp: int
guarantee_timestamp: int
not_return_all_meta: bool
output_fields: _containers.RepeatedScalarFieldContainer[str]
consistency_level: _common_pb2.ConsistencyLevel
use_default_consistency: bool
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., partition_names: _Optional[_Iterable[str]] = ..., requests: _Optional[_Iterable[_Union[SearchRequest, _Mapping]]] = ..., rank_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., not_return_all_meta: bool = ..., output_fields: _Optional[_Iterable[str]] = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ...) -> None: ...

class FlushRequest(_message.Message):
__slots__ = ["base", "db_name", "collection_names"]
BASE_FIELD_NUMBER: _ClassVar[int]
Expand Down
Loading

0 comments on commit e95d39f

Please sign in to comment.