From b5d7ca6618479480402bdbfce9636c857ebeeec8 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Wed, 29 Jan 2025 18:29:59 -0500 Subject: [PATCH 1/2] Quant unittest fix (#3191) Fixes # . ### Description Remove quantization unit test, mainly due to limited testable scope and random errors caused by imported but not used packages ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- .../app_opt/quantization/__init__.py | 13 ---- .../app_opt/quantization/quantization_test.py | 70 ------------------- 2 files changed, 83 deletions(-) delete mode 100644 tests/unit_test/app_opt/quantization/__init__.py delete mode 100644 tests/unit_test/app_opt/quantization/quantization_test.py diff --git a/tests/unit_test/app_opt/quantization/__init__.py b/tests/unit_test/app_opt/quantization/__init__.py deleted file mode 100644 index d9155f923f..0000000000 --- a/tests/unit_test/app_opt/quantization/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py deleted file mode 100644 index 5f452ca3a5..0000000000 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest -import torch - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor -from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor - -TEST_CASES = [ - ( - {"a": np.array([1.0, 2.0, 3.0, 70000.0], dtype="float32")}, - "float16", - {"a": np.array([1.0, 2.0, 3.0, 65504.0], dtype="float32")}, - ), - # ( - # {"a": np.array([1.0, 2.0, 3.0, 4.0], dtype="float32")}, - # "blockwise8", - # {"a": np.array([0.99062496, 2.003125, 3.015625, 4.0], dtype="float32")}, - # ), - ( - {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, - "float16", - {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, - ), - # ( - # {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)}, - # "blockwise8", - # {"a": torch.tensor([0.99062496, 2.003125, 3.015625, 4.0], dtype=torch.float32)}, - # ), -] - - -class TestQuantization: - @pytest.mark.parametrize("input_data, quantization_type, expected_data", TEST_CASES) - def test_quantization(self, input_data, quantization_type, expected_data): - dxo = DXO( - data_kind=DataKind.WEIGHTS, - data=input_data, - ) - fl_ctx = FLContext() - f_quant = ModelQuantizor(quantization_type=quantization_type) - quant_dxo = f_quant.process_dxo(dxo, dxo.to_shareable(), fl_ctx) - f_dequant = ModelDequantizor() - dequant_dxo = f_dequant.process_dxo(quant_dxo, dxo.to_shareable(), fl_ctx) - dequant_data = dequant_dxo.data - for key in dequant_data.keys(): - dequant_array = dequant_data[key] - expected_array = expected_data[key] - # print the values - print(f"dequant_array: {dequant_array}") - print(f"expected_array: {expected_array}") - if isinstance(dequant_array, torch.Tensor): - assert torch.allclose(dequant_array, expected_array) - else: - assert np.allclose(dequant_array, expected_array) From 288e79081e40e27eb36d464e5ae8d1fe75d0b1ae Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 29 Jan 2025 19:01:50 -0500 Subject: [PATCH 2/2] Add yaml include support (#3133) Fixes # . ### Description Added the support for yaml to include another yaml configuration file. YAML does not naturally support any kind of "import" or "include" statement to include another yaml file. Adding this to support the yaml config in the format like: (include could be single file, or a list of include files.) .... include: 1.yml or: include: [1.yml, 2.yml] The "include" can be used at any level. Also support recursively include. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Ziyue Xu --- nvflare/lighter/utils.py | 77 ++++++++++++++++++++------- tests/unit_test/lighter/0.yml | 15 ++++++ tests/unit_test/lighter/1.yml | 1 + tests/unit_test/lighter/2.yml | 1 + tests/unit_test/lighter/3.yml | 2 + tests/unit_test/lighter/utils_test.py | 20 ++++++- 6 files changed, 95 insertions(+), 21 deletions(-) create mode 100644 tests/unit_test/lighter/0.yml create mode 100644 tests/unit_test/lighter/1.yml create mode 100644 tests/unit_test/lighter/2.yml create mode 100644 tests/unit_test/lighter/3.yml diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index ae71e409fa..8cee919cca 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -40,7 +40,8 @@ def serialize_cert(cert): def load_crt(path): - return load_crt_bytes(open(path, "rb").read()) + with open(path, "rb") as f: + return load_crt_bytes(f.read()) def load_crt_bytes(data: bytes): @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999): for file in files: if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE: continue - signatures[file] = sign_content( - content=open(os.path.join(root, file), "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(os.path.join(root, file), "rb") as f: + signatures[file] = sign_content( + content=f.read(), + signing_pri_key=signing_pri_key, + ) for folder in folders: signatures[folder] = sign_content( content=folder, signing_pri_key=signing_pri_key, ) - json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f: + json.dump(signatures, f) shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) if depth >= max_depth: break @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path): root_ca_public_key = root_ca_cert.public_key() for root, folders, files in os.walk(src_folder): try: - signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f: + signatures = json.load(f) cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) public_key = cert.public_key() except: @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path): continue signature = signatures.get(file) if signature: - verify_content( - content=open(os.path.join(root, file), "rb").read(), - signature=signature, - public_key=public_key, - ) + with open(os.path.join(root, file), "rb") as f: + verify_content( + content=f.read(), + signature=signature, + public_key=public_key, + ) for folder in folders: signature = signatures.get(folder) if signature: @@ -173,20 +178,52 @@ def sign_all(content_folder, signing_pri_key): for f in os.listdir(content_folder): path = os.path.join(content_folder, f) if os.path.isfile(path): - signatures[f] = sign_content( - content=open(path, "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(path, "rb") as file: + signatures[f] = sign_content( + content=file.read(), + signing_pri_key=signing_pri_key, + ) return signatures def load_yaml(file): + + root = os.path.split(file)[0] + yaml_data = None if isinstance(file, str): - return yaml.safe_load(open(file, "r")) + with open(file, "r") as f: + yaml_data = yaml.safe_load(f) elif isinstance(file, bytes): - return yaml.safe_load(file) - else: - return None + yaml_data = yaml.safe_load(file) + + yaml_data = load_yaml_include(root, yaml_data) + + return yaml_data + + +def load_yaml_include(root, yaml_data): + new_data = {} + for k, v in yaml_data.items(): + if k == "include": + if isinstance(v, str): + includes = [v] + elif isinstance(v, list): + includes = v + for item in includes: + new_data.update(load_yaml(os.path.join(root, item))) + elif isinstance(v, list): + new_list = [] + for item in v: + if isinstance(item, dict): + item = load_yaml_include(root, item) + new_list.append(item) + new_data[k] = new_list + elif isinstance(v, dict): + new_data[k] = load_yaml_include(root, v) + else: + new_data[k] = v + + return new_data def sh_replace(src, mapping_dict): diff --git a/tests/unit_test/lighter/0.yml b/tests/unit_test/lighter/0.yml new file mode 100644 index 0000000000..bc9ee7e07c --- /dev/null +++ b/tests/unit_test/lighter/0.yml @@ -0,0 +1,15 @@ +api_version: 3 +name: example_project + +include: 1.yml + +participants: + - name: server + port: 123 + include: [1.yml] + extra: + location: "east" + include: 3.yml + - name: client + port: 234 + include: 2.yml diff --git a/tests/unit_test/lighter/1.yml b/tests/unit_test/lighter/1.yml new file mode 100644 index 0000000000..4dece82e9e --- /dev/null +++ b/tests/unit_test/lighter/1.yml @@ -0,0 +1 @@ +server_name: server \ No newline at end of file diff --git a/tests/unit_test/lighter/2.yml b/tests/unit_test/lighter/2.yml new file mode 100644 index 0000000000..18d2519c61 --- /dev/null +++ b/tests/unit_test/lighter/2.yml @@ -0,0 +1 @@ +client_name: client-1 \ No newline at end of file diff --git a/tests/unit_test/lighter/3.yml b/tests/unit_test/lighter/3.yml new file mode 100644 index 0000000000..08117a92cc --- /dev/null +++ b/tests/unit_test/lighter/3.yml @@ -0,0 +1,2 @@ +size: 4 +gpus: large \ No newline at end of file diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index f1bdb7266a..60c3eaa9c8 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -25,7 +25,7 @@ from cryptography.x509.oid import NameOID from nvflare.lighter.impl.cert import serialize_cert -from nvflare.lighter.utils import sign_folders, verify_folder_signature +from nvflare.lighter.utils import load_yaml, sign_folders, verify_folder_signature folders = ["folder1", "folder2"] files = ["file1", "file2"] @@ -144,3 +144,21 @@ def test_verify_updated_folder(self): os.unlink("client.crt") os.unlink("root.crt") shutil.rmtree(folder) + + def _get_participant(self, name, participants): + for p in participants: + if p.get("name") == name: + return p + + def test_load_yaml(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + data = load_yaml(os.path.join(dir_path, "0.yml")) + + assert data.get("server_name") == "server" + + participant = self._get_participant("server", data.get("participants")) + assert participant.get("server_name") == "server" + assert participant.get("extra").get("gpus") == "large" + + participant = self._get_participant("client", data.get("participants")) + assert participant.get("client_name") == "client-1"