Skip to content

Commit

Permalink
Add (b)float16 vector approve
Browse files Browse the repository at this point in the history
Signed-off-by: Writer-X <[email protected]>
  • Loading branch information
Writer-X committed Jan 11, 2024
1 parent f393302 commit c21fade
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 5 deletions.
71 changes: 71 additions & 0 deletions examples/bfloat16_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import time
import random
import numpy as np
from bfloat16 import bfloat16
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from pymilvus import MilvusClient

bf16_index_types = ["FLAT"]

default_bf16_index_params = [{"nlist": 128}]

def gen_bf16_vectors(num, dim):
raw_vectors = []
bf16_vectors = []
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
bf16_vectors.append(bytes(bf16_vector))
return raw_vectors, bf16_vectors

def bf16_vector_search():
connections.connect()

int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True)
dim = 128
nb = 3000
vector_field_name = "bfloat16_vector"
bf16_vector = FieldSchema(name=vector_field_name, dtype=DataType.BFLOAT16_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, bf16_vector])

has = utility.has_collection("hello_milvus")
if has:
hello_milvus = Collection("hello_milvus_fp16")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_fp16", schema)

_, vectors = gen_bf16_vectors(nb, dim)
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
]

hello_milvus.insert(rows)
hello_milvus.flush()

for i, index_type in enumerate(bf16_index_types):
index_params = default_bf16_index_params[i]
hello_milvus.create_index(vector_field_name,
index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"})
hello_milvus.load()
print("index_type = ", index_type)
res = hello_milvus.search(vectors[0:10], vector_field_name, {"metric_type": "L2"}, limit=1)
print(res)
hello_milvus.release()
hello_milvus.drop_index()

hello_milvus.drop()

if __name__ == "__main__":
bf16_vector_search()
70 changes: 70 additions & 0 deletions examples/float16_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import time
import random
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from pymilvus import MilvusClient

fp16_index_types = ["FLAT"]

default_fp16_index_params = [{"nlist": 128}]

def gen_fp16_vectors(num, dim):
raw_vectors = []
fp16_vectors = []
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
fp16_vector = np.array(raw_vector, dtype=np.float16).view(np.uint8).tolist()
fp16_vectors.append(bytes(fp16_vector))
return raw_vectors, fp16_vectors

def fp16_vector_search():
connections.connect()

int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True)
dim = 128
nb = 3000
vector_field_name = "float16_vector"
fp16_vector = FieldSchema(name=vector_field_name, dtype=DataType.FLOAT16_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, fp16_vector])

has = utility.has_collection("hello_milvus")
if has:
hello_milvus = Collection("hello_milvus_fp16")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_fp16", schema)

_, vectors = gen_fp16_vectors(nb, dim)
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
]

hello_milvus.insert(rows)
hello_milvus.flush()

