Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into tutorial_nb
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Jan 30, 2025
2 parents dafe214 + 288e790 commit d6dd74c
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 104 deletions.
77 changes: 57 additions & 20 deletions nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def serialize_cert(cert):


def load_crt(path):
return load_crt_bytes(open(path, "rb").read())
with open(path, "rb") as f:
return load_crt_bytes(f.read())


def load_crt_bytes(data: bytes):
Expand Down Expand Up @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
for file in files:
if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE:
continue
signatures[file] = sign_content(
content=open(os.path.join(root, file), "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(os.path.join(root, file), "rb") as f:
signatures[file] = sign_content(
content=f.read(),
signing_pri_key=signing_pri_key,
)
for folder in folders:
signatures[folder] = sign_content(
content=folder,
signing_pri_key=signing_pri_key,
)

json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f:
json.dump(signatures, f)
shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
if depth >= max_depth:
break
Expand All @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path):
root_ca_public_key = root_ca_cert.public_key()
for root, folders, files in os.walk(src_folder):
try:
signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f:
signatures = json.load(f)
cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
public_key = cert.public_key()
except:
Expand All @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path):
continue
signature = signatures.get(file)
if signature:
verify_content(
content=open(os.path.join(root, file), "rb").read(),
signature=signature,
public_key=public_key,
)
with open(os.path.join(root, file), "rb") as f:
verify_content(
content=f.read(),
signature=signature,
public_key=public_key,
)
for folder in folders:
signature = signatures.get(folder)
if signature:
Expand All @@ -173,20 +178,52 @@ def sign_all(content_folder, signing_pri_key):
for f in os.listdir(content_folder):
path = os.path.join(content_folder, f)
if os.path.isfile(path):
signatures[f] = sign_content(
content=open(path, "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(path, "rb") as file:
signatures[f] = sign_content(
content=file.read(),
signing_pri_key=signing_pri_key,
)
return signatures


def load_yaml(file):

root = os.path.split(file)[0]
yaml_data = None
if isinstance(file, str):
return yaml.safe_load(open(file, "r"))
with open(file, "r") as f:
yaml_data = yaml.safe_load(f)
elif isinstance(file, bytes):
return yaml.safe_load(file)
else:
return None
yaml_data = yaml.safe_load(file)

yaml_data = load_yaml_include(root, yaml_data)

return yaml_data


def load_yaml_include(root, yaml_data):
new_data = {}
for k, v in yaml_data.items():
if k == "include":
if isinstance(v, str):
includes = [v]
elif isinstance(v, list):
includes = v
for item in includes:
new_data.update(load_yaml(os.path.join(root, item)))
elif isinstance(v, list):
new_list = []
for item in v:
if isinstance(item, dict):
item = load_yaml_include(root, item)
new_list.append(item)
new_data[k] = new_list
elif isinstance(v, dict):
new_data[k] = load_yaml_include(root, v)
else:
new_data[k] = v

return new_data


def sh_replace(src, mapping_dict):
Expand Down
13 changes: 0 additions & 13 deletions tests/unit_test/app_opt/quantization/__init__.py

This file was deleted.

70 changes: 0 additions & 70 deletions tests/unit_test/app_opt/quantization/quantization_test.py

This file was deleted.

15 changes: 15 additions & 0 deletions tests/unit_test/lighter/0.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
api_version: 3
name: example_project

include: 1.yml

participants:
- name: server
port: 123
include: [1.yml]
extra:
location: "east"
include: 3.yml
- name: client
port: 234
include: 2.yml
1 change: 1 addition & 0 deletions tests/unit_test/lighter/1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
server_name: server
1 change: 1 addition & 0 deletions tests/unit_test/lighter/2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client_name: client-1
2 changes: 2 additions & 0 deletions tests/unit_test/lighter/3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
size: 4
gpus: large
20 changes: 19 additions & 1 deletion tests/unit_test/lighter/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from cryptography.x509.oid import NameOID

from nvflare.lighter.impl.cert import serialize_cert
from nvflare.lighter.utils import sign_folders, verify_folder_signature
from nvflare.lighter.utils import load_yaml, sign_folders, verify_folder_signature

folders = ["folder1", "folder2"]
files = ["file1", "file2"]
Expand Down Expand Up @@ -144,3 +144,21 @@ def test_verify_updated_folder(self):
os.unlink("client.crt")
os.unlink("root.crt")
shutil.rmtree(folder)

def _get_participant(self, name, participants):
for p in participants:
if p.get("name") == name:
return p

def test_load_yaml(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
data = load_yaml(os.path.join(dir_path, "0.yml"))

assert data.get("server_name") == "server"

participant = self._get_participant("server", data.get("participants"))
assert participant.get("server_name") == "server"
assert participant.get("extra").get("gpus") == "large"

participant = self._get_participant("client", data.get("participants"))
assert participant.get("client_name") == "client-1"

0 comments on commit d6dd74c

Please sign in to comment.