Skip to content

Commit

Permalink
fix: use correct primary field name in Hits (#2560)
Browse files Browse the repository at this point in the history
issue: #2558

Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Jan 13, 2025
1 parent 2f6b016 commit f85223e
Show file tree
Hide file tree
Showing 8 changed files with 375 additions and 298 deletions.
62 changes: 62 additions & 0 deletions examples/customize_schema_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
DataType
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)

schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("uid", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("title", DataType.VARCHAR, max_length=64)
schema.add_field("id", DataType.VARCHAR, max_length=64)


index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

rng = np.random.default_rng(seed=19530)
rows = [
{"uid": 1, "embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1", "id":"u1"},
{"uid": 2, "embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2", "id":"u2"},
{"uid": 3, "embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3", "id":"u3"},
{"uid": 4, "embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4", "id":"u4"},
{"uid": 5, "embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5", "id":"u5"},
{"uid": 6, "embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6", "id":"u6"},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)


print(fmt.format("Start load collection "))
milvus_client.load_collection(collection_name)

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))

print(fmt.format(f"Start search with retrieve serveral fields."))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["id"])
for hits in result:
for hit in hits:
print(f"hit: {hit}")

milvus_client.drop_collection(collection_name)
21 changes: 16 additions & 5 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def __init__(
self._nq = res.num_queries
all_topks = res.topks
self.recalls = res.recalls
self._pk_name = res.primary_field_name or "id"

self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0")

Expand All @@ -500,7 +501,14 @@ def __init__(
start, end = nq_thres, nq_thres + topk
nq_th_fields = self.get_fields_by_range(start, end, fields_data)
data.append(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
Hits(
topk,
all_pks[start:end],
all_scores[start:end],
nq_th_fields,
output_fields,
self._pk_name,
)
)
nq_thres += topk
self._session_ts = session_ts
Expand Down Expand Up @@ -673,6 +681,7 @@ def __init__(
distances: List[float],
fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]],
output_fields: List[str],
pk_name: str,
):
"""
Args:
Expand All @@ -681,6 +690,7 @@ def __init__(
"""
self.ids = pks
self.distances = distances
self._pk_name = pk_name

all_fields = list(fields.keys())
dynamic_fields = list(set(output_fields) - set(all_fields))
Expand Down Expand Up @@ -719,7 +729,7 @@ def __init__(
# sparse float vector and other fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
hits.append(Hit(pks[i], distances[i], curr_field, self._pk_name))

super().__init__(hits)

Expand All @@ -739,10 +749,11 @@ class Hit:
distance: float
fields: Dict[str, Any]

def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any]):
def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any], pk_name: str):
self.id = pk
self.distance = distance
self.fields = fields
self._pk_name = pk_name

def __getattr__(self, item: str):
if item not in self.fields:
Expand All @@ -765,13 +776,13 @@ def get(self, field_name: str) -> Any:
return self.fields.get(field_name)

def __str__(self) -> str:
return f"id: {self.id}, distance: {self.distance}, entity: {self.fields}"
return f"{self._pk_name}: {self.id}, distance: {self.distance}, entity: {self.fields}"

__repr__ = __str__

def to_dict(self):
return {
"id": self.id,
self._pk_name: self.id,
"distance": self.distance,
"entity": self.fields,
}
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
528 changes: 264 additions & 264 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -994,18 +994,20 @@ class QueryRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., expr: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., partition_names: _Optional[_Iterable[str]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., query_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ..., expr_template_values: _Optional[_Mapping[str, _schema_pb2.TemplateValue]] = ...) -> None: ...

class QueryResults(_message.Message):
__slots__ = ("status", "fields_data", "collection_name", "output_fields", "session_ts")
__slots__ = ("status", "fields_data", "collection_name", "output_fields", "session_ts", "primary_field_name")
STATUS_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
PRIMARY_FIELD_NAME_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
collection_name: str
output_fields: _containers.RepeatedScalarFieldContainer[str]
session_ts: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., session_ts: _Optional[int] = ...) -> None: ...
primary_field_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., session_ts: _Optional[int] = ..., primary_field_name: _Optional[str] = ...) -> None: ...

class QueryCursor(_message.Message):
__slots__ = ("session_ts", "str_pk", "int_pk")
Expand Down
40 changes: 20 additions & 20 deletions pymilvus/grpc_gen/schema_pb2.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pymilvus/grpc_gen/schema_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class SearchIteratorV2Results(_message.Message):
def __init__(self, token: _Optional[str] = ..., last_bound: _Optional[float] = ...) -> None: ...