for i, index_type in enumerate(fp16_index_types):
index_params = default_fp16_index_params[i]
hello_milvus.create_index(vector_field_name,
index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"})
hello_milvus.load()
print("index_type = ", index_type)
res = hello_milvus.search(vectors[0:10], vector_field_name, {"metric_type": "L2"}, limit=1)
print(res)
hello_milvus.release()
hello_milvus.drop_index()

hello_milvus.drop()

if __name__ == "__main__":
fp16_vector_search()
12 changes: 12 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
elif field_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)
elif field_type == DataType.FLOAT16_VECTOR:
field_data.vectors.dim = len(field_value) // 2
field_data.vectors.float16_vector += bytes(field_value)
elif field_type == DataType.BFLOAT16_VECTOR:
field_data.vectors.dim = len(field_value) // 2
field_data.vectors.bfloat16_vector += bytes(field_value)
elif field_type == DataType.VARCHAR:
field_data.scalars.string_data.data.append(
convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY)
Expand Down Expand Up @@ -159,6 +165,12 @@ def entity_to_field_data(entity: Any, field_info: Any):
elif entity_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) * 8
field_data.vectors.binary_vector = b"".join(entity.get("values"))
elif entity_type == DataType.FLOAT16_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.float16_vector = b"".join(entity.get("values"))
elif entity_type == DataType.BFLOAT16_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.bfloat16_vector = b"".join(entity.get("values"))
elif entity_type == DataType.VARCHAR:
field_data.scalars.string_data.data.extend(
entity_to_str_arr(entity, field_info, CHECK_STR_ARRAY)
Expand Down
7 changes: 6 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,12 @@ def create_index(
if field_name != fields["name"]:
continue
valid_field = True
if fields["type"] not in {DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR}:
if fields["type"] not in {
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
}:
break

if not valid_field:
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class DataType(IntEnum):

BINARY_VECTOR = 100
FLOAT_VECTOR = 101
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103

UNKNOWN = 999

Expand Down Expand Up @@ -154,6 +156,8 @@ class PlaceholderType(IntEnum):
NoneType = 0
BinaryVector = 100
FloatVector = 101
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103


class State(IntEnum):
Expand Down
18 changes: 16 additions & 2 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def _parse_type_params(self):
if self._dtype not in (
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.VARCHAR,
DataType.ARRAY,
):
Expand Down Expand Up @@ -493,10 +495,17 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
for i, dtype in enumerate(data_types):
if dtype == DataType.UNKNOWN:
new_dtype = infer_dtype_bydata(values[i])
if new_dtype in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR):
if new_dtype in (
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
):
vector_type_params = {}
if new_dtype == DataType.BINARY_VECTOR:
vector_type_params["dim"] = len(values[i]) * 8
elif new_dtype in (DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR):
vector_type_params["dim"] = len(values[i]) / 2
else:
vector_type_params["dim"] = len(values[i])
column_params_map[col_names[i]] = vector_type_params
Expand All @@ -515,7 +524,12 @@ def check_schema(schema: CollectionSchema):
raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema)
vector_fields = []
for field in schema.fields:
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
if field.dtype in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
):
vector_fields.append(field.name)
if len(vector_fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ grpcio-testing
sklearn==0.0
ruff
black
git + https://github.com/GreenWaves-Technologies/bfloat16.git
12 changes: 12 additions & 0 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def test_search_result_with_fields_data(self, pk):
binary_vector=os.urandom(6),
),
),
schema_pb2.FieldData(type=DataType.FLOAT16_VECTOR, field_name="float16_vector_field", field_id=115,
vectors=schema_pb2.VectorField(
dim=16,
float16_vector=os.urandom(32),
),
),
schema_pb2.FieldData(type=DataType.BFLOAT16_VECTOR, field_name="bfloat16_vector_field", field_id=116,
vectors=schema_pb2.VectorField(
dim=16,
bfloat16_vector=os.urandom(32),
),
),
]
result = schema_pb2.SearchResultData(
fields_data=fields_data,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_collection_by_DataFrame(self):
fields = [
FieldSchema("int64", DataType.INT64),
FieldSchema("float", DataType.FLOAT),
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128)
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128),
FieldSchema("float16_vector", DataType.FLOAT16_VECTOR, dim=128),
FieldSchema("bfloat16_vector", DataType.BFLOAT16_VECTOR, dim=128)
]

prefix = "pymilvus.client.grpc_handler.GrpcHandler"
Expand Down
84 changes: 84 additions & 0 deletions tests/test_create_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,87 @@ def test_create_collection(self, collection_name):
return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"

def test_create_fp16_collection(self, collection_name):
id_field = {
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
}
vector_field = {
"name": "embedding",
"type": DataType.FLOAT16_VECTOR,
"metric_type": "L2",
"params": {"dim": "4"},
}
fields = {"fields": [id_field, vector_field], "enable_dynamic_field": True}
future = self._milvus.create_collection(
collection_name=collection_name, fields=fields, _async=True
)

invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary(
self._servicer.methods_by_name["CreateCollection"]
)
rpc.send_initial_metadata(())
rpc.terminate(
common_pb2.Status(
code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success"
),
(),
grpc.StatusCode.OK,
"",
)

request_schema = schema_pb2.CollectionSchema()
request_schema.ParseFromString(request.schema)

assert request.collection_name == collection_name
assert Fields.equal(request_schema.fields, fields["fields"])
assert request_schema.enable_dynamic_field == fields["enable_dynamic_field"]

return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"

def test_create_bf16_collection(self, collection_name):
id_field = {
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
}
vector_field = {
"name": "embedding",
"type": DataType.BFLOAT16_VECTOR,
"metric_type": "L2",
"params": {"dim": "4"},
}
fields = {"fields": [id_field, vector_field], "enable_dynamic_field": True}
future = self._milvus.create_collection(
collection_name=collection_name, fields=fields, _async=True
)

invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary(
self._servicer.methods_by_name["CreateCollection"]
)
rpc.send_initial_metadata(())
rpc.terminate(
common_pb2.Status(
code=ErrorCode.SUCCESS, error_code=common_pb2.Success, reason="success"
),
(),
grpc.StatusCode.OK,
"",
)

request_schema = schema_pb2.CollectionSchema()
request_schema.ParseFromString(request.schema)

assert request.collection_name == collection_name
assert Fields.equal(request_schema.fields, fields["fields"])
assert request_schema.enable_dynamic_field == fields["enable_dynamic_field"]

return_value = future.result()
assert return_value.code == 0
assert return_value.reason == "success"
Loading

0 comments on commit c21fade

Please sign in to comment.