From 0f74c2d0dda93e8e88a26205b8f6888001ed563b Mon Sep 17 00:00:00 2001 From: Xu Tong <80471801+Writer-X@users.noreply.github.com> Date: Wed, 8 Nov 2023 17:54:10 +0800 Subject: [PATCH] Add float16 Vector Approve (#1772) Signed-off-by: Writer-X <1256866856@qq.com> --- examples/float16_example.py | 70 ++++++++++++++++++++++++++++++++ pymilvus/client/entity_helper.py | 3 ++ pymilvus/client/types.py | 2 + pymilvus/orm/schema.py | 11 ++++- tests/test_abstract.py | 6 +++ tests/test_collection.py | 3 +- tests/test_create_collection.py | 42 +++++++++++++++++++ tests/test_schema.py | 17 ++++++++ tests/test_types.py | 4 +- 9 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 examples/float16_example.py diff --git a/examples/float16_example.py b/examples/float16_example.py new file mode 100644 index 000000000..3c593e3d3 --- /dev/null +++ b/examples/float16_example.py @@ -0,0 +1,70 @@ +import time +import random +import numpy as np +from pymilvus import ( + connections, + utility, + FieldSchema, CollectionSchema, DataType, + Collection, + ) +from pymilvus import MilvusClient + +fp16_index_types = ["FLAT"] + +default_fp16_index_params = [{"nlist": 128}] + +def gen_fp16_vectors(num, dim): + raw_vectors = [] + fp16_vectors = [] + for _ in range(num): + raw_vector = [random.random() for _ in range(dim)] + raw_vectors.append(raw_vector) + fp16_vector = np.array(raw_vector, dtype=np.float16).tobytes() + fp16_vectors.append(fp16_vector) + return raw_vectors, fp16_vectors + +def fp16_vector_search(): + connections.connect() + + int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True) + dim = 128 + nb = 3000 + vector_field_name = "float16_vector" + fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim) + schema = CollectionSchema(fields=[int64_field, fp16_vector]) + + has = utility.has_collection("hello_milvus") + if has: + hello_milvus = Collection("hello_milvus_fp16") + hello_milvus.drop() + else: + hello_milvus = Collection("hello_milvus_fp16", schema) + + _, vectors = gen_fp16_vectors(nb, dim) + rows = [ + {vector_field_name: vectors[0]}, + {vector_field_name: vectors[1]}, + {vector_field_name: vectors[2]}, + {vector_field_name: vectors[3]}, + {vector_field_name: vectors[4]}, + {vector_field_name: vectors[5]}, + ] + + hello_milvus.insert(rows) + hello_milvus.flush() + + for i, index_type in enumerate(fp16_index_types): + index_params = default_fp16_index_params[i] + hello_milvus.create_index(vector_field_name, + index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"}) + hello_milvus.load() + print("index_type = ", index_type) + res = hello_milvus.search(vectors[:1], vector_field_name, {"metric_type": "L2"}, limit=1) + print(res) + hello_milvus.release() + hello_milvus.drop_index() + + hello_milvus.drop() + +if __name__ == "__main__": + fp16_vector_search() \ No newline at end of file diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index f0b1fa50d..61f567a10 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -122,6 +122,9 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info elif field_type == DataType.BINARY_VECTOR: field_data.vectors.dim = len(field_value) * 8 field_data.vectors.binary_vector += bytes(field_value) + elif field_type == DataType.FLOAT16_VECTOR: + field_data.vectors.dim = len(field_value) // 2 + field_data.vectors.float16_vector += bytes(field_value) elif field_type == DataType.VARCHAR: field_data.scalars.string_data.data.append( convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 0c64ba724..11bcf0c18 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -87,6 +87,7 @@ class DataType(IntEnum): BINARY_VECTOR = 100 FLOAT_VECTOR = 101 + FLOAT16_VECTOR = 102 UNKNOWN = 999 @@ -154,6 +155,7 @@ class PlaceholderType(IntEnum): NoneType = 0 BinaryVector = 100 FloatVector = 101 + Float16Vector = 102 class State(IntEnum): diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index a4a295b91..52f808575 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -288,6 +288,7 @@ def _parse_type_params(self): if self._dtype not in ( DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, + DataType.FLOAT16_VECTOR, DataType.VARCHAR, DataType.ARRAY, ): @@ -493,10 +494,16 @@ def prepare_fields_from_dataframe(df: pd.DataFrame): for i, dtype in enumerate(data_types): if dtype == DataType.UNKNOWN: new_dtype = infer_dtype_bydata(values[i]) - if new_dtype in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR): + if new_dtype in ( + DataType.BINARY_VECTOR, + DataType.FLOAT_VECTOR, + DataType.FLOAT16_VECTOR, + ): vector_type_params = {} if new_dtype == DataType.BINARY_VECTOR: vector_type_params["dim"] = len(values[i]) * 8 + elif new_dtype == DataType.FLOAT16_VECTOR: + vector_type_params["dim"] = len(values[i]) / 2 else: vector_type_params["dim"] = len(values[i]) column_params_map[col_names[i]] = vector_type_params @@ -515,7 +522,7 @@ def check_schema(schema: CollectionSchema): raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema) vector_fields = [] for field in schema.fields: - if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, DataType.FLOAT16_VECTOR): vector_fields.append(field.name) if len(vector_fields) < 1: raise SchemaNotReadyException(message=ExceptionsMessage.NoVector) diff --git a/tests/test_abstract.py b/tests/test_abstract.py index a1023d7f2..a45ddb51f 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -168,6 +168,12 @@ def test_search_result_with_fields_data(self, pk): binary_vector=os.urandom(6), ), ), + schema_pb2.FieldData(type=DataType.FLOAT16_VECTOR, field_name="float16_vector_field", field_id=115, + vectors=schema_pb2.VectorField( + dim=16, + float16_vector=os.urandom(32), + ), + ), ] result = schema_pb2.SearchResultData( fields_data=fields_data, diff --git a/tests/test_collection.py b/tests/test_collection.py index d0fca7bcd..8f298e86a 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -27,7 +27,8 @@ def test_collection_by_DataFrame(self): fields = [ FieldSchema("int64", DataType.INT64), FieldSchema("float", DataType.FLOAT), - FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128) + FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128), + FieldSchema("float16_vector", DataType.FLOAT16_VECTOR, dim=128) ] prefix = "pymilvus.client.grpc_handler.GrpcHandler" diff --git a/tests/test_create_collection.py b/tests/test_create_collection.py index 09ee0a880..b03d76226 100644 --- a/tests/test_create_collection.py +++ b/tests/test_create_collection.py @@ -118,3 +118,45 @@ def test_create_collection(self, collection_name): return_value = future.result() assert return_value.code == 0 assert return_value.reason == "success" + + def test_create_fp16_collection(self, collection_name): + id_field = { + "name": "my_id", + "type": DataType.INT64, + "auto_id": True, + "is_primary": True, + } + vector_field = { + "name": "embedding", + "type": DataType.FLOAT16_VECTOR, + "metric_type": "L2", + "params": {"dim": "4"}, + } + fields = {"fields": [id_field, vector_field], "enable_dynamic_field": True} + future = self._milvus.create_collection( + collection_name=collection_name, fields=fields, _async=True + ) + + invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary( + self._servicer.methods_by_name["CreateCollection"] + ) + rpc.send_initial_metadata(()) + rpc.terminate( + common_pb2.Status( + code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success" + ), + (), + grpc.StatusCode.OK, + "", + ) + + request_schema = schema_pb2.CollectionSchema() + request_schema.ParseFromString(request.schema) + + assert request.collection_name == collection_name + assert Fields.equal(request_schema.fields, fields["fields"]) + assert request_schema.enable_dynamic_field == fields["enable_dynamic_field"] + + return_value = future.result() + assert return_value.code == 0 + assert return_value.reason == "success" diff --git a/tests/test_schema.py b/tests/test_schema.py index aabcff668..4f808b833 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -74,6 +74,15 @@ def raw_dict_binary_vector(self): _dict["params"] = {"dim": 128} return _dict + @pytest.fixture(scope="function") + def raw_dict_float16_vector(self): + _dict = dict() + _dict["name"] = "TestFieldSchema_name_float16_vector" + _dict["description"] = "TestFieldSchema_description_float16_vector" + _dict["type"] = DataType.FLOAT16_VECTOR + _dict["params"] = {"dim": 128} + return _dict + @pytest.fixture(scope="function") def raw_dict_norm(self): _dict = dict() @@ -93,6 +102,14 @@ def dataframe1(self): df1 = pandas.DataFrame(data) return df1 + def test_constructor_from_float16_dict(self, raw_dict_float16_vector): + field = FieldSchema.construct_from_dict(raw_dict_float16_vector) + assert field.dtype == DataType.FLOAT16_VECTOR + assert field.description == raw_dict_float16_vector['description'] + assert field.is_primary is False + assert field.name == raw_dict_float16_vector['name'] + assert field.dim == raw_dict_float16_vector['params']['dim'] + def test_constructor_from_float_dict(self, raw_dict_float_vector): field = FieldSchema.construct_from_dict(raw_dict_float_vector) assert field.dtype == DataType.FLOAT_VECTOR diff --git a/tests/test_types.py b/tests/test_types.py index 51f389597..e97f81b2a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -83,7 +83,8 @@ def test_infer_dtype_bydata(self): "abc", np.int8(1), np.int16(1), - [np.int8(1)] + [np.int8(1)], + [np.float16(1.0)] ] wants = [ @@ -98,6 +99,7 @@ def test_infer_dtype_bydata(self): DataType.INT8, DataType.INT16, DataType.FLOAT_VECTOR, + DataType.FLOAT16_VECTOR, ] actual = []