Skip to content

Commit

Permalink
Add yaml include support (#3133)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Added the support for yaml to include another yaml configuration file.

YAML does not naturally support any kind of "import" or "include"
statement to include another yaml file. Adding this to support the yaml
config in the format like: (include could be single file, or a list of
include files.)

....
include: 1.yml

or:
include:  [1.yml, 2.yml]

The "include" can be used at any level. Also support recursively
include.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.

---------

Co-authored-by: Ziyue Xu <[email protected]>
  • Loading branch information
yhwen and ZiyueXu77 authored Jan 30, 2025
1 parent b5d7ca6 commit 288e790
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 21 deletions.
77 changes: 57 additions & 20 deletions nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def serialize_cert(cert):


def load_crt(path):
return load_crt_bytes(open(path, "rb").read())
with open(path, "rb") as f:
return load_crt_bytes(f.read())


def load_crt_bytes(data: bytes):
Expand Down Expand Up @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
for file in files:
if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE:
continue
signatures[file] = sign_content(
content=open(os.path.join(root, file), "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(os.path.join(root, file), "rb") as f:
signatures[file] = sign_content(
content=f.read(),
signing_pri_key=signing_pri_key,
)
for folder in folders:
signatures[folder] = sign_content(
content=folder,
signing_pri_key=signing_pri_key,
)

json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f:
json.dump(signatures, f)
shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
if depth >= max_depth:
break
Expand All @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path):
root_ca_public_key = root_ca_cert.public_key()
for root, folders, files in os.walk(src_folder):
try:
signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f:
signatures = json.load(f)
cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
public_key = cert.public_key()
except:
Expand All @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path):
continue
signature = signatures.get(file)
if signature:
verify_content(
content=open(os.path.join(root, file), "rb").read(),
signature=signature,
public_key=public_key,
)
with open(os.path.join(root, file), "rb") as f:
verify_content(
content=f.read(),
signature=signature,
public_key=public_key,
)
for folder in folders:
signature = signatures.get(folder)
if signature:
Expand All @@ -173,20 +178,52 @@ def sign_all(content_folder, signing_pri_key):
for f in os.listdir(content_folder):
path = os.path.join(content_folder, f)
if os.path.isfile(path):
signatures[f] = sign_content(
content=open(path, "rb").read(),
signing_pri_key=signing_pri_key,
)
with open(path, "rb") as file:
signatures[f] = sign_content(
content=file.read(),
signing_pri_key=signing_pri_key,
)
return signatures


def load_yaml(file):

root = os.path.split(file)[0]
yaml_data = None
if isinstance(file, str):
return yaml.safe_load(open(file, "r"))
with open(file, "r") as f:
yaml_data = yaml.safe_load(f)
elif isinstance(file, bytes):
return yaml.safe_load(file)
else:
return None
yaml_data = yaml.safe_load(file)

yaml_data = load_yaml_include(root, yaml_data)

return yaml_data


def load_yaml_include(root, yaml_data):
new_data = {}
for k, v in yaml_data.items():
if k == "include":
if isinstance(v, str):
includes = [v]
elif isinstance(v, list):
includes = v
for item in includes:
new_data.update(load_yaml(os.path.join(root, item)))
elif isinstance(v, list):
new_list = []
for item in v:
if isinstance(item, dict):
item = load_yaml_include(root, item)
new_list.append(item)
new_data[k] = new_list
elif isinstance(v, dict):
new_data[k] = load_yaml_include(root, v)
else:
new_data[k] = v

return new_data


def sh_replace(src, mapping_dict):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit_test/lighter/0.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
api_version: 3
name: example_project

include: 1.yml

participants:
- name: server
port: 123
include: [1.yml]
extra:
location: "east"
include: 3.yml
- name: client
port: 234
include: 2.yml
1 change: 1 addition & 0 deletions tests/unit_test/lighter/1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
server_name: server
1 change: 1 addition & 0 deletions tests/unit_test/lighter/2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client_name: client-1
2 changes: 2 additions & 0 deletions tests/unit_test/lighter/3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
size: 4
gpus: large
20 changes: 19 additions & 1 deletion tests/unit_test/lighter/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from cryptography.x509.oid import NameOID

from nvflare.lighter.impl.cert import serialize_cert
from nvflare.lighter.utils import sign_folders, verify_folder_signature
from nvflare.lighter.utils import load_yaml, sign_folders, verify_folder_signature

folders = ["folder1", "folder2"]
files = ["file1", "file2"]
Expand Down Expand Up @@ -144,3 +144,21 @@ def test_verify_updated_folder(self):
os.unlink("client.crt")
os.unlink("root.crt")
shutil.rmtree(folder)

def _get_participant(self, name, participants):
for p in participants:
if p.get("name") == name:
return p

def test_load_yaml(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
data = load_yaml(os.path.join(dir_path, "0.yml"))

assert data.get("server_name") == "server"

participant = self._get_participant("server", data.get("participants"))
assert participant.get("server_name") == "server"
assert participant.get("extra").get("gpus") == "large"

participant = self._get_participant("client", data.get("participants"))
assert participant.get("client_name") == "client-1"

0 comments on commit 288e790

Please sign in to comment.