Skip to content

Commit

Permalink
fix: Unify output of str(DataType) for different pythons (#2635) (#2638)
Browse files Browse the repository at this point in the history
See also: #2633
pr: #2635

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored Feb 17, 2025
1 parent 2139c65 commit 5edb054
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
49 changes: 28 additions & 21 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ExceptionsMessage,
InvalidConsistencyLevel,
)
from pymilvus.grpc_gen import common_pb2, rg_pb2
from pymilvus.grpc_gen import common_pb2, rg_pb2, schema_pb2
from pymilvus.grpc_gen import milvus_pb2 as milvus_types

Status = TypeVar("Status")
Expand Down Expand Up @@ -84,29 +84,36 @@ def OK(self):


class DataType(IntEnum):
NONE = 0
BOOL = 1
INT8 = 2
INT16 = 3
INT32 = 4
INT64 = 5

FLOAT = 10
DOUBLE = 11

STRING = 20
VARCHAR = 21
ARRAY = 22
JSON = 23

BINARY_VECTOR = 100
FLOAT_VECTOR = 101
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103
SPARSE_FLOAT_VECTOR = 104
"""
String of DataType is str of its value, e.g.: str(DataType.BOOL) == "1"
"""

NONE = 0 # schema_pb2.None, this is an invalid representation in python
BOOL = schema_pb2.Bool
INT8 = schema_pb2.Int8
INT16 = schema_pb2.Int16
INT32 = schema_pb2.Int32
INT64 = schema_pb2.Int64

FLOAT = schema_pb2.Float
DOUBLE = schema_pb2.Double

STRING = schema_pb2.String
VARCHAR = schema_pb2.VarChar
ARRAY = schema_pb2.Array
JSON = schema_pb2.JSON

BINARY_VECTOR = schema_pb2.BinaryVector
FLOAT_VECTOR = schema_pb2.FloatVector
FLOAT16_VECTOR = schema_pb2.Float16Vector
BFLOAT16_VECTOR = schema_pb2.BFloat16Vector
SPARSE_FLOAT_VECTOR = schema_pb2.SparseFloatVector

UNKNOWN = 999

def __str__(self) -> str:
return str(self.value)


class RangeType(IntEnum):
LT = 0 # less than
Expand Down
45 changes: 28 additions & 17 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,37 @@
# from ml_dtypes import bfloat16


@pytest.mark.skip("please fix me")
class TestTypes:
@pytest.mark.parametrize("input_expect", [
([1], DataType.FLOAT_VECTOR),
([True], DataType.UNKNOWN),
([1.0, 2.0], DataType.FLOAT_VECTOR),
(["abc"], DataType.UNKNOWN),
(bytes("abc", encoding='ascii'), DataType.BINARY_VECTOR),
(1, DataType.INT64),
(True, DataType.BOOL),
("abc", DataType.VARCHAR),
(np.int8(1), DataType.INT8),
(np.int16(1), DataType.INT16),
([np.int8(1)], DataType.FLOAT_VECTOR),
([np.float16(1.0)], DataType.FLOAT16_VECTOR),
# ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR),
])
@pytest.mark.skip("please fix me")
@pytest.mark.parametrize(
"input_expect",
[
([1], DataType.FLOAT_VECTOR),
([True], DataType.UNKNOWN),
([1.0, 2.0], DataType.FLOAT_VECTOR),
(["abc"], DataType.UNKNOWN),
(bytes("abc", encoding="ascii"), DataType.BINARY_VECTOR),
(1, DataType.INT64),
(True, DataType.BOOL),
("abc", DataType.VARCHAR),
(np.int8(1), DataType.INT8),
(np.int16(1), DataType.INT16),
([np.int8(1)], DataType.FLOAT_VECTOR),
([np.float16(1.0)], DataType.FLOAT16_VECTOR),
# ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR),
],
)
def test_infer_dtype_bydata(self, input_expect):
data, expect = input_expect
got = infer_dtype_bydata(data)
assert got == expect

def test_str_of_data_type(self):
for v in DataType:
assert isinstance(v, DataType)
assert str(v) == str(v.value)
assert str(v) != v.name


class TestConsistencyLevel:
def test_consistency_level_int(self):
Expand Down Expand Up @@ -91,6 +100,8 @@ def test_shard(self):
def test_shard_dup_nodeIDs(self):
s = Shard("channel-1", (1, 1, 1), 1)
assert s.channel_name == "channel-1"
assert s.shard_nodes == {1,}
assert s.shard_nodes == {
1,
}
assert s.shard_leader == 1
print(s)

0 comments on commit 5edb054

Please sign in to comment.