From 5edb054d2127f29779b35bfd6a6d4515da41c511 Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Mon, 17 Feb 2025 14:18:10 +0800 Subject: [PATCH] fix: Unify output of str(DataType) for different pythons (#2635) (#2638) See also: #2633 pr: #2635 Signed-off-by: yangxuan --- pymilvus/client/types.py | 49 +++++++++++++++++++++++----------------- tests/test_types.py | 45 ++++++++++++++++++++++-------------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 29e7068dc..f056af3a0 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -7,7 +7,7 @@ ExceptionsMessage, InvalidConsistencyLevel, ) -from pymilvus.grpc_gen import common_pb2, rg_pb2 +from pymilvus.grpc_gen import common_pb2, rg_pb2, schema_pb2 from pymilvus.grpc_gen import milvus_pb2 as milvus_types Status = TypeVar("Status") @@ -84,29 +84,36 @@ def OK(self): class DataType(IntEnum): - NONE = 0 - BOOL = 1 - INT8 = 2 - INT16 = 3 - INT32 = 4 - INT64 = 5 - - FLOAT = 10 - DOUBLE = 11 - - STRING = 20 - VARCHAR = 21 - ARRAY = 22 - JSON = 23 - - BINARY_VECTOR = 100 - FLOAT_VECTOR = 101 - FLOAT16_VECTOR = 102 - BFLOAT16_VECTOR = 103 - SPARSE_FLOAT_VECTOR = 104 + """ + String of DataType is str of its value, e.g.: str(DataType.BOOL) == "1" + """ + + NONE = 0 # schema_pb2.None, this is an invalid representation in python + BOOL = schema_pb2.Bool + INT8 = schema_pb2.Int8 + INT16 = schema_pb2.Int16 + INT32 = schema_pb2.Int32 + INT64 = schema_pb2.Int64 + + FLOAT = schema_pb2.Float + DOUBLE = schema_pb2.Double + + STRING = schema_pb2.String + VARCHAR = schema_pb2.VarChar + ARRAY = schema_pb2.Array + JSON = schema_pb2.JSON + + BINARY_VECTOR = schema_pb2.BinaryVector + FLOAT_VECTOR = schema_pb2.FloatVector + FLOAT16_VECTOR = schema_pb2.Float16Vector + BFLOAT16_VECTOR = schema_pb2.BFloat16Vector + SPARSE_FLOAT_VECTOR = schema_pb2.SparseFloatVector UNKNOWN = 999 + def __str__(self) -> str: + return str(self.value) + class RangeType(IntEnum): LT = 0 # less than diff --git a/tests/test_types.py b/tests/test_types.py index 1fe685fe6..edd34f40d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -31,28 +31,37 @@ # from ml_dtypes import bfloat16 -@pytest.mark.skip("please fix me") class TestTypes: - @pytest.mark.parametrize("input_expect", [ - ([1], DataType.FLOAT_VECTOR), - ([True], DataType.UNKNOWN), - ([1.0, 2.0], DataType.FLOAT_VECTOR), - (["abc"], DataType.UNKNOWN), - (bytes("abc", encoding='ascii'), DataType.BINARY_VECTOR), - (1, DataType.INT64), - (True, DataType.BOOL), - ("abc", DataType.VARCHAR), - (np.int8(1), DataType.INT8), - (np.int16(1), DataType.INT16), - ([np.int8(1)], DataType.FLOAT_VECTOR), - ([np.float16(1.0)], DataType.FLOAT16_VECTOR), - # ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR), - ]) + @pytest.mark.skip("please fix me") + @pytest.mark.parametrize( + "input_expect", + [ + ([1], DataType.FLOAT_VECTOR), + ([True], DataType.UNKNOWN), + ([1.0, 2.0], DataType.FLOAT_VECTOR), + (["abc"], DataType.UNKNOWN), + (bytes("abc", encoding="ascii"), DataType.BINARY_VECTOR), + (1, DataType.INT64), + (True, DataType.BOOL), + ("abc", DataType.VARCHAR), + (np.int8(1), DataType.INT8), + (np.int16(1), DataType.INT16), + ([np.int8(1)], DataType.FLOAT_VECTOR), + ([np.float16(1.0)], DataType.FLOAT16_VECTOR), + # ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR), + ], + ) def test_infer_dtype_bydata(self, input_expect): data, expect = input_expect got = infer_dtype_bydata(data) assert got == expect + def test_str_of_data_type(self): + for v in DataType: + assert isinstance(v, DataType) + assert str(v) == str(v.value) + assert str(v) != v.name + class TestConsistencyLevel: def test_consistency_level_int(self): @@ -91,6 +100,8 @@ def test_shard(self): def test_shard_dup_nodeIDs(self): s = Shard("channel-1", (1, 1, 1), 1) assert s.channel_name == "channel-1" - assert s.shard_nodes == {1,} + assert s.shard_nodes == { + 1, + } assert s.shard_leader == 1 print(s)