class SearchResultData(_message.Message):
__slots__ = ("num_queries", "top_k", "fields_data", "scores", "ids", "topks", "output_fields", "group_by_field_value", "all_search_count", "distances", "search_iterator_v2_results", "recalls")
__slots__ = ("num_queries", "top_k", "fields_data", "scores", "ids", "topks", "output_fields", "group_by_field_value", "all_search_count", "distances", "search_iterator_v2_results", "recalls", "primary_field_name")
NUM_QUERIES_FIELD_NUMBER: _ClassVar[int]
TOP_K_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -321,6 +321,7 @@ class SearchResultData(_message.Message):
DISTANCES_FIELD_NUMBER: _ClassVar[int]
SEARCH_ITERATOR_V2_RESULTS_FIELD_NUMBER: _ClassVar[int]
RECALLS_FIELD_NUMBER: _ClassVar[int]
PRIMARY_FIELD_NAME_FIELD_NUMBER: _ClassVar[int]
num_queries: int
top_k: int
fields_data: _containers.RepeatedCompositeFieldContainer[FieldData]
Expand All @@ -333,7 +334,8 @@ class SearchResultData(_message.Message):
distances: _containers.RepeatedScalarFieldContainer[float]
search_iterator_v2_results: SearchIteratorV2Results
recalls: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, num_queries: _Optional[int] = ..., top_k: _Optional[int] = ..., fields_data: _Optional[_Iterable[_Union[FieldData, _Mapping]]] = ..., scores: _Optional[_Iterable[float]] = ..., ids: _Optional[_Union[IDs, _Mapping]] = ..., topks: _Optional[_Iterable[int]] = ..., output_fields: _Optional[_Iterable[str]] = ..., group_by_field_value: _Optional[_Union[FieldData, _Mapping]] = ..., all_search_count: _Optional[int] = ..., distances: _Optional[_Iterable[float]] = ..., search_iterator_v2_results: _Optional[_Union[SearchIteratorV2Results, _Mapping]] = ..., recalls: _Optional[_Iterable[float]] = ...) -> None: ...
primary_field_name: str
def __init__(self, num_queries: _Optional[int] = ..., top_k: _Optional[int] = ..., fields_data: _Optional[_Iterable[_Union[FieldData, _Mapping]]] = ..., scores: _Optional[_Iterable[float]] = ..., ids: _Optional[_Union[IDs, _Mapping]] = ..., topks: _Optional[_Iterable[int]] = ..., output_fields: _Optional[_Iterable[str]] = ..., group_by_field_value: _Optional[_Union[FieldData, _Mapping]] = ..., all_search_count: _Optional[int] = ..., distances: _Optional[_Iterable[float]] = ..., search_iterator_v2_results: _Optional[_Union[SearchIteratorV2Results, _Mapping]] = ..., recalls: _Optional[_Iterable[float]] = ..., primary_field_name: _Optional[str] = ...) -> None: ...

class VectorClusteringInfo(_message.Message):
__slots__ = ("field", "centroid")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestHit:
])
def test_hit_no_fields(self, pk_dist: List[Tuple]):
pk, dist = pk_dist
h = Hit(pk, dist, {})
h = Hit(pk, dist, {}, "id")
assert h.id == pk
assert h.score == dist
assert h.distance == dist
Expand All @@ -30,9 +30,9 @@ def test_hit_no_fields(self, pk_dist: List[Tuple]):
}

@pytest.mark.parametrize("pk_dist_fields", [
(1, 0.1, {"vector": [1., 2., 3., 4.], "description": "This is a test", 'd_a': "dynamic a"}),
(2, 0.3, {"vector": [3., 4., 5., 6.], "description": "This is a test too", 'd_b': "dynamic b"}),
("a", 0.4, {"vector": [4., 4., 4., 4.], "description": "This is a third test", 'd_a': "dynamic a twice"}),
(1, 0.1, {"vector": [1., 2., 3., 4.], "description": "This is a test", 'd_a': "dynamic a"}, "id"),
(2, 0.3, {"vector": [3., 4., 5., 6.], "description": "This is a test too", 'd_b': "dynamic b"}, "id"),
("a", 0.4, {"vector": [4., 4., 4., 4.], "description": "This is a third test", 'd_a': "dynamic a twice"}, "id"),
])
def test_hit_with_fields(self, pk_dist_fields: List[Tuple]):
h = Hit(*pk_dist_fields)
Expand Down

0 comments on commit f85223e

Please sign in to comment.