Skip to content

Commit

Permalink
Merge branch 'main' into compress_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Nov 6, 2024
2 parents f1a811e + dc6598e commit 72c1401
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 50 deletions.
4 changes: 4 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ class WorkspaceConstants:

ADMIN_STARTUP_CONFIG = "fed_admin.json"

RESOURCE_FILE_NAME_PATTERN = "*__resources.json" # for both parent and job processes
JOB_RESOURCE_FILE_NAME_PATTERN = "*__j_resources.json" # for job process only
PARENT_RESOURCE_FILE_NAME_PATTERN = "*__p_resources.json" # for parent process only


class SiteType:
SERVER = "server"
Expand Down
56 changes: 56 additions & 0 deletions nvflare/apis/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
from typing import List, Union

Expand Down Expand Up @@ -178,3 +179,58 @@ def get_stats_pool_records_path(self, job_id: str, prefix=None) -> str:
if prefix:
file_name = f"{prefix}.{file_name}"
return os.path.join(self.get_run_dir(job_id), file_name)

def get_config_files_for_startup(self, is_server: bool, for_job: bool) -> list:
"""Get all config files to be used for startup of the process (SP, SJ, CP, CJ).
We first get required config files:
- the startup file (fed_server.json or fed_client.json) in "startup" folder
- resource file (resources.json.default or resources.json) in "local" folder
We then try to get resources files (usually generated by different builders of the Provision system):
- resources files from the "startup" folder take precedence
- resources files from the "local" folder are next
These extra resource config files must be json and follow the following patterns:
- *__resources.json: these files are for both parent process and job processes
- *__p_resources.json: these files are for parent process only
- *__j_resources.json: these files are for job process only
Args:
is_server: whether this is for server site or client site
for_job: whether this is for job process or parent process
Returns: a list of config file names
"""
if is_server:
startup_file_path = self.get_server_startup_file_path()
else:
startup_file_path = self.get_client_startup_file_path()

resource_config_path = self.get_resources_file_path()
config_files = [startup_file_path, resource_config_path]
if for_job:
# this is for job process
job_resources_file_path = self.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)

# add other resource config files
patterns = [WorkspaceConstants.RESOURCE_FILE_NAME_PATTERN]
if for_job:
patterns.append(WorkspaceConstants.JOB_RESOURCE_FILE_NAME_PATTERN)
else:
patterns.append(WorkspaceConstants.PARENT_RESOURCE_FILE_NAME_PATTERN)

# add startup files first, then local files
self._add_resource_files(self.get_startup_kit_dir(), config_files, patterns)
self._add_resource_files(self.get_site_config_dir(), config_files, patterns)
return config_files

@staticmethod
def _add_resource_files(from_dir: str, to_list: list, patterns: [str]):
for p in patterns:
files = glob.glob(os.path.join(from_dir, p))
if files:
to_list.extend(files)
58 changes: 44 additions & 14 deletions nvflare/fuel/sec/security_content_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import json
import os
from base64 import b64decode
Expand Down Expand Up @@ -98,29 +99,29 @@ class SecurityContentService(object):
"""Uses SecurityContentManager to load secure content."""

security_content_manager = None
content_folder = None

@staticmethod
def initialize(content_folder: str, signature_filename="signature.json", root_cert="rootCA.pem"):
if SecurityContentService.security_content_manager is None:
SecurityContentService.security_content_manager = SecurityContentManager(
content_folder, signature_filename, root_cert
)
@classmethod
def initialize(cls, content_folder: str, signature_filename="signature.json", root_cert="rootCA.pem"):
if cls.security_content_manager is None:
cls.content_folder = content_folder
cls.security_content_manager = SecurityContentManager(content_folder, signature_filename, root_cert)

@staticmethod
def load_content(file_under_verification):
if not SecurityContentService.security_content_manager:
@classmethod
def load_content(cls, file_under_verification):
if not cls.security_content_manager:
return None, LoadResult.NOT_MANAGED

return SecurityContentService.security_content_manager.load_content(file_under_verification)
return cls.security_content_manager.load_content(file_under_verification)

@staticmethod
def load_json(file_under_verification):
if not SecurityContentService.security_content_manager:
@classmethod
def load_json(cls, file_under_verification):
if not cls.security_content_manager:
return None, LoadResult.NOT_MANAGED

json_data = None

data_bytes, result = SecurityContentService.security_content_manager.load_content(file_under_verification)
data_bytes, result = cls.security_content_manager.load_content(file_under_verification)

if data_bytes:
try:
Expand All @@ -130,3 +131,32 @@ def load_json(file_under_verification):
return None, LoadResult.INVALID_CONTENT

return json_data, result

@classmethod
def check_json_files(cls, patterns: [str]) -> [str]:
"""Check JSON files that match the specified patterns
Args:
patterns: the patterns to be checked
Returns: full paths of invalid files if any.
A file is considered invalid in any of the cases:
- The file is not signed
- The file does not match signature
"""
bad_files = []
if not cls.security_content_manager:
return bad_files

if not patterns:
return bad_files

for p in patterns:
files = glob.glob(os.path.join(cls.content_folder, p))
if files:
for f in files:
_, result = cls.load_json(os.path.basename(f))
if result != LoadResult.OK:
bad_files.append(f)
return bad_files
15 changes: 6 additions & 9 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import sys
import threading

