From 124b9afce10450784ffbdd4343173ea01b37ee0c Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Fri, 8 Mar 2024 16:30:21 +0800 Subject: [PATCH] support bge style sparse embedding Signed-off-by: Buqian Zheng --- pymilvus/client/entity_helper.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 8fd554e8e..74e9afc0e 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -59,10 +59,18 @@ def entity_is_sparse_matrix(entity): if sparse_is_scipy_format(entity): return True try: + def is_type_in_str(v, t): + if not isinstance(v, str): + return False + try: + t(v) + return True + except ValueError: + return False def is_int_type(v): - return isinstance(v, (int, np.integer)) + return isinstance(v, (int, np.integer)) or is_type_in_str(v, int) def is_float_type(v): - return isinstance(v, (float, np.floating)) + return isinstance(v, (float, np.floating)) or is_type_in_str(v, float) # must be of multiple rows for item in entity: pairs = item.items() if isinstance(item, dict) else item @@ -98,8 +106,6 @@ def sparse_float_row_to_bytes(indices, values): f"sparse vector index must be positive and less than 2^32-1: {i}") if math.isnan(v): raise ValueError("sparse vector value must not be NaN") - # if i > 2**31-1: - # print(f'seeing large index: {i}') data += struct.pack("I", i) data += struct.pack("f", v) return data @@ -116,8 +122,8 @@ def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array: for row_id, row in enumerate(data): row = row.items() if isinstance(row, dict) else row row_indices.extend([row_id] * len(row)) - col_indices.extend([col_id for col_id, _ in row]) - values.extend([value for _, value in row]) + col_indices.extend([int(col_id) if isinstance(col_id, str) else col_id for col_id, _ in row]) + values.extend([float(value) if isinstance(value, str) else value for _, value in row]) return sparse.csr_array((values, (row_indices, col_indices))) csr = unify_sparse_input(data) result = schema_types.SparseFloatArray()