From 2e4b44d290d299f0eea17300642769631423dd79 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Mon, 11 Mar 2024 15:40:10 +0800 Subject: [PATCH] add sparse float vector support: insert, search, query, delete; supports various sparse matrix representations also supported row based insertion so sparse is also supported in milvus client Signed-off-by: Buqian Zheng --- examples/hello_sparse.py | 154 +++++++++++++++++++++++++ examples/milvus_client/sparse.py | 88 ++++++++++++++ pymilvus/client/abstract.py | 18 +-- pymilvus/client/check.py | 28 ++--- pymilvus/client/entity_helper.py | 192 ++++++++++++++++++++++++++++++- pymilvus/client/grpc_handler.py | 3 +- pymilvus/client/prepare.py | 28 +++-- pymilvus/client/types.py | 2 + pymilvus/client/utils.py | 4 +- pymilvus/exceptions.py | 2 +- pymilvus/orm/collection.py | 24 ++-- pymilvus/orm/iterator.py | 10 +- pymilvus/orm/partition.py | 13 ++- pymilvus/orm/prepare.py | 4 +- pymilvus/orm/schema.py | 3 + pymilvus/orm/types.py | 6 +- requirements.txt | 1 + 17 files changed, 516 insertions(+), 64 deletions(-) create mode 100644 examples/hello_sparse.py create mode 100644 examples/milvus_client/sparse.py diff --git a/examples/hello_sparse.py b/examples/hello_sparse.py new file mode 100644 index 000000000..b6ac8f732 --- /dev/null +++ b/examples/hello_sparse.py @@ -0,0 +1,154 @@ +# hello_sprase.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus, +# while operating on sparse float vectors. +# 1. connect to Milvus +# 2. create collection +# 3. insert data +# 4. create index +# 5. search, query, and hybrid search on entities +# 6. delete entities by PK +# 7. drop collection +import time + +import numpy as np +from scipy.sparse import rand +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, +) + +fmt = "=== {:30} ===" +search_latency_fmt = "search latency = {:.4f}s" +num_entities, dim, density = 1000, 3000, 0.005 + +def log(msg): + print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg) + +# ----------------------------------------------------------------------------- +# connect to Milvus +log(fmt.format("start connecting to Milvus")) +connections.connect("default", host="localhost", port="19530") + +has = utility.has_collection("hello_sparse") +log(f"Does collection hello_sparse exist in Milvus: {has}") + +# ----------------------------------------------------------------------------- +# create collection with a sparse float vector column +hello_sparse = None +if not has: + fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR), + ] + schema = CollectionSchema(fields, "hello_sparse is the simplest demo to introduce sparse float vector usage") + log(fmt.format("Create collection `hello_sparse`")) + hello_sparse = Collection("hello_sparse", schema, consistency_level="Strong") +else: + hello_sparse = Collection("hello_sparse") + +log(f"hello_sparse has {hello_sparse.num_entities} entities({hello_sparse.num_entities/1000000}M), indexed {hello_sparse.has_index()}") + +# ----------------------------------------------------------------------------- +# insert +log(fmt.format("Start creating entities to insert")) +rng = np.random.default_rng(seed=19530) +# this step is so damn slow +matrix_csr = rand(num_entities, dim, density=density, format='csr') +entities = [ + rng.random(num_entities).tolist(), + matrix_csr, +] + +log(fmt.format("Start inserting entities")) +insert_result = hello_sparse.insert(entities) + +# ----------------------------------------------------------------------------- +# create index +if not hello_sparse.has_index(): + log(fmt.format("Start Creating index SPARSE_INVERTED_INDEX")) + index = { + "index_type": "SPARSE_INVERTED_INDEX", + "metric_type": "IP", + "params":{ + "drop_ratio_build": 0.2, + } + } + hello_sparse.create_index("embeddings", index) + +log(fmt.format("Start loading")) +hello_sparse.load() + +# ----------------------------------------------------------------------------- +# search based on vector similarity +log(fmt.format("Start searching based on vector similarity")) +vectors_to_search = entities[-1][-1:] +search_params = { + "metric_type": "IP", + "params": { + "drop_ratio_search": "0.2", + } +} + +start_time = time.time() +result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "random", "embeddings"]) +end_time = time.time() + +for hits in result: + for hit in hits: + print(f"hit: {hit}") +log(search_latency_fmt.format(end_time - start_time)) +# ----------------------------------------------------------------------------- +# query based on scalar filtering(boolean, int, etc.) +print(fmt.format("Start querying with `random > 0.5`")) + +start_time = time.time() +result = hello_sparse.query(expr="random > 0.5", output_fields=["random", "embeddings"]) +end_time = time.time() + +print(f"query result:\n-{result[0]}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# pagination +r1 = hello_sparse.query(expr="random > 0.5", limit=4, output_fields=["random"]) +r2 = hello_sparse.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"]) +print(f"query pagination(limit=4):\n\t{r1}") +print(f"query pagination(offset=1, limit=3):\n\t{r2}") + + +# ----------------------------------------------------------------------------- +# hybrid search +print(fmt.format("Start hybrid searching with `random > 0.5`")) + +start_time = time.time() +result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"]) +end_time = time.time() + +for hits in result: + for hit in hits: + print(f"hit: {hit}, random field: {hit.entity.get('random')}") +print(search_latency_fmt.format(end_time - start_time)) + +# ----------------------------------------------------------------------------- +# delete entities by PK +# You can delete entities by their PK values using boolean expressions. +ids = insert_result.primary_keys + +expr = f'pk in ["{ids[0]}" , "{ids[1]}"]' +print(fmt.format(f"Start deleting with expr `{expr}`")) + +result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n") + +hello_sparse.delete(expr) + +result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"]) +print(f"query after delete by expr=`{expr}` -> result: {result}\n") + + +# ----------------------------------------------------------------------------- +# drop collection +print(fmt.format("Drop collection `hello_sparse`")) +utility.drop_collection("hello_sparse") diff --git a/examples/milvus_client/sparse.py b/examples/milvus_client/sparse.py new file mode 100644 index 000000000..d55daad96 --- /dev/null +++ b/examples/milvus_client/sparse.py @@ -0,0 +1,88 @@ +from pymilvus import ( + MilvusClient, + FieldSchema, CollectionSchema, DataType, +) + +import random + +def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict: + indices = random.sample(range(dimension), non_zero_count) + values = [random.random() for _ in range(non_zero_count)] + sparse_vector = {index: value for index, value in zip(indices, values)} + return sparse_vector + + +fmt = "\n=== {:30} ===\n" +dim = 100 +non_zero_count = 20 +collection_name = "hello_sparse" +milvus_client = MilvusClient("http://localhost:19530") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) +fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, + is_primary=True, auto_id=True, max_length=100), + FieldSchema(name="random", dtype=DataType.DOUBLE), + FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR), +] +schema = CollectionSchema( + fields, "demo for using sparse float vector with milvus client") +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name="embeddings", index_name="sparse_inverted_index", + index_type="SPARSE_INVERTED_INDEX", metric_type="IP", params={"drop_ratio_build": 0.2}) +milvus_client.create_collection(collection_name, schema=schema, + index_params=index_params, timeout=5, consistency_level="Strong") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +N = 6 +rows = [{"random": i, "embeddings": generate_sparse_vector( + dim, non_zero_count)} for i in range(N)] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows, progress_bar=True) +print(fmt.format("Inserting entities done")) +print(insert_result) + +print(fmt.format(f"Start vector anns search.")) +vectors_to_search = [generate_sparse_vector(dim, non_zero_count)] +search_params = { + "metric_type": "IP", + "params": { + "drop_ratio_search": 0.2, + } +} +# no need to specify anns_field for collections with only 1 vector field +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=[ + "pk", "random", "embeddings"], search_params=search_params) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter="random < 3") +pks = [ret['pk'] for ret in query_results] +for ret in query_results: + print(ret) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query( + collection_name, filter=f"pk == '{pks[0]}'") +print(query_results[0]) + +print(f"start to delete by specifying filter in collection {collection_name}") +delete_result = milvus_client.delete(collection_name, ids=pks[:1]) +print(delete_result) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query( + collection_name, filter=f"pk == '{pks[0]}'") +print(f'query result should be empty: {query_results}') + +milvus_client.drop_collection(collection_name) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index fbec16917..bc7c061d1 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -7,6 +7,7 @@ from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config +from . import entity_helper from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED from .types import DataType @@ -15,7 +16,6 @@ class FieldSchema: def __init__(self, raw: Any): self._raw = raw - # self.field_id = 0 self.name = None self.is_primary = False @@ -28,7 +28,6 @@ def __init__(self, raw: Any): self.is_dynamic = False # For array field self.element_type = None - ## self.__pack(self._raw) def __pack(self, raw: Any): @@ -100,7 +99,6 @@ class CollectionSchema: def __init__(self, raw: Any): self._raw = raw - # self.collection_name = None self.description = None self.params = {} @@ -115,7 +113,6 @@ def __init__(self, raw: Any): self.num_partitions = 0 self.enable_dynamic_field = False - # if self._raw: self.__pack(self._raw) @@ -324,7 +321,7 @@ def dict(self): class AnnSearchRequest: def __init__( self, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -466,6 +463,13 @@ def get_fields_by_range( field_meta, ) continue + # TODO(SPARSE): do we want to allow the user to specify the return format? + if dtype == DataType.SPARSE_FLOAT_VECTOR: + field2data[name] = ( + entity_helper.sparse_proto_to_rows(vectors.sparse_float_vector, start, end), + field_meta, + ) + continue if dtype == DataType.BFLOAT16_VECTOR: field2data[name] = ( @@ -521,7 +525,7 @@ def __init__( for fname, (data, field_meta) in fields.items(): if len(data) <= i: curr_field[fname] = None - # Get vectors + # Get dense vectors if field_meta.type in ( DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, @@ -546,7 +550,7 @@ def __init__( curr_field.update(data[i]) continue - # normal fields + # sparse float vector and other fields curr_field[fname] = data[i] hits.append(Hit(pks[i], distances[i], curr_field)) diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index b979926f7..110cd88c0 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -5,6 +5,7 @@ from pymilvus.exceptions import ParamError from pymilvus.grpc_gen import milvus_pb2 as milvus_types +from . import entity_helper from .singleton_utils import Singleton @@ -40,24 +41,6 @@ def is_legal_port(port: Any) -> bool: return False -def is_legal_vector(array: Any) -> bool: - if not array or not isinstance(array, list) or len(array) == 0: - return False - - return True - - -def is_legal_bin_vector(array: Any) -> bool: - if not array or not isinstance(array, bytes) or len(array) == 0: - return False - - return True - - -def is_legal_numpy_array(array: Any) -> bool: - return not (array is None or array.size == 0) - - def int_or_str(item: Union[int, str]) -> str: if isinstance(item, int): return str(item) @@ -149,6 +132,10 @@ def is_legal_max_iterations(max_iterations: Any) -> bool: return isinstance(max_iterations, int) +def is_legal_drop_ratio(drop_ratio: Any) -> bool: + return isinstance(drop_ratio, float) and 0 <= drop_ratio < 1 + + def is_legal_team_size(team_size: Any) -> bool: return isinstance(team_size, int) @@ -197,6 +184,9 @@ def is_legal_anns_field(field: Any) -> bool: def is_legal_search_data(data: Any) -> bool: import numpy as np + if entity_helper.entity_is_sparse_matrix(data): + return True + if not isinstance(data, (list, np.ndarray)): return False @@ -331,6 +321,8 @@ def __init__(self) -> None: "team_size": is_legal_team_size, "index_name": is_legal_index_name, "timeout": is_legal_timeout, + "drop_ratio_build": is_legal_drop_ratio, + "drop_ratio_search": is_legal_drop_ratio, } def check(self, key: str, value: Callable): diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index b0da5f496..6466163f9 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -1,7 +1,10 @@ -from typing import Any, Dict, List, Optional +import math +import struct +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import ujson +from scipy import sparse from pymilvus.exceptions import ( DataNotMatchException, @@ -16,6 +19,172 @@ CHECK_STR_ARRAY = True +# in search results, if output fields includes a sparse float vector field, we +# will return a SparseRowOutputType for each entity. Using Dict for readability. +# TODO(SPARSE): to allow the user to specify output format. +SparseRowOutputType = Dict[int, float] + +# we accept the following types as input for sparse matrix in user facing APIs +# such as insert, search, etc.: +# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil +# - iterable of iterables, each element(iterable) is a sparse vector with index +# as key and value as float. +# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...] +# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...] +# both index/value can be str numbers: {'2': '3.1'} +SparseMatrixInputType = Union[ + Iterable[ + Union[ + SparseRowOutputType, + Iterable[Tuple[int, float]], # only type hint, we accept int/float like types + ] + ], + sparse.csc_array, + sparse.coo_array, + sparse.bsr_array, + sparse.dia_array, + sparse.dok_array, + sparse.lil_array, + sparse.csr_array, + sparse.spmatrix, +] + + +def sparse_is_scipy_matrix(data: Any): + return isinstance(data, sparse.spmatrix) + + +def sparse_is_scipy_array(data: Any): + # sparse.sparray, the common superclass of sparse.*_array, is introduced in + # scipy 1.11.0, which requires python 3.9, higher than pymilvus's current requirement. + return isinstance( + data, + ( + sparse.bsr_array, + sparse.coo_array, + sparse.csc_array, + sparse.csr_array, + sparse.dia_array, + sparse.dok_array, + sparse.lil_array, + ), + ) + + +def sparse_is_scipy_format(data: Any): + return sparse_is_scipy_matrix(data) or sparse_is_scipy_array(data) + + +def entity_is_sparse_matrix(entity: Any): + if sparse_is_scipy_format(entity): + return True + try: + + def is_type_in_str(v: Any, t: Any): + if not isinstance(v, str): + return False + try: + t(v) + except ValueError: + return False + return True + + def is_int_type(v: Any): + return isinstance(v, (int, np.integer)) or is_type_in_str(v, int) + + def is_float_type(v: Any): + return isinstance(v, (float, np.floating)) or is_type_in_str(v, float) + + # must be of multiple rows + for item in entity: + pairs = item.items() if isinstance(item, dict) else item + # each row must be a list of Tuple[int, float] + for pair in pairs: + if len(pair) != 2 or not is_int_type(pair[0]) or not is_float_type(pair[1]): + return False + except Exception: + return False + return True + + +# parses plain bytes to a sparse float vector(SparseRowOutputType) +def sparse_parse_single_row(data: bytes) -> SparseRowOutputType: + if len(data) % 8 != 0: + msg = f"The length of data must be a multiple of 8, got {len(data)}" + raise ValueError(msg) + + return { + struct.unpack("I", data[i : i + 4])[0]: struct.unpack("f", data[i + 4 : i + 8])[0] + for i in range(0, len(data), 8) + } + + +# converts supported sparse matrix to schemapb.SparseFloatArray proto +def sparse_rows_to_proto(data: SparseMatrixInputType) -> schema_types.SparseFloatArray: + # converts a sparse float vector to plain bytes. the format is the same as how + # milvus interprets/persists the data. + def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]): + if len(indices) != len(values): + msg = f"length of indices and values must be the same, got {len(indices)} and {len(values)}" + raise ValueError(msg) + data = b"" + for i, v in sorted(zip(indices, values), key=lambda x: x[0]): + if not (0 <= i < 2**32 - 1): + msg = f"sparse vector index must be positive and less than 2^32-1: {i}" + raise ValueError(msg) + if math.isnan(v): + msg = "sparse vector value must not be NaN" + raise ValueError(msg) + data += struct.pack("I", i) + data += struct.pack("f", v) + return data + + def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array: + if isinstance(data, sparse.csr_array): + return data + if sparse_is_scipy_array(data): + return data.tocsr() + if sparse_is_scipy_matrix(data): + return sparse.csr_array(data.tocsr()) + row_indices = [] + col_indices = [] + values = [] + for row_id, row_data in enumerate(data): + row = row_data.items() if isinstance(row_data, dict) else row_data + row_indices.extend([row_id] * len(row)) + col_indices.extend( + [int(col_id) if isinstance(col_id, str) else col_id for col_id, _ in row] + ) + values.extend([float(value) if isinstance(value, str) else value for _, value in row]) + return sparse.csr_array((values, (row_indices, col_indices))) + + csr = unify_sparse_input(data) + result = schema_types.SparseFloatArray() + result.dim = csr.shape[1] + for start, end in zip(csr.indptr[:-1], csr.indptr[1:]): + result.contents.append( + sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end]) + ) + return result + + +# converts schema_types.SparseFloatArray proto to Iterable[SparseRowOutputType] +def sparse_proto_to_rows( + sfv: schema_types.SparseFloatArray, start: Optional[int] = None, end: Optional[int] = None +) -> Iterable[SparseRowOutputType]: + if not isinstance(sfv, schema_types.SparseFloatArray): + msg = "Vector must be a sparse float vector" + raise TypeError(msg) + start = start or 0 + end = end or len(sfv.contents) + return [sparse_parse_single_row(row_bytes) for row_bytes in sfv.contents[start:end]] + + +def get_input_num_rows(entity: Any) -> int: + if sparse_is_scipy_format(entity): + return entity.shape[0] + return len(entity) + def entity_type_to_dtype(entity_type: Any): if isinstance(entity_type, int): @@ -139,6 +308,17 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info elif field_type == DataType.BFLOAT16_VECTOR: field_data.vectors.dim = len(field_value) // 2 field_data.vectors.bfloat16_vector += bytes(field_value) + elif field_type == DataType.SPARSE_FLOAT_VECTOR: + # field_value is a single row of sparse float vector in user provided format + if not sparse_is_scipy_format(field_value): + field_value = [field_value] + elif field_value.shape[0] != 1: + raise ParamError(message="invalid input for sparse float vector: expect 1 row") + if not entity_is_sparse_matrix(field_value): + raise ParamError(message="invalid input for sparse float vector") + field_data.vectors.sparse_float_vector.contents.append( + sparse_rows_to_proto(field_value).contents[0] + ) elif field_type == DataType.VARCHAR: field_data.scalars.string_data.data.append( convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY) @@ -190,6 +370,8 @@ def entity_to_field_data(entity: Any, field_info: Any): field_data.scalars.json_data.data.extend(entity_to_json_arr(entity)) elif entity_type == DataType.ARRAY: field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info)) + elif entity_type == DataType.SPARSE_FLOAT_VECTOR: + field_data.vectors.sparse_float_vector.CopyFrom(sparse_rows_to_proto(entity.get("values"))) else: raise ParamError(message=f"UnSupported data type: {entity_type}") @@ -247,6 +429,7 @@ def extract_array_row_data(field_data: Any, index: int): # pylint: disable=R1702 (too-many-nested-blocks) +# pylint: disable=R0915 (too-many-statements) def extract_row_data_from_fields_data( fields_data: Any, index: Any, @@ -305,8 +488,7 @@ def check_append(field_data: Any): entity_row_data.update(json_dict) return - tmp_dict = {k: v for k, v in json_dict.items() if k in dynamic_fields} - entity_row_data.update(tmp_dict) + entity_row_data.update({k: v for k, v in json_dict.items() if k in dynamic_fields}) return if field_data.type == DataType.ARRAY and len(field_data.scalars.array_data.data) >= index: entity_row_data[field_data.field_name] = extract_array_row_data(field_data, index) @@ -339,6 +521,10 @@ def check_append(field_data: Any): entity_row_data[field_data.field_name] = [ field_data.vectors.float16_vector[start_pos:end_pos] ] + elif field_data.type == DataType.SPARSE_FLOAT_VECTOR: + entity_row_data[field_data.field_name] = sparse_parse_single_row( + field_data.vectors.sparse_float_vector.contents[index] + ) for field_data in fields_data: check_append(field_data) diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 2c2dc2bdf..8865b55a3 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -757,7 +757,7 @@ def _execute_hybrid_search( def search( self, collection_name: str, - data: List[List[float]], + data: Union[List[List[float]], entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -940,6 +940,7 @@ def create_index( DataType.BINARY_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, + DataType.SPARSE_FLOAT_VECTOR, }: break diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index c7289f493..b8ec669ef 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -4,14 +4,14 @@ import ujson -from pymilvus.client import __version__ +from pymilvus.client import __version__, entity_helper from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError from pymilvus.grpc_gen import common_pb2 as common_types from pymilvus.grpc_gen import milvus_pb2 as milvus_types from pymilvus.grpc_gen import schema_pb2 as schema_types from pymilvus.orm.schema import CollectionSchema -from . import blob, entity_helper, ts_utils +from . import blob, ts_utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, @@ -456,11 +456,7 @@ def _pre_batch_check( fields_info: Any, ): for entity in entities: - if ( - not entity.get("name", None) - or not entity.get("values", None) - or not entity.get("type", None) - ): + if not entity.get("name", None) or entity.get("values", None) is None or not entity.get("type", None): raise ParamError( message="Missing param in entities, a field must have type, name and values" ) @@ -492,7 +488,7 @@ def _parse_batch_request( pre_field_size = 0 try: for entity in entities: - latest_field_size = len(entity.get("values")) + latest_field_size = entity_helper.get_input_num_rows(entity.get("values")) if pre_field_size not in (0, latest_field_size): raise ParamError( message=( @@ -576,6 +572,13 @@ def check_str(instr: str, prefix: str): def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is_binary: bool): pl = common_types.PlaceholderValue(tag=tag) pl.type = pl_type + # sparse vector + if pl_type == PlaceholderType.SparseFloatVector: + sparse_float_array_proto = entity_helper.sparse_rows_to_proto(vectors) + pl.values.extend(sparse_float_array_proto.contents) + return pl + + # dense or binary vector for i in range(nq): if is_binary: pl.values.append(blob.vector_binary_to_bytes(vectors[i])) @@ -587,7 +590,7 @@ def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is def search_requests_with_expr( cls, collection_name: str, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -597,7 +600,10 @@ def search_requests_with_expr( round_decimal: int = -1, **kwargs, ) -> milvus_types.SearchRequest: - if isinstance(data[0], bytes): + if entity_helper.entity_is_sparse_matrix(data): + is_binary = False + pl_type = PlaceholderType.SparseFloatVector + elif isinstance(data[0], bytes): is_binary = True pl_type = PlaceholderType.BinaryVector else: @@ -647,7 +653,7 @@ def dump(v: Dict): return ujson.dumps(v) return str(v) - nq = len(data) + nq = entity_helper.get_input_num_rows(data) tag = "$0" pl = cls._prepare_placeholders(data, nq, tag, pl_type, is_binary) plg = common_types.PlaceholderGroup() diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 6a7b1f0b1..3434e3ddd 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -89,6 +89,7 @@ class DataType(IntEnum): FLOAT_VECTOR = 101 FLOAT16_VECTOR = 102 BFLOAT16_VECTOR = 103 + SPARSE_FLOAT_VECTOR = 104 UNKNOWN = 999 @@ -158,6 +159,7 @@ class PlaceholderType(IntEnum): FloatVector = 101 FLOAT16_VECTOR = 102 BFLOAT16_VECTOR = 103 + SparseFloatVector = 104 class State(IntEnum): diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index a114cfebe..d0b029890 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -195,6 +195,8 @@ def len_of(field_data: Any) -> int: message=f"Invalid bfloat16 or float16 vector length: total_len={total_len}, dim={dim}" ) return int(total_len / (dim * data_wide_in_bytes)) + if field_data.vectors.HasField("sparse_float_vector"): + return len(field_data.vectors.sparse_float_vector.contents) total_len = len(field_data.vectors.binary_vector) return int(total_len / (dim / 8)) @@ -239,7 +241,7 @@ def traverse_rows_info(fields_info: Any, entities: List): value = entity.get(field_name, None) if value is None: raise ParamError(message=f"Field {field_name} don't match in entities[{j}]") - + # no special check for sparse float vector field if field_type in [ DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 9f122801a..e20b137e0 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -151,7 +151,7 @@ class ExceptionsMessage: ) AliasType = "Alias should be string, but %r is given." ConnLackConf = "You need to pass in the configuration of the connection named %r ." - ConnectFirst = "should create connect first." + ConnectFirst = "should create connection first." CollectionNotExistNoSchema = "Collection %r not exist, or you can pass in schema to create one." NoSchema = "Should be passed into the schema." EmptySchema = "The field of the schema cannot be empty." diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index dccc0b6d5..2e2459721 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -16,6 +16,7 @@ import pandas as pd +from pymilvus.client import entity_helper from pymilvus.client.abstract import BaseRanker, SearchResult from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL from pymilvus.client.types import ( @@ -28,6 +29,7 @@ from pymilvus.exceptions import ( AutoIDException, DataTypeNotMatchException, + DataTypeNotSupportException, ExceptionsMessage, IndexNotExistException, PartitionAlreadyExistException, @@ -168,6 +170,7 @@ def __repr__(self) -> str: def _get_connection(self): return connections._fetch_handler(self._using) + # TODO(SPARSE): support pd.SparseDtype @classmethod def construct_from_dataframe(cls, name: str, dataframe: pd.DataFrame, **kwargs): if not isinstance(dataframe, pd.DataFrame): @@ -449,7 +452,7 @@ def release(self, timeout: Optional[float] = None, **kwargs): def insert( self, - data: Union[List, pd.DataFrame, Dict], + data: Union[List, pd.DataFrame, Dict, entity_helper.SparseMatrixInputType], partition_name: Optional[str] = None, timeout: Optional[float] = None, **kwargs, @@ -457,7 +460,7 @@ def insert( """Insert data into the collection. Args: - data (``list/tuple/pandas.DataFrame``): The specified data to insert + data (``list/tuple/pandas.DataFrame/sparse types``): The specified data to insert partition_name (``str``): The partition name which the data will be inserted to, if partition name is not passed, then the data will be inserted to default partition @@ -573,7 +576,7 @@ def delete( def upsert( self, - data: Union[List, pd.DataFrame, Dict], + data: Union[List, pd.DataFrame, Dict, entity_helper.SparseMatrixInputType], partition_name: Optional[str] = None, timeout: Optional[float] = None, **kwargs, @@ -581,7 +584,7 @@ def upsert( """Upsert data into the collection. Args: - data (``list/tuple/pandas.DataFrame``): The specified data to upsert + data (``list/tuple/pandas.DataFrame/sparse types``): The specified data to upsert partition_name (``str``): The partition name which the data will be upserted at, if partition name is not passed, then the data will be upserted in default partition @@ -645,7 +648,7 @@ def upsert( def search( self, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -659,7 +662,7 @@ def search( """Conducts a vector similarity search with an optional boolean expression as filter. Args: - data (``List[List[float]]``): The vectors of search data. + data (``List[List[float]]/sparse types``): The vectors of search data. the length of data is number of query (nq), and the dim of every vector in data must be equal to the vector field of collection. anns_field (``str``): The name of the vector field used to search of collection. @@ -780,7 +783,8 @@ def search( if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) - if isinstance(data, list) and len(data) == 0: + empty_scipy_sparse = entity_helper.sparse_is_scipy_format(data) and (data.shape[0] == 0) + if (isinstance(data, list) and len(data) == 0) or empty_scipy_sparse: resp = SearchResult(schema_pb2.SearchResultData()) return SearchFuture(None) if kwargs.get("_async", False) else resp @@ -946,7 +950,7 @@ def hybrid_search( def search_iterator( self, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, batch_size: Optional[int] = 1000, @@ -958,6 +962,10 @@ def search_iterator( round_decimal: int = -1, **kwargs, ): + if entity_helper.entity_is_sparse_matrix(data): + # search iterator is based on range_search, which is not yet supported for sparse. + raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) + if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) return SearchIterator( diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index ac6b7ca8e..ea25bf5f5 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,7 +1,8 @@ import logging from copy import deepcopy -from typing import Any, Dict, List, Optional, TypeVar +from typing import Any, Dict, List, Optional, TypeVar, Union +from pymilvus.client import entity_helper from pymilvus.client.abstract import Hits, LoopBase from pymilvus.exceptions import ( MilvusException, @@ -282,7 +283,7 @@ def __init__( self, connection: Connections, collection_name: str, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], ann_field: str, param: Dict, batch_size: Optional[int] = 1000, @@ -295,11 +296,12 @@ def __init__( schema: Optional[CollectionSchema] = None, **kwargs, ) -> SearchIterator: - if len(data) > 1: + rows = entity_helper.get_input_num_rows(data) + if rows > 1: raise ParamError( message="Not support search iteration over multiple vectors at present" ) - if len(data) == 0: + if rows == 0: raise ParamError(message="vector_data for search cannot be empty") self._conn = connection self._iterator_params = { diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index a54a3b466..cdcc6d6e2 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -15,6 +15,7 @@ import pandas as pd import ujson +from pymilvus.client import entity_helper from pymilvus.client.abstract import BaseRanker, SearchResult from pymilvus.client.types import Replica from pymilvus.exceptions import MilvusException @@ -238,14 +239,14 @@ def release(self, timeout: Optional[float] = None, **kwargs): def insert( self, - data: Union[List, pd.DataFrame], + data: Union[List, pd.DataFrame, entity_helper.SparseMatrixInputType], timeout: Optional[float] = None, **kwargs, ) -> MutationResult: """Insert data into the partition, the same as Collection.insert(data, [partition]) Args: - data (``list/tuple/pandas.DataFrame``): The specified data to insert + data (``list/tuple/pandas.DataFrame/sparse types``): The specified data to insert partition_name (``str``): The partition name which the data will be inserted to, if partition name is not passed, then the data will be inserted to default partition timeout (``float``, optional): A duration of time in seconds to allow for the RPC @@ -316,14 +317,14 @@ def delete(self, expr: str, timeout: Optional[float] = None, **kwargs): def upsert( self, - data: Union[List, pd.DataFrame], + data: Union[List, pd.DataFrame, entity_helper.SparseMatrixInputType], timeout: Optional[float] = None, **kwargs, ) -> MutationResult: """Upsert data into the collection. Args: - data (``list/tuple/pandas.DataFrame``): The specified data to upsert + data (``list/tuple/pandas.DataFrame/sparse types``): The specified data to upsert partition_name (``str``): The partition name which the data will be upserted at, if partition name is not passed, then the data will be upserted in default partition timeout (``float``, optional): A duration of time in seconds to allow for the RPC. @@ -356,7 +357,7 @@ def upsert( def search( self, - data: List, + data: Union[List, entity_helper.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -369,7 +370,7 @@ def search( """Conducts a vector similarity search with an optional boolean expression as filter. Args: - data (``List[List[float]]``): The vectors of search data. + data (``List[List[float]]`` or sparse types): The vectors of search data. the length of data is number of query (nq), and the dim of every vector in data must be equal to the vector field of collection. anns_field (``str``): The name of the vector field used to search of collection. diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index d2823a10b..0d5d320c2 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +from pymilvus.client import entity_helper from pymilvus.exceptions import ( DataNotMatchException, DataTypeNotSupportException, @@ -46,6 +47,7 @@ def prepare_insert_data( and not data[schema.primary_field.name].isnull().all() ): raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData) + # TODO(SPARSE): support pd.SparseDtype for sparse float vector field for field in fields: if field.is_primary and field.auto_id: continue @@ -79,7 +81,7 @@ def prepare_insert_data( @classmethod def prepare_upsert_data( cls, - data: Union[List, Tuple, pd.DataFrame], + data: Union[List, Tuple, pd.DataFrame, entity_helper.SparseMatrixInputType], schema: CollectionSchema, ) -> List: if schema.auto_id: diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index bbfcacedd..5939eae62 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -295,6 +295,7 @@ def _parse_type_params(self): DataType.BFLOAT16_VECTOR, DataType.VARCHAR, DataType.ARRAY, + DataType.SPARSE_FLOAT_VECTOR, ): return if not self._kwargs: @@ -485,6 +486,7 @@ def construct_fields_from_dataframe(df: pd.DataFrame) -> List[FieldSchema]: def prepare_fields_from_dataframe(df: pd.DataFrame): + # TODO: infer pd.SparseDtype as DataType.SPARSE_FLOAT_VECTOR d_types = list(df.dtypes) data_types = list(map(map_numpy_dtype_to_datatype, d_types)) col_names = list(df.columns) @@ -532,6 +534,7 @@ def check_schema(schema: CollectionSchema): DataType.BINARY_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, + DataType.SPARSE_FLOAT_VECTOR, ): vector_fields.append(field.name) if len(vector_fields) < 1: diff --git a/pymilvus/orm/types.py b/pymilvus/orm/types.py index 1bb4eb715..eebbb50bc 100644 --- a/pymilvus/orm/types.py +++ b/pymilvus/orm/types.py @@ -72,7 +72,7 @@ def is_numeric_datatype(data_type: DataType): # pylint: disable=too-many-return-statements -def infer_dtype_by_scaladata(data: Any): +def infer_dtype_by_scalar_data(data: Any): if isinstance(data, float): return DataType.DOUBLE if isinstance(data, bool): @@ -108,7 +108,7 @@ def infer_dtype_by_scaladata(data: Any): def infer_dtype_bydata(data: Any): d_type = DataType.UNKNOWN if is_scalar(data): - return infer_dtype_by_scaladata(data) + return infer_dtype_by_scalar_data(data) if isinstance(data, dict): return DataType.JSON @@ -130,7 +130,7 @@ def infer_dtype_bydata(data: Any): elem = None if elem is not None and is_scalar(elem): - d_type = infer_dtype_by_scaladata(elem) + d_type = infer_dtype_by_scalar_data(elem) if d_type == DataType.UNKNOWN: _dtype = getattr(data, "dtype", None) diff --git a/requirements.txt b/requirements.txt index a910ad867..1cecf2ea6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ ujson>=2.0.0 urllib3==1.26.18 sklearn==0.0 m2r==0.3.1 +scipy>=1.9.3 Sphinx==4.0.0 sphinx-copybutton sphinx-rtd-theme