from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, JobConstants, SystemConfigs
from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, JobConstants, SiteType, SystemConfigs
from nvflare.apis.overseer_spec import SP
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import EngineConstant
Expand All @@ -38,6 +36,8 @@
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
security_close,
security_init_for_job,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception
Expand Down Expand Up @@ -72,12 +72,9 @@ def main(args):
os.remove(restart_file)

fobs_initialize(workspace=workspace, job_id=args.job_id)
# Initialize audit service since the job execution will need it!
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)

# print("starting the client .....")
SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())
# initialize security processing and ensure that content in the startup has not been tampered with.
security_init_for_job(secure_train, workspace, SiteType.CLIENT)

thread = None
stop_event = threading.Event()
Expand Down Expand Up @@ -139,7 +136,7 @@ def main(args):
stop_event.set()
if thread and thread.is_alive():
thread.join()
AuditService.close()
security_close()
err = create_stats_pool_files_for_job(workspace, args.job_id)
if err:
logger.warning(err)
Expand Down
19 changes: 2 additions & 17 deletions nvflare/private/fed/app/fl_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):

configure_logging(workspace)

server_startup_file_path = workspace.get_server_startup_file_path()
resource_config_path = workspace.get_resources_file_path()
config_files = [server_startup_file_path, resource_config_path]
if args.job_id:
# this is for job process
job_resources_file_path = workspace.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)
config_files = workspace.get_config_files_for_startup(is_server=True, for_job=True if args.job_id else False)

JsonConfigurator.__init__(
self,
Expand Down Expand Up @@ -235,15 +228,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):

configure_logging(workspace)

client_startup_file_path = workspace.get_client_startup_file_path()
resources_file_path = workspace.get_resources_file_path()
config_files = [client_startup_file_path, resources_file_path]

if args.job_id:
# this is for job process
job_resources_file_path = workspace.get_job_resources_file_path()
if os.path.exists(job_resources_file_path):
config_files.append(job_resources_file_path)
config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False)

JsonConfigurator.__init__(
self,
Expand Down
17 changes: 7 additions & 10 deletions nvflare/private/fed/app/server/runner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import sys
import threading

from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SystemConfigs
from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
Expand All @@ -38,6 +36,8 @@
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
security_close,
security_init_for_job,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception, secure_log_traceback
Expand All @@ -63,16 +63,14 @@ def main(args):
stop_event = threading.Event()
workspace = Workspace(root_dir=args.workspace, site_name="server")
set_stats_pool_config_for_job(workspace, args.job_id)
secure_train = kv_list.get("secure_train", False)

try:
os.chdir(args.workspace)
fobs_initialize(workspace=workspace, job_id=args.job_id)

SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())

# Initialize audit service since the job execution will need it!
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)
# initialize security processing and ensure that content in the startup has not been tampered with.
security_init_for_job(secure_train, workspace, SiteType.SERVER)

conf = FLServerStarterConfiger(
workspace=workspace,
Expand All @@ -97,7 +95,6 @@ def main(args):
conf.configure()
event_handlers = conf.handlers
deployer = conf.deployer
secure_train = conf.cmd_vars.get("secure_train", False)

decomposer_module = ConfigService.get_str_var(
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
Expand Down Expand Up @@ -130,7 +127,7 @@ def main(args):
if deployer:
deployer.close()
stop_event.set()
AuditService.close()
security_close()
err = create_stats_pool_files_for_job(workspace, args.job_id)
if err:
logger.warning(err)
Expand Down
39 changes: 39 additions & 0 deletions nvflare/private/fed/utils/fed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ def _check_secure_content(site_type: str) -> List[str]:
if sig != LoadResult.OK:
insecure_list.append(WorkspaceConstants.AUTHORIZATION_CONFIG)

# every resource file in the startup must be signed and not tampered with!
bad_files = SecurityContentService.check_json_files(
[
WorkspaceConstants.RESOURCE_FILE_NAME_PATTERN,
WorkspaceConstants.PARENT_RESOURCE_FILE_NAME_PATTERN,
WorkspaceConstants.JOB_RESOURCE_FILE_NAME_PATTERN,
]
)

if bad_files:
insecure_list.extend(bad_files)

return insecure_list


Expand Down Expand Up @@ -186,6 +198,33 @@ def security_init(secure_train: bool, site_org: str, workspace: Workspace, app_v
sys.exit(1)


def security_init_for_job(secure_train: bool, workspace: Workspace, site_type: str):
"""Initialize security processing for a job process (SJ or CJ).
Args:
secure_train (bool): if run in secure mode or not.
workspace: the workspace object.
site_type (str): server or client. fed_client.json or fed_server.json
"""
# initialize the SecurityContentService.
# must do this before initializing other services since it may be needed by them!
startup_dir = workspace.get_startup_kit_dir()
SecurityContentService.initialize(content_folder=startup_dir)

if secure_train:
insecure_list = _check_secure_content(site_type=site_type)
if len(insecure_list):
print("The following files are not secure content.")
for item in insecure_list:
print(item)
sys.exit(1)

# initialize the AuditService, which is used by command processing.
# The Audit Service can be used in other places as well.
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)


def security_close():
AuditService.close()

Expand Down

0 comments on commit 72c1401

Please sign in to comment.