diff --git a/tests/common/common_func.py b/tests/common/common_func.py index d90f9ac9..159dae5b 100644 --- a/tests/common/common_func.py +++ b/tests/common/common_func.py @@ -119,6 +119,19 @@ def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False, is_primary=is_primary) return float_vec_field +def gen_float16_vec_field(name=ct.default_float_vec_field_name, is_primary=False, dim=ct.default_dim, + description=ct.default_desc): + float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.FLOAT16_VECTOR, + description=description, dim=dim, + is_primary=is_primary) + return float_vec_field + +def gen_brain_float16_vec_field(name=ct.default_float_vec_field_name, is_primary=False, dim=ct.default_dim, + description=ct.default_desc): + float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.BFLOAT16_VECTOR, + description=description, dim=dim, + is_primary=is_primary) + return float_vec_field def gen_binary_vec_field(name=ct.default_binary_vec_field_name, is_primary=False, dim=ct.default_dim, description=ct.default_desc): diff --git a/tests/requirements.txt b/tests/requirements.txt index cfb3ea85..027aaa15 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -44,4 +44,5 @@ minio==7.1.5 # for benchmark h5py==3.1.0 pytest-benchmark==4.0.0 - +# for brain float16 datatype +jax==0.4.23 \ No newline at end of file diff --git a/tests/testcases/test_restore_backup.py b/tests/testcases/test_restore_backup.py index a5c0c703..5e53b69d 100644 --- a/tests/testcases/test_restore_backup.py +++ b/tests/testcases/test_restore_backup.py @@ -2,6 +2,8 @@ import pytest import json import numpy as np +import jax.numpy as jnp +import random from collections import defaultdict from pymilvus import db, list_collections, Collection, DataType from base.client_base import TestcaseBase @@ -481,6 +483,89 @@ def test_milvus_restore_back_with_multi_vector_datatype(self, include_dynamic, i assert back_up_name not in all_backup + @pytest.mark.parametrize("include_partition_key", [True, False]) + @pytest.mark.parametrize("include_dynamic", [True, False]) + @pytest.mark.tags(CaseLabel.MASTER) + def test_milvus_restore_back_with_f16_bf16_datatype(self, include_dynamic, include_partition_key): + self._connect() + name_origin = cf.gen_unique_str(prefix) + back_up_name = cf.gen_unique_str(backup_prefix) + fields = [cf.gen_int64_field(name="int64", is_primary=True), + cf.gen_int64_field(name="key"), + cf.gen_json_field(name="json"), + cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR), + cf.gen_array_field(name="int_array", element_type=DataType.INT64), + cf.gen_float_vec_field(name="float_vector", dim=128), + cf.gen_float16_vec_field()(name="float16_vector", dim=128), + cf.gen_brain_float16_vec_field(name="brain_float16_vector", dim=128), + ] + if include_partition_key: + partition_key = "key" + default_schema = cf.gen_collection_schema(fields, + enable_dynamic_field=include_dynamic, + partition_key_field=partition_key) + else: + default_schema = cf.gen_collection_schema(fields, + enable_dynamic_field=include_dynamic) + + collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True) + nb = 3000 + data = [ + [i for i in range(nb)], + [i % 3 for i in range(nb)], + [{f"key_{str(i)}": i} for i in range(nb)], + [[str(x) for x in range(10)] for i in range(nb)], + [[int(x) for x in range(10)] for i in range(nb)], + [[np.float32(i) for i in range(128)] for _ in range(nb)], + [[np.float16(i) for i in range(128)] for _ in range(nb)], + [bytes(np.array(jnp.array([random.random() for _ in range(128)], dtype=jnp.bfloat16)).view(np.uint8).tolist()) for _ in range(nb)] + ] + collection_w.insert(data=data) + if include_dynamic: + data = [ + { + "int64": i, + "key": i % 3, + "json": {f"key_{str(i)}": i}, + "var_array": [str(x) for x in range(10)], + "int_array": [int(x) for x in range(10)], + "float_vector": [np.float32(i) for i in range(128)], + "float16_vector": [np.float16(i) for i in range(128)], + "brain_float16_vector": bytes(np.array(jnp.array([random.random() for _ in range(128)], dtype=jnp.bfloat16)).view(np.uint8).tolist()), + f"dynamic_{str(i)}": i + } for i in range(nb, nb*2) + ] + collection_w.insert(data=data) + res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]}) + log.info(f"create_backup {res}") + res = client.list_backup() + log.info(f"list_backup {res}") + if "data" in res: + all_backup = [r["name"] for r in res["data"]] + else: + all_backup = [] + assert back_up_name in all_backup + backup = client.get_backup(back_up_name) + assert backup["data"]["name"] == back_up_name + backup_collections = [backup["collection_name"]for backup in backup["data"]["collection_backups"]] + assert name_origin in backup_collections + res = client.restore_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin], + "collection_suffix": suffix}) + log.info(f"restore_backup: {res}") + res, _ = self.utility_wrap.list_collections() + assert name_origin + suffix in res + output_fields = None + self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields) + res = client.delete_backup(back_up_name) + res = client.list_backup() + if "data" in res: + all_backup = [r["name"] for r in res["data"]] + else: + all_backup = [] + assert back_up_name not in all_backup + + + @pytest.mark.tags(CaseLabel.L1) def test_milvus_restore_back_with_delete(self): self._connect()