Skip to content

Commit

Permalink
normalize vector data to a standard form during insertion (#27469)
Browse files Browse the repository at this point in the history
Signed-off-by: NamCaoHai <[email protected]>
  • Loading branch information
CaoHaiNam committed Nov 8, 2024
1 parent 113aeb1 commit 3015862
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
16 changes: 16 additions & 0 deletions examples/example_normalization_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong", normalization_fields=['embddings'])
except BaseException as e:
print(e)

print(fmt.format("Insert data, without conversion to standard form"))

hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong")

print(fmt.format("Start inserting a row"))
rng = np.random.default_rng(seed=19530)

row = {
"pk": "19530",
"random": 0.5,
"embeddings1": rng.random((1, dim), np.float32)[0],
"embeddings2": rng.random((1, dim), np.float32)[0]
}
hello_milvus.insert(row)
utility.drop_collection("hello_milvus")

print(fmt.format("Insert data, with conversion to standard form"))

Expand Down
25 changes: 10 additions & 15 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import ujson

from pymilvus.exceptions import MilvusException, ParamError
from pymilvus.grpc_gen.common_pb2 import Status

from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK
from .types import DataType
import numpy as np

MILVUS = "milvus"
ZILLIZ = "zilliz"
Expand Down Expand Up @@ -378,23 +378,18 @@ def is_scipy_sparse(cls, data: Any):
]


def convert_to_standard_form(vector_data):
def convert_to_standard_form(vector_data: Any) -> Any:
if len(vector_data.shape) == 1:
# Calculate the mean and standard deviation of the vector
mean = np.mean(vector_data)
std_dev = np.std(vector_data)

# Standardize the vector
standardized_vector = (vector_data - mean) / std_dev if std_dev != 0 else vector_data
return standardized_vector

else:
# Calculate mean and standard deviation for each row
row_means = np.mean(vector_data, axis=1, keepdims=True)
row_stds = np.std(vector_data, axis=1, keepdims=True)

# Standardize each row independently
standardized_matrix = np.where(
row_stds != 0, (vector_data - row_means) / row_stds, vector_data
)
return standardized_matrix
return (vector_data - mean) / std_dev if std_dev != 0 else vector_data

# Calculate mean and standard deviation for each row
row_means = np.mean(vector_data, axis=1, keepdims=True)
row_stds = np.std(vector_data, axis=1, keepdims=True)

# Standardize each row independently
return np.where(row_stds != 0, (vector_data - row_means) / row_stds, vector_data)
2 changes: 1 addition & 1 deletion pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,4 @@ class ExceptionsMessage:
"Default value cannot be None for a field that is defined as nullable == false."
)
InvalidVectorFields = "%s is not a valid vector field; expected %s"
InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s."
InvalidNormalizationParam = "Unexpected normalization_fields parameters. Expected 'all' or a list of fields (e.g., [field1, field2, ...]), but got %s."
15 changes: 8 additions & 7 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
DataTypeNotSupportException,
ExceptionsMessage,
IndexNotExistException,
MilvusException,
PartitionAlreadyExistException,
SchemaNotReadyException,
MilvusException,
)
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(
self._using = using
self._kwargs = kwargs
self._num_shards = None
self._normalization_fields = None
conn = self._get_connection()

has = conn.has_collection(self._name, **kwargs)
Expand Down Expand Up @@ -157,7 +158,7 @@ def __init__(
self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

self._normalization_fields = kwargs.get("normalization_fields", None)
self._normalization_fields = kwargs.get("normalization_fields")
if self._normalization_fields:
self._vector_fields = self._get_vector_fields()
if self._normalization_fields == "all":
Expand Down Expand Up @@ -540,10 +541,10 @@ def insert(
schema=self._schema_dict,
**kwargs,
)

for idx, fld in enumerate(self._schema_dict["fields"]):
if fld["name"] in self._normalization_fields:
data[idx] = utils.convert_to_standard_form(data[idx])
if self._normalization_fields:
for idx, fld in enumerate(self._schema_dict["fields"]):
if fld["name"] in self._normalization_fields:
data[idx] = utils.convert_to_standard_form(data[idx])
check_insert_schema(self.schema, data)
entities = Prepare.prepare_data(data, self.schema)
return conn.batch_insert(
Expand Down Expand Up @@ -1622,4 +1623,4 @@ def get_replicas(self, timeout: Optional[float] = None, **kwargs) -> Replica:

def describe(self, timeout: Optional[float] = None):
conn = self._get_connection()
return conn.describe_collection(self.name, timeout=timeout)
return conn.describe_collection(self.name, timeout=timeout)

0 comments on commit 3015862

Please sign in to comment.