Skip to content

Commit

Permalink
Add basic sparse float vector support to pymilvus. for now only scipy…
Browse files Browse the repository at this point in the history
….sparse.csr_matrix as input is supported

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Jan 29, 2024
1 parent d53344c commit 81baaaa
Show file tree
Hide file tree
Showing 21 changed files with 1,062 additions and 803 deletions.
153 changes: 153 additions & 0 deletions examples/hello_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# hello_sprase.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus,
# while operating on sparse float vectors.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from scipy.sparse import rand
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "=== {:30} ==="
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim, density = 10000, 60000, 0.00005

def log(msg):
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)

# -----------------------------------------------------------------------------
# connect to Milvus
log(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_sparse")
log(f"Does collection hello_sparse exist in Milvus: {has}")

# -----------------------------------------------------------------------------
# create collection with a sparse float vector column
hello_sparse = None
if not has:
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(fields, "hello_sparse is the simplest demo to introduce sparse float vector usage")
log(fmt.format("Create collection `hello_sparse`"))
hello_sparse = Collection("hello_sparse", schema, consistency_level="Strong")
else:
hello_sparse = Collection("hello_sparse")

# -----------------------------------------------------------------------------
# create index
if not hello_sparse.has_index():
log(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
index = {
"index_type": "SPARSE_INVERTED_INDEX",
"metric_type": "IP",
"params":{
"drop_ratio_build": 0.2,
}
}
hello_sparse.create_index("embeddings", index)

log(fmt.format("Start loading"))
hello_sparse.load()

log(f"hello_sparse has {hello_sparse.num_entities} entities({hello_sparse.num_entities/1000000}M), indexed {hello_sparse.has_index()}")

# -----------------------------------------------------------------------------
# insert
log(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)
matrix_csr = rand(num_entities, dim, density=density, format='csr')
entities = [
rng.random(num_entities).tolist(),
matrix_csr,
]

insert_result = hello_sparse.insert(entities)

# -----------------------------------------------------------------------------
# search based on vector similarity
log(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-1][-1:]
search_params = {
"metric_type": "IP",
"params": {
"drop_ratio_search": "0.2",
}
}

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "random", "embeddings"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}")
log(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
print(fmt.format("Start querying with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.query(expr="random > 0.5", output_fields=["random", "embeddings"])
end_time = time.time()

print(f"query result:\n-{result[0]}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_sparse.query(expr="random > 0.5", limit=4, output_fields=["random"])
r2 = hello_sparse.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"])
print(f"query pagination(limit=4):\n\t{r1}")
print(f"query pagination(offset=1, limit=3):\n\t{r2}")


# -----------------------------------------------------------------------------
# hybrid search
print(fmt.format("Start hybrid searching with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

hello_sparse.delete(expr)

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


# -----------------------------------------------------------------------------
# drop collection
print(fmt.format("Drop collection `hello_sparse`"))
utility.drop_collection("hello_sparse")
9 changes: 7 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType
from . import entity_helper


class FieldSchema:
Expand Down Expand Up @@ -467,6 +468,10 @@ def get_fields_by_range(
field_meta,
)
continue
if dtype == DataType.SPARSE_FLOAT_VECTOR:
field2data[name] = entity_helper.sparse_float_array_to_rows(
vectors.sparse_float_vector, start, end), field_meta
continue

return field2data

Expand Down Expand Up @@ -509,7 +514,7 @@ def __init__(
for fname, (data, field_meta) in fields.items():
if len(data) <= i:
curr_field[fname] = None
# Get vectors
# Get dense vectors
if field_meta.type in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
dim = field_meta.vectors.dim
dim = dim // 8 if field_meta.type == DataType.BINARY_VECTOR else dim
Expand All @@ -527,7 +532,7 @@ def __init__(
curr_field.update(data[i])
continue

# normal fields
# sparse float vector and normal fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import sys
from typing import Any, Callable, Union
import scipy

from pymilvus.exceptions import ParamError
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
Expand Down Expand Up @@ -195,6 +196,9 @@ def is_legal_anns_field(field: Any) -> bool:

def is_legal_search_data(data: Any) -> bool:
import numpy as np
# TODO(SPARSE): support other format of sparse vector representation
if isinstance(data, scipy.sparse.csr_matrix):
return True

if not isinstance(data, (list, np.ndarray)):
return False
Expand Down
27 changes: 27 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import ujson
import struct
from scipy.sparse import csr_matrix

from pymilvus.exceptions import MilvusException, ParamError
from pymilvus.grpc_gen import schema_pb2 as schema_types
Expand All @@ -11,6 +13,29 @@

CHECK_STR_ARRAY = True

# TODO(SPARSE): support other format of sparse vector representation
def csr_to_sparse_float_array(csr):
if not isinstance(csr, csr_matrix) or csr.dtype != float:
raise ValueError("Matrix must be a float CSR matrix")
result = schema_types.SparseFloatArray()
result.dim = csr.shape[1]
for start, end in zip(csr.indptr[:-1], csr.indptr[1:]):
row = result.contents.add()
row.indices.data.extend(csr.indices[start:end])
row.values.data.extend(csr.data[start:end])
return result

# convert sparse proto to List[Dict[Int, Float]], each element(dict) is a sparse
# vector
def sparse_float_array_to_rows(sfv, start=None, end=None):
if not isinstance(sfv, schema_types.SparseFloatArray):
raise ValueError("Vector must be a sparse float vector")
start = start or 0
end = end or len(sfv.contents)
res = []
for row in sfv.contents[start:end]:
res.append(dict(zip(row.indices.data, row.values.data)))
return res

def entity_type_to_dtype(entity_type: Any):
if isinstance(entity_type, int):
Expand Down Expand Up @@ -179,6 +204,8 @@ def entity_to_field_data(entity: Any, field_info: Any):
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
elif entity_type == DataType.ARRAY:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
elif entity_type == DataType.SPARSE_FLOAT_VECTOR:
field_data.vectors.sparse_float_vector.CopyFrom(csr_to_sparse_float_array(entity.get("values")))
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")

Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def create_index(
DataType.BINARY_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.SPARSE_FLOAT_VECTOR,
}:
break

Expand Down
23 changes: 20 additions & 3 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pymilvus.grpc_gen import schema_pb2 as schema_types
from pymilvus.orm.schema import CollectionSchema

from scipy.sparse import csr_matrix

from . import blob, entity_helper, ts_utils
from .check import check_pass_param, is_legal_collection_properties
from .constants import (
Expand Down Expand Up @@ -452,9 +454,13 @@ def _pre_batch_check(
fields_info: Any,
):
for entity in entities:
values = entity.get("values", None)
# csr_matrix object cannot be used as Boolean
if isinstance(values, csr_matrix):
values = True
if (
not entity.get("name", None)
or not entity.get("values", None)
or not values
or not entity.get("type", None)
):
raise ParamError(
Expand Down Expand Up @@ -488,7 +494,8 @@ def _parse_batch_request(
pre_field_size = 0
try:
for entity in entities:
latest_field_size = len(entity.get("values"))
values = entity.get("values")
latest_field_size = values.shape[0] if isinstance(values, csr_matrix) else len(entity.get("values"))
if pre_field_size not in (0, latest_field_size):
raise ParamError(
message=(
Expand Down Expand Up @@ -570,6 +577,12 @@ def check_str(instr: str, prefix: str):
def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is_binary: bool):
pl = common_types.PlaceholderValue(tag=tag)
pl.type = pl_type
# sparse vector
if pl_type == PlaceholderType.SparseFloatVector:
pl.values.append(entity_helper.csr_to_sparse_float_array(vectors).SerializeToString())
return pl

# dense or binary vector
for i in range(nq):
if is_binary:
pl.values.append(blob.vector_binary_to_bytes(vectors[i]))
Expand All @@ -594,6 +607,10 @@ def search_requests_with_expr(
if isinstance(data[0], bytes):
is_binary = True
pl_type = PlaceholderType.BinaryVector
# TODO(SPARSE): support other format of sparse vector representation
elif isinstance(data, csr_matrix):
is_binary = False
pl_type = PlaceholderType.SparseFloatVector
else:
is_binary = False
pl_type = PlaceholderType.FloatVector
Expand Down Expand Up @@ -637,7 +654,7 @@ def dump(v: Dict):
return ujson.dumps(v)
return str(v)

nq = len(data)
nq = data.shape[0] if isinstance(data, csr_matrix) else len(data)
tag = "$0"
pl = cls._prepare_placeholders(data, nq, tag, pl_type, is_binary)
plg = common_types.PlaceholderGroup()
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class DataType(IntEnum):
FLOAT_VECTOR = 101
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103
SPARSE_FLOAT_VECTOR = 104

UNKNOWN = 999

Expand Down Expand Up @@ -158,6 +159,7 @@ class PlaceholderType(IntEnum):
FloatVector = 101
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103
SparseFloatVector = 104


class State(IntEnum):
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def len_of(field_data: Any) -> int:
message=f"Invalid vector length: total_len={total_len}, dim={dim}"
)
return int(total_len / dim)
if field_data.vectors.HasField("sparse_float_vector"):
return len(field_data.vectors.sparse_float_vector.contents)

total_len = len(field_data.vectors.binary_vector)
return int(total_len / (dim / 8))
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class ExceptionsMessage:
)
AliasType = "Alias should be string, but %r is given."
ConnLackConf = "You need to pass in the configuration of the connection named %r ."
ConnectFirst = "should create connect first."
ConnectFirst = "should create connection first."
CollectionNotExistNoSchema = "Collection %r not exist, or you can pass in schema to create one."
NoSchema = "Should be passed into the schema."
EmptySchema = "The field of the schema cannot be empty."
Expand Down
Loading

0 comments on commit 81baaaa

Please sign in to comment.