Skip to content

Commit

Permalink
Enable simulator to run HE (NVIDIA#2339)
Browse files Browse the repository at this point in the history
* Enable simulator to run HE.

* fixed the unittest.

* Created startup folder for simulator run if not exist.

* Changed to use setup and teardown for pytest.

* extract common codes init_security_content_service().

* removed no use import.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
3 people authored Apr 10, 2024
1 parent 656cec6 commit d72258c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
4 changes: 3 additions & 1 deletion nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer
from nvflare.private.fed.app.utils import kill_child_processes
from nvflare.private.fed.app.utils import init_security_content_service, kill_child_processes
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.server.job_meta_validator import JobMetaValidator
from nvflare.private.fed.simulator.simulator_app_runner import SimulatorServerAppRunner
Expand Down Expand Up @@ -153,6 +153,8 @@ def setup(self):
AuthorizationService.initialize(EmptyAuthorizer())
AuditService.the_auditor = SimulatorAuditor()

init_security_content_service(self.args.workspace)

self.simulator_root = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME)
if os.path.exists(self.simulator_root):
shutil.rmtree(self.simulator_root)
Expand Down
4 changes: 3 additions & 1 deletion nvflare/private/fed/app/simulator/simulator_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nvflare.fuel.hci.server.authz import AuthorizationService
from nvflare.fuel.sec.audit import AuditService
from nvflare.private.fed.app.deployer.base_client_deployer import BaseClientDeployer
from nvflare.private.fed.app.utils import check_parent_alive
from nvflare.private.fed.app.utils import check_parent_alive, init_security_content_service
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.client.fed_client import FederatedClient
Expand Down Expand Up @@ -241,6 +241,8 @@ def main(args):
# AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG)
AuditService.the_auditor = SimulatorAuditor()

init_security_content_service(args.workspace)

if args.gpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
Expand Down
10 changes: 9 additions & 1 deletion nvflare/private/fed/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import psutil

from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_constant import FLContextKey, WorkspaceConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import UnsafeComponentError
from nvflare.apis.workspace import Workspace
from nvflare.fuel.hci.security import hash_password
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.private.defs import SSLConstants
from nvflare.private.fed.runner import Runner
from nvflare.private.fed.server.admin import FedAdminServer
Expand Down Expand Up @@ -103,6 +105,12 @@ def version_check():
raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10")


def init_security_content_service(workspace_dir):
os.makedirs(os.path.join(workspace_dir, WorkspaceConstants.STARTUP_FOLDER_NAME), exist_ok=True)
workspace_obj = Workspace(root_dir=workspace_dir)
SecurityContentService.initialize(content_folder=workspace_obj.get_startup_kit_dir())


def component_security_check(fl_ctx: FLContext):
exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS)
if exceptions:
Expand Down
20 changes: 14 additions & 6 deletions tests/unit_test/private/fed/app/simulator/simulator_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import os
import shutil
import uuid
from unittest.mock import patch

import pytest

from nvflare.apis.fl_constant import WorkspaceConstants
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner
from nvflare.private.fed.utils.fed_utils import split_gpus

Expand All @@ -28,14 +30,22 @@ def get_root_url_for_child(self):


class TestSimulatorRunner:
def setup_method(self, method):
self.workspace_name = str(uuid.uuid4())
self.cwd = os.getcwd()
os.makedirs(os.path.join(self.cwd, self.workspace_name, WorkspaceConstants.STARTUP_FOLDER_NAME))

def teardown_method(self, method):
os.chdir(self.cwd)
shutil.rmtree(os.path.join(self.cwd, self.workspace_name))

@patch("nvflare.private.fed.app.deployer.simulator_deployer.SimulatorServer.deploy")
@patch("nvflare.private.fed.app.utils.FedAdminServer")
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, threads=1)
assert runner.setup()

expected_clients = ["site-1", "site-2"]
Expand All @@ -49,9 +59,8 @@ def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register,
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, clients="site-1", threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, clients="site-1", threads=1)
assert runner.setup()

expected_clients = ["site-1"]
Expand All @@ -65,9 +74,8 @@ def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_c
@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
@patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell())
def test_no_app_for_client(self, mock_deploy, mock_admin, mock_register, mock_cell):
workspace_name = str(uuid.uuid4())
job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job")
runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, n_clients=3, threads=1)
runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, n_clients=3, threads=1)
assert not runner.setup()

@pytest.mark.parametrize(
Expand Down

0 comments on commit d72258c

Please sign in to comment.