Skip to content

Commit

Permalink
Add float16 Vector Approve (#1772)
Browse files Browse the repository at this point in the history
Signed-off-by: Writer-X <[email protected]>
  • Loading branch information
Writer-X authored Nov 8, 2023
1 parent 59c6f0c commit 0f74c2d
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 4 deletions.
70 changes: 70 additions & 0 deletions examples/float16_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import time
import random
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from pymilvus import MilvusClient

fp16_index_types = ["FLAT"]

default_fp16_index_params = [{"nlist": 128}]

def gen_fp16_vectors(num, dim):
raw_vectors = []
fp16_vectors = []
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
fp16_vector = np.array(raw_vector, dtype=np.float16).tobytes()
fp16_vectors.append(fp16_vector)
return raw_vectors, fp16_vectors

def fp16_vector_search():
connections.connect()

int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True)
dim = 128
nb = 3000
vector_field_name = "float16_vector"
fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, fp16_vector])

has = utility.has_collection("hello_milvus")
if has:
hello_milvus = Collection("hello_milvus_fp16")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_fp16", schema)

_, vectors = gen_fp16_vectors(nb, dim)
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
]

hello_milvus.insert(rows)
hello_milvus.flush()

for i, index_type in enumerate(fp16_index_types):
index_params = default_fp16_index_params[i]
hello_milvus.create_index(vector_field_name,
index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"})
hello_milvus.load()
print("index_type = ", index_type)
res = hello_milvus.search(vectors[:1], vector_field_name, {"metric_type": "L2"}, limit=1)
print(res)
hello_milvus.release()
hello_milvus.drop_index()

hello_milvus.drop()

if __name__ == "__main__":
fp16_vector_search()
3 changes: 3 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
elif field_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)
elif field_type == DataType.FLOAT16_VECTOR:
field_data.vectors.dim = len(field_value) // 2
field_data.vectors.float16_vector += bytes(field_value)
elif field_type == DataType.VARCHAR:
field_data.scalars.string_data.data.append(
convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY)
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DataType(IntEnum):

BINARY_VECTOR = 100
FLOAT_VECTOR = 101
FLOAT16_VECTOR = 102

UNKNOWN = 999

Expand Down Expand Up @@ -154,6 +155,7 @@ class PlaceholderType(IntEnum):
NoneType = 0
BinaryVector = 100
FloatVector = 101
Float16Vector = 102


class State(IntEnum):
Expand Down
11 changes: 9 additions & 2 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _parse_type_params(self):
if self._dtype not in (
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.VARCHAR,
DataType.ARRAY,
):
Expand Down Expand Up @@ -493,10 +494,16 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
for i, dtype in enumerate(data_types):
if dtype == DataType.UNKNOWN:
new_dtype = infer_dtype_bydata(values[i])
if new_dtype in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR):
if new_dtype in (
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
):
vector_type_params = {}
if new_dtype == DataType.BINARY_VECTOR:
vector_type_params["dim"] = len(values[i]) * 8
elif new_dtype == DataType.FLOAT16_VECTOR:
vector_type_params["dim"] = len(values[i]) / 2
else:
vector_type_params["dim"] = len(values[i])
column_params_map[col_names[i]] = vector_type_params
Expand All @@ -515,7 +522,7 @@ def check_schema(schema: CollectionSchema):
raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema)
vector_fields = []
for field in schema.fields:
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, DataType.FLOAT16_VECTOR):
vector_fields.append(field.name)
if len(vector_fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)
6 changes: 6 additions & 0 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def test_search_result_with_fields_data(self, pk):
binary_vector=os.urandom(6),
),
),
schema_pb2.FieldData(type=DataType.FLOAT16_VECTOR, field_name="float16_vector_field", field_id=115,
vectors=schema_pb2.VectorField(
dim=16,
float16_vector=os.urandom(32),
),
),
]
result = schema_pb2.SearchResultData(
fields_data=fields_data,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_collection_by_DataFrame(self):
fields = [
FieldSchema("int64", DataType.INT64),
FieldSchema("float", DataType.FLOAT),
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128)
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128),
FieldSchema("float16_vector", DataType.FLOAT16_VECTOR, dim=128)
]

prefix = "pymilvus.client.grpc_handler.GrpcHandler"
Expand Down
42 changes: 42 additions & 0 deletions tests/test_create_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,45 @@ def test_create_collection(self, collection_name):
return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"

def test_create_fp16_collection(self, collection_name):
id_field = {
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
}
vector_field = {
"name": "embedding",
"type": DataType.FLOAT16_VECTOR,
"metric_type": "L2",
"params": {"dim": "4"},
}
fields = {"fields": [id_field, vector_field], "enable_dynamic_field": True}
future = self._milvus.create_collection(
collection_name=collection_name, fields=fields, _async=True
)

invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary(
self._servicer.methods_by_name["CreateCollection"]
)
rpc.send_initial_metadata(())
rpc.terminate(
common_pb2.Status(
code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success"
),
(),
grpc.StatusCode.OK,
"",
)

request_schema = schema_pb2.CollectionSchema()
request_schema.ParseFromString(request.schema)

assert request.collection_name == collection_name
assert Fields.equal(request_schema.fields, fields["fields"])
assert request_schema.enable_dynamic_field == fields["enable_dynamic_field"]

return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"
17 changes: 17 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def raw_dict_binary_vector(self):
_dict["params"] = {"dim": 128}
return _dict

@pytest.fixture(scope="function")
def raw_dict_float16_vector(self):
_dict = dict()
_dict["name"] = "TestFieldSchema_name_float16_vector"
_dict["description"] = "TestFieldSchema_description_float16_vector"
_dict["type"] = DataType.FLOAT16_VECTOR
_dict["params"] = {"dim": 128}
return _dict

@pytest.fixture(scope="function")
def raw_dict_norm(self):
_dict = dict()
Expand All @@ -93,6 +102,14 @@ def dataframe1(self):
df1 = pandas.DataFrame(data)
return df1

def test_constructor_from_float16_dict(self, raw_dict_float16_vector):
field = FieldSchema.construct_from_dict(raw_dict_float16_vector)
assert field.dtype == DataType.FLOAT16_VECTOR
assert field.description == raw_dict_float16_vector['description']
assert field.is_primary is False
assert field.name == raw_dict_float16_vector['name']
assert field.dim == raw_dict_float16_vector['params']['dim']

def test_constructor_from_float_dict(self, raw_dict_float_vector):
field = FieldSchema.construct_from_dict(raw_dict_float_vector)
assert field.dtype == DataType.FLOAT_VECTOR
Expand Down
4 changes: 3 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_infer_dtype_bydata(self):
"abc",
np.int8(1),
np.int16(1),
[np.int8(1)]
[np.int8(1)],
[np.float16(1.0)]
]

wants = [
Expand All @@ -98,6 +99,7 @@ def test_infer_dtype_bydata(self):
DataType.INT8,
DataType.INT16,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
]

actual = []
Expand Down

0 comments on commit 0f74c2d

Please sign in to comment.