Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add sparse float vector support to PyMilvus #1902

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions examples/hello_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# hello_sprase.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus,
# while operating on sparse float vectors.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from scipy.sparse import rand
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "=== {:30} ==="
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim, density = 1000, 3000, 0.005

def log(msg):
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)

# -----------------------------------------------------------------------------
# connect to Milvus
log(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_sparse")
log(f"Does collection hello_sparse exist in Milvus: {has}")

# -----------------------------------------------------------------------------
# create collection with a sparse float vector column
hello_sparse = None
if not has:
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(fields, "hello_sparse is the simplest demo to introduce sparse float vector usage")
log(fmt.format("Create collection `hello_sparse`"))
hello_sparse = Collection("hello_sparse", schema, consistency_level="Strong")
else:
hello_sparse = Collection("hello_sparse")

log(f"hello_sparse has {hello_sparse.num_entities} entities({hello_sparse.num_entities/1000000}M), indexed {hello_sparse.has_index()}")

# -----------------------------------------------------------------------------
# insert
log(fmt.format("Start creating entities to insert"))
rng = np.random.default_rng(seed=19530)
# this step is so damn slow
matrix_csr = rand(num_entities, dim, density=density, format='csr')
entities = [
rng.random(num_entities).tolist(),
matrix_csr,
]

log(fmt.format("Start inserting entities"))
insert_result = hello_sparse.insert(entities)

# -----------------------------------------------------------------------------
# create index
if not hello_sparse.has_index():
log(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
index = {
"index_type": "SPARSE_INVERTED_INDEX",
"metric_type": "IP",
"params":{
"drop_ratio_build": 0.2,
}
}
hello_sparse.create_index("embeddings", index)

log(fmt.format("Start loading"))
hello_sparse.load()

# -----------------------------------------------------------------------------
# search based on vector similarity
log(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-1][-1:]
search_params = {
"metric_type": "IP",
"params": {
"drop_ratio_search": "0.2",
}
}

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "random", "embeddings"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}")
log(search_latency_fmt.format(end_time - start_time))
# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
print(fmt.format("Start querying with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.query(expr="random > 0.5", output_fields=["random", "embeddings"])
end_time = time.time()

print(f"query result:\n-{result[0]}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_sparse.query(expr="random > 0.5", limit=4, output_fields=["random"])
r2 = hello_sparse.query(expr="random > 0.5", offset=1, limit=3, output_fields=["random"])
print(f"query pagination(limit=4):\n\t{r1}")
print(f"query pagination(offset=1, limit=3):\n\t{r2}")


# -----------------------------------------------------------------------------
# hybrid search
print(fmt.format("Start hybrid searching with `random > 0.5`"))

start_time = time.time()
result = hello_sparse.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

hello_sparse.delete(expr)

result = hello_sparse.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


# -----------------------------------------------------------------------------
# drop collection
print(fmt.format("Drop collection `hello_sparse`"))
utility.drop_collection("hello_sparse")
88 changes: 88 additions & 0 deletions examples/milvus_client/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pymilvus import (
MilvusClient,
FieldSchema, CollectionSchema, DataType,
)

import random

def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict:
indices = random.sample(range(dimension), non_zero_count)
values = [random.random() for _ in range(non_zero_count)]
sparse_vector = {index: value for index, value in zip(indices, values)}
return sparse_vector


fmt = "\n=== {:30} ===\n"
dim = 100
non_zero_count = 20
collection_name = "hello_sparse"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR,
is_primary=True, auto_id=True, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.SPARSE_FLOAT_VECTOR),
]
schema = CollectionSchema(
fields, "demo for using sparse float vector with milvus client")
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name="embeddings", index_name="sparse_inverted_index",
index_type="SPARSE_INVERTED_INDEX", metric_type="IP", params={"drop_ratio_build": 0.2})
milvus_client.create_collection(collection_name, schema=schema,
index_params=index_params, timeout=5, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

N = 6
rows = [{"random": i, "embeddings": generate_sparse_vector(
dim, non_zero_count)} for i in range(N)]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows, progress_bar=True)
print(fmt.format("Inserting entities done"))
print(insert_result)

print(fmt.format(f"Start vector anns search."))
vectors_to_search = [generate_sparse_vector(dim, non_zero_count)]
search_params = {
"metric_type": "IP",
"params": {
"drop_ratio_search": 0.2,
}
}
# no need to specify anns_field for collections with only 1 vector field
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=[
"pk", "random", "embeddings"], search_params=search_params)
for hits in result:
for hit in hits:
print(f"hit: {hit}")

print(fmt.format("Start query by specifying filtering expression"))
query_results = milvus_client.query(collection_name, filter="random < 3")
pks = [ret['pk'] for ret in query_results]
for ret in query_results:
print(ret)

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(
collection_name, filter=f"pk == '{pks[0]}'")
print(query_results[0])

print(f"start to delete by specifying filter in collection {collection_name}")
delete_result = milvus_client.delete(collection_name, ids=pks[:1])
print(delete_result)

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(
collection_name, filter=f"pk == '{pks[0]}'")
print(f'query result should be empty: {query_results}')

milvus_client.drop_collection(collection_name)
18 changes: 11 additions & 7 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config

from . import entity_helper
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType

Expand All @@ -15,7 +16,6 @@ class FieldSchema:
def __init__(self, raw: Any):
self._raw = raw

#
self.field_id = 0
self.name = None
self.is_primary = False
Expand All @@ -29,7 +29,6 @@ def __init__(self, raw: Any):
# For array field
self.element_type = None
self.is_clustering_key = False
##
self.__pack(self._raw)

def __pack(self, raw: Any):
Expand Down Expand Up @@ -106,7 +105,6 @@ class CollectionSchema:
def __init__(self, raw: Any):
self._raw = raw

#
self.collection_name = None
self.description = None
self.params = {}
Expand All @@ -121,7 +119,6 @@ def __init__(self, raw: Any):
self.num_partitions = 0
self.enable_dynamic_field = False

#
if self._raw:
self.__pack(self._raw)

Expand Down Expand Up @@ -330,7 +327,7 @@ def dict(self):
class AnnSearchRequest:
def __init__(
self,
data: List,
data: Union[List, entity_helper.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
Expand Down Expand Up @@ -472,6 +469,13 @@ def get_fields_by_range(
field_meta,
)
continue
# TODO(SPARSE): do we want to allow the user to specify the return format?
if dtype == DataType.SPARSE_FLOAT_VECTOR:
field2data[name] = (
entity_helper.sparse_proto_to_rows(vectors.sparse_float_vector, start, end),
field_meta,
)
continue

if dtype == DataType.BFLOAT16_VECTOR:
field2data[name] = (
Expand Down Expand Up @@ -527,7 +531,7 @@ def __init__(
for fname, (data, field_meta) in fields.items():
if len(data) <= i:
curr_field[fname] = None
# Get vectors
# Get dense vectors
if field_meta.type in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
Expand All @@ -552,7 +556,7 @@ def __init__(
curr_field.update(data[i])
continue

# normal fields
# sparse float vector and other fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
Expand Down
28 changes: 10 additions & 18 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pymilvus.exceptions import ParamError
from pymilvus.grpc_gen import milvus_pb2 as milvus_types

from . import entity_helper
from .singleton_utils import Singleton


Expand Down Expand Up @@ -40,24 +41,6 @@ def is_legal_port(port: Any) -> bool:
return False


def is_legal_vector(array: Any) -> bool:
if not array or not isinstance(array, list) or len(array) == 0:
return False

return True


def is_legal_bin_vector(array: Any) -> bool:
if not array or not isinstance(array, bytes) or len(array) == 0:
return False

return True


def is_legal_numpy_array(array: Any) -> bool:
return not (array is None or array.size == 0)


def int_or_str(item: Union[int, str]) -> str:
if isinstance(item, int):
return str(item)
Expand Down Expand Up @@ -149,6 +132,10 @@ def is_legal_max_iterations(max_iterations: Any) -> bool:
return isinstance(max_iterations, int)


def is_legal_drop_ratio(drop_ratio: Any) -> bool:
return isinstance(drop_ratio, float) and 0 <= drop_ratio < 1


def is_legal_team_size(team_size: Any) -> bool:
return isinstance(team_size, int)

Expand Down Expand Up @@ -197,6 +184,9 @@ def is_legal_anns_field(field: Any) -> bool:
def is_legal_search_data(data: Any) -> bool:
import numpy as np

if entity_helper.entity_is_sparse_matrix(data):
return True

if not isinstance(data, (list, np.ndarray)):
return False

Expand Down Expand Up @@ -331,6 +321,8 @@ def __init__(self) -> None:
"team_size": is_legal_team_size,
"index_name": is_legal_index_name,
"timeout": is_legal_timeout,
"drop_ratio_build": is_legal_drop_ratio,
"drop_ratio_search": is_legal_drop_ratio,
}

def check(self, key: str, value: Callable):
Expand Down
Loading