Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support extra provision builder generated component files #3056

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ class WorkspaceConstants:

ADMIN_STARTUP_CONFIG = "fed_admin.json"

RESOURCE_FILE_NAME_PATTERN = "*__resources.json" # for both parent and job processes
JOB_RESOURCE_FILE_NAME_PATTERN = "*__j_resources.json" # for job process only
PARENT_RESOURCE_FILE_NAME_PATTERN = "*__p_resources.json" # for parent process only


class SiteType:
SERVER = "server"
Expand Down
56 changes: 56 additions & 0 deletions nvflare/apis/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
from typing import List, Union

Expand Down Expand Up @@ -178,3 +179,58 @@ def get_stats_pool_records_path(self, job_id: str, prefix=None) -> str:
if prefix:
file_name = f"{prefix}.{file_name}"
return os.path.join(self.get_run_dir(job_id), file_name)

def get_config_files_for_startup(self, is_server: bool, for_job: bool) -> list:
"""Get all config files to be used for startup of the process (SP, SJ, CP, CJ).

We first get required config files:
- the startup file (fed_server.json or fed_client.json) in "startup" folder
- resource file (resources.json.default or resources.json) in "local" folder
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved

We then try to get resources files (usually generated by different builders of the Provision system):
- resources files from the "startup" folder take precedence
- resources files from the "local" folder are next

These extra resource config files must be json and follow the following patterns:
- *__resources.json: these files are for both parent process and job processes
- *__p_resources.json: these files are for parent process only
- *__j_resources.json: these files are for job process only

Args:
is_server: whether this is for server site or client site
for_job: whether this is for job process or parent process

Returns: a list of config file names

"""
if is_server:
startup_file_path = self.get_server_startup_file_path()
else:
startup_file_path = self.get_client_startup_file_path()

resource_config_path = self.get_resources_file_path()
config_files = [startup_file_path, resource_config_path]
if for_job:
# this is for job process
job_resources_file_path = self.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)

# add other resource config files
patterns = [WorkspaceConstants.RESOURCE_FILE_NAME_PATTERN]
if for_job:
patterns.append(WorkspaceConstants.JOB_RESOURCE_FILE_NAME_PATTERN)
else:
patterns.append(WorkspaceConstants.PARENT_RESOURCE_FILE_NAME_PATTERN)

# add startup files first, then local files
self._add_resource_files(self.get_startup_kit_dir(), config_files, patterns)
self._add_resource_files(self.get_site_config_dir(), config_files, patterns)
return config_files

@staticmethod
def _add_resource_files(from_dir: str, to_list: list, patterns: [str]):
for p in patterns:
files = glob.glob(os.path.join(from_dir, p))
if files:
to_list.extend(files)
58 changes: 44 additions & 14 deletions nvflare/fuel/sec/security_content_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import json
import os
from base64 import b64decode
Expand Down Expand Up @@ -98,29 +99,29 @@ class SecurityContentService(object):
"""Uses SecurityContentManager to load secure content."""

security_content_manager = None
content_folder = None

@staticmethod
def initialize(content_folder: str, signature_filename="signature.json", root_cert="rootCA.pem"):
if SecurityContentService.security_content_manager is None:
SecurityContentService.security_content_manager = SecurityContentManager(
content_folder, signature_filename, root_cert
)
@classmethod
def initialize(cls, content_folder: str, signature_filename="signature.json", root_cert="rootCA.pem"):
if cls.security_content_manager is None:
cls.content_folder = content_folder
cls.security_content_manager = SecurityContentManager(content_folder, signature_filename, root_cert)

@staticmethod
def load_content(file_under_verification):
if not SecurityContentService.security_content_manager:
@classmethod
def load_content(cls, file_under_verification):
if not cls.security_content_manager:
return None, LoadResult.NOT_MANAGED

return SecurityContentService.security_content_manager.load_content(file_under_verification)
return cls.security_content_manager.load_content(file_under_verification)

@staticmethod
def load_json(file_under_verification):
if not SecurityContentService.security_content_manager:
@classmethod
def load_json(cls, file_under_verification):
if not cls.security_content_manager:
return None, LoadResult.NOT_MANAGED

json_data = None

data_bytes, result = SecurityContentService.security_content_manager.load_content(file_under_verification)
data_bytes, result = cls.security_content_manager.load_content(file_under_verification)

if data_bytes:
try:
Expand All @@ -130,3 +131,32 @@ def load_json(file_under_verification):
return None, LoadResult.INVALID_CONTENT

return json_data, result

@classmethod
def check_json_files(cls, patterns: [str]) -> [str]:
"""Check JSON files that match the specified patterns

Args:
patterns: the patterns to be checked

Returns: full paths of invalid files if any.
A file is considered invalid in any of the cases:
- The file is not signed
- The file does not match signature

"""
bad_files = []
if not cls.security_content_manager:
return bad_files

if not patterns:
return bad_files

for p in patterns:
files = glob.glob(os.path.join(cls.content_folder, p))
if files:
for f in files:
_, result = cls.load_json(os.path.basename(f))
if result != LoadResult.OK:
bad_files.append(f)
return bad_files
15 changes: 6 additions & 9 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import sys
import threading

from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, JobConstants, SystemConfigs
from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, JobConstants, SiteType, SystemConfigs
from nvflare.apis.overseer_spec import SP
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import EngineConstant
Expand All @@ -38,6 +36,8 @@
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
security_close,
security_init_for_job,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception
Expand Down Expand Up @@ -72,12 +72,9 @@ def main(args):
os.remove(restart_file)

fobs_initialize(workspace=workspace, job_id=args.job_id)
# Initialize audit service since the job execution will need it!
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)

# print("starting the client .....")
SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())
# initialize security processing and ensure that content in the startup has not been tampered with.
security_init_for_job(secure_train, workspace, SiteType.CLIENT)

thread = None
stop_event = threading.Event()
Expand Down Expand Up @@ -139,7 +136,7 @@ def main(args):
stop_event.set()
if thread and thread.is_alive():
thread.join()
AuditService.close()
security_close()
err = create_stats_pool_files_for_job(workspace, args.job_id)
if err:
logger.warning(err)
Expand Down
19 changes: 2 additions & 17 deletions nvflare/private/fed/app/fl_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):

configure_logging(workspace)

server_startup_file_path = workspace.get_server_startup_file_path()
resource_config_path = workspace.get_resources_file_path()
config_files = [server_startup_file_path, resource_config_path]
if args.job_id:
# this is for job process
job_resources_file_path = workspace.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)
config_files = workspace.get_config_files_for_startup(is_server=True, for_job=True if args.job_id else False)

JsonConfigurator.__init__(
self,
Expand Down Expand Up @@ -235,15 +228,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):

configure_logging(workspace)

client_startup_file_path = workspace.get_client_startup_file_path()
resources_file_path = workspace.get_resources_file_path()
config_files = [client_startup_file_path, resources_file_path]

if args.job_id:
# this is for job process
job_resources_file_path = workspace.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)
config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False)

JsonConfigurator.__init__(
self,
Expand Down
17 changes: 7 additions & 10 deletions nvflare/private/fed/app/server/runner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import sys
import threading

from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SystemConfigs
from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
Expand All @@ -38,6 +36,8 @@
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
security_close,
security_init_for_job,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception, secure_log_traceback
Expand All @@ -63,16 +63,14 @@ def main(args):
stop_event = threading.Event()
workspace = Workspace(root_dir=args.workspace, site_name="server")
set_stats_pool_config_for_job(workspace, args.job_id)
secure_train = kv_list.get("secure_train", False)

try:
os.chdir(args.workspace)
fobs_initialize(workspace=workspace, job_id=args.job_id)

SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())

# Initialize audit service since the job execution will need it!
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)
# initialize security processing and ensure that content in the startup has not been tampered with.
security_init_for_job(secure_train, workspace, SiteType.SERVER)

conf = FLServerStarterConfiger(
workspace=workspace,
Expand All @@ -97,7 +95,6 @@ def main(args):
conf.configure()
event_handlers = conf.handlers
deployer = conf.deployer
secure_train = conf.cmd_vars.get("secure_train", False)

decomposer_module = ConfigService.get_str_var(
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
Expand Down Expand Up @@ -130,7 +127,7 @@ def main(args):
if deployer:
deployer.close()
stop_event.set()
AuditService.close()
security_close()
err = create_stats_pool_files_for_job(workspace, args.job_id)
if err:
logger.warning(err)
Expand Down
39 changes: 39 additions & 0 deletions nvflare/private/fed/utils/fed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ def _check_secure_content(site_type: str) -> List[str]:
if sig != LoadResult.OK:
insecure_list.append(WorkspaceConstants.AUTHORIZATION_CONFIG)

# every resource file in the startup must be signed and not tampered with!
bad_files = SecurityContentService.check_json_files(
[
WorkspaceConstants.RESOURCE_FILE_NAME_PATTERN,
WorkspaceConstants.PARENT_RESOURCE_FILE_NAME_PATTERN,
WorkspaceConstants.JOB_RESOURCE_FILE_NAME_PATTERN,
]
)

if bad_files:
insecure_list.extend(bad_files)

return insecure_list


Expand Down Expand Up @@ -186,6 +198,33 @@ def security_init(secure_train: bool, site_org: str, workspace: Workspace, app_v
sys.exit(1)


def security_init_for_job(secure_train: bool, workspace: Workspace, site_type: str):
"""Initialize security processing for a job process (SJ or CJ).

Args:
secure_train (bool): if run in secure mode or not.
workspace: the workspace object.
site_type (str): server or client. fed_client.json or fed_server.json
"""
# initialize the SecurityContentService.
# must do this before initializing other services since it may be needed by them!
startup_dir = workspace.get_startup_kit_dir()
SecurityContentService.initialize(content_folder=startup_dir)

if secure_train:
insecure_list = _check_secure_content(site_type=site_type)
if len(insecure_list):
print("The following files are not secure content.")
for item in insecure_list:
print(item)
sys.exit(1)
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved

# initialize the AuditService, which is used by command processing.
# The Audit Service can be used in other places as well.
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)


def security_close():
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
AuditService.close()

Expand Down
Loading