Skip to content

Commit

Permalink
normalize vector data to a standard form during insertion (#27469)
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoHaiNam committed Nov 4, 2024
1 parent 3ee9e10 commit f77c27a
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
112 changes: 112 additions & 0 deletions examples/example_normalization_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import time

import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
MilvusClient
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8


print(fmt.format("start connecting to Milvus"))
# this is milvus standalone
connection = connections.connect(
alias="default",
host='localhost', # or '0.0.0.0' or 'localhost'
port='19530'
)

client = MilvusClient(connections=connection)

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")
if has:
utility.drop_collection("hello_milvus")

fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings1", dtype=DataType.FLOAT_VECTOR, dim=dim),
FieldSchema(name="embeddings2", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs")

print(fmt.format("Create collection `hello_milvus`"))

print(fmt.format("Message for handling an invalid format in the normalization_fields value")) # you can try with other value like: dict,...
try:
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields='embeddings1')
except BaseException as e:
print(e)


print(fmt.format("Message for handling the invalid vector fields"))
try:
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings'])
except BaseException as e:
print(e)

print(fmt.format("Insert data, with conversion to standard form"))

hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embeddings1'])

print(fmt.format("Start inserting a row"))
rng = np.random.default_rng(seed=19530)

row = {
"pk": "19530",
"random": 0.5,
"embeddings1": rng.random((1, dim), np.float32)[0],
"embeddings2": rng.random((1, dim), np.float32)[0]
}
_row = row.copy()
hello_milvus.insert(row)

index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}}
hello_milvus.create_index("embeddings1", index_param)
hello_milvus.create_index("embeddings2", index_param)
hello_milvus.load()

original_vector = _row['embeddings1']
insert_vector = hello_milvus.query(
expr="pk == '19530'",
output_fields=["embeddings1"],
)[0]['embeddings1']

print(fmt.format("Mean and standard deviation before normalization."))
print("Mean: ", np.mean(original_vector))
print("Std: ", np.std(original_vector))

print(fmt.format("Mean and standard deviation after normalization."))
print("Mean: ", np.mean(insert_vector))
print("Std: ", np.std(insert_vector))


print(fmt.format("Start inserting entities"))

entities = [
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(),
rng.random((num_entities, dim), np.float32),
rng.random((num_entities, dim), np.float32),
]

insert_result = hello_milvus.insert(entities)

insert_vector = hello_milvus.query(
expr="pk == '1'",
output_fields=["embeddings1"],
)[0]['embeddings1']

print(fmt.format("Mean and standard deviation after normalization."))
print("Mean: ", np.mean(insert_vector))
print("Std: ", np.std(insert_vector))

utility.drop_collection("hello_milvus")
21 changes: 21 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK
from .types import DataType
import numpy as np

MILVUS = "milvus"
ZILLIZ = "zilliz"
Expand Down Expand Up @@ -375,3 +376,23 @@ def is_scipy_sparse(cls, data: Any):
"csr_array",
"spmatrix",
]

def convert_to_standard_form(vector_data):

if len(vector_data.shape) == 1:
# Calculate the mean and standard deviation of the vector
mean = np.mean(vector_data)
std_dev = np.std(vector_data)

# Standardize the vector
standardized_vector = (vector_data - mean) / std_dev if std_dev != 0 else vector_data
return standardized_vector

else:
# Calculate mean and standard deviation for each row
row_means = np.mean(vector_data, axis=1, keepdims=True)
row_stds = np.std(vector_data, axis=1, keepdims=True)

# Standardize each row independently
standardized_matrix = np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data)
return standardized_matrix
4 changes: 4 additions & 0 deletions pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,7 @@ class ExceptionsMessage:
DefaultValueInvalid = (
"Default value cannot be None for a field that is defined as nullable == false."
)
InvalidVectorFields = (
"%s is not a valid vector field; expected %s"
)
InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s."
25 changes: 25 additions & 0 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,25 @@ def __init__(
self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

self._normalization_fields = kwargs.get("normalization_fields", None)
if self._normalization_fields:
self._vector_fields = self._get_vector_fields()
if self._normalization_fields == 'all':
self._normalization_fields = self._vector_fields
elif isinstance(self._normalization_fields, list):
for field in self._normalization_fields:
if field not in self._vector_fields:
raise BaseException(ExceptionsMessage.InvalidVectorFields % (field, ', '.join(self._vector_fields)))
else:
raise BaseException(ExceptionsMessage.InvalidNormalizationParam % (self._normalization_fields))

def _get_vector_fields(self):
vector_fields = []
for field in self._schema_dict.get("fields", []):
if field.get("params", {}).get("dim", None):
vector_fields.append(field.get("name"))
return vector_fields

def __repr__(self) -> str:
_dict = {
"name": self.name,
Expand Down Expand Up @@ -504,6 +523,9 @@ def insert(

conn = self._get_connection()
if is_row_based(data):
if self._normalization_fields:
for norm_fld in self._normalization_fields:
data[norm_fld] = utils.convert_to_standard_form(data[norm_fld])
return conn.insert_rows(
collection_name=self._name,
entities=data,
Expand All @@ -513,6 +535,9 @@ def insert(
**kwargs,
)

for idx, fld in enumerate(self._schema_dict['fields']):
if fld['name'] in self._normalization_fields:
data[idx] = utils.convert_to_standard_form(data[idx])
check_insert_schema(self.schema, data)
entities = Prepare.prepare_data(data, self.schema)
return conn.batch_insert(
Expand Down

0 comments on commit f77c27a

Please sign in to comment.