diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 8b833205db..3b609c9b0d 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -162,6 +162,7 @@ class FLContextKey(object): RECONNECTED_CLIENT_NAME = "_reconnected_client_name" SITE_OBJ = "_site_obj_" JOB_LAUNCHER = "_job_launcher" + SNAPSHOT = "job_snapshot" CLIENT_REGISTER_DATA = "_client_register_data" SECURITY_ITEMS = "_security_items" diff --git a/nvflare/app_common/job_launcher/client_process_launcher.py b/nvflare/app_common/job_launcher/client_process_launcher.py new file mode 100644 index 0000000000..aecd1c8502 --- /dev/null +++ b/nvflare/app_common/job_launcher/client_process_launcher.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +from nvflare.apis.fl_constant import FLContextKey, JobConstants +from nvflare.apis.workspace import Workspace +from nvflare.app_common.job_launcher.process_launcher import ProcessJobLauncher +from nvflare.private.fed.utils.fed_utils import add_custom_dir_to_path + + +class ClientProcessJobLauncher(ProcessJobLauncher): + def get_command(self, launch_data, fl_ctx) -> (str, dict): + new_env = os.environ.copy() + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + job_id = launch_data.get(JobConstants.JOB_ID) + server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) + if not server_config: + raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") + service = server_config[0].get("service", {}) + if not isinstance(service, dict): + raise RuntimeError(f"expect server config data to be dict but got {type(service)}") + + app_custom_folder = workspace_obj.get_app_custom_dir(job_id) + if app_custom_folder != "": + add_custom_dir_to_path(app_custom_folder, new_env) + + command_options = "" + for t in args.set: + command_options += " " + t + command = ( + f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m " + + args.workspace + + " -w " + + (workspace_obj.get_startup_kit_dir()) + + " -t " + + client.token + + " -d " + + client.ssid + + " -n " + + job_id + + " -c " + + client.client_name + + " -p " + + str(client.cell.get_internal_listener_url()) + + " -g " + + service.get("target") + + " -scheme " + + service.get("scheme", "grpc") + + " -s fed_client.json " + " --set" + command_options + " print_conf=True" + ) + return command, new_env diff --git a/nvflare/app_common/job_launcher/process_launcher.py b/nvflare/app_common/job_launcher/process_launcher.py index 912893feff..c374f61013 100644 --- a/nvflare/app_common/job_launcher/process_launcher.py +++ b/nvflare/app_common/job_launcher/process_launcher.py @@ -15,15 +15,13 @@ import os import shlex import subprocess -import sys +from abc import abstractmethod from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -from nvflare.apis.job_def import JobMetaKey from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobReturnCode, add_launcher -from nvflare.apis.workspace import Workspace -from nvflare.private.fed.utils.fed_utils import add_custom_dir_to_path, extract_job_image +from nvflare.private.fed.utils.fed_utils import extract_job_image JOB_RETURN_CODE_MAPPING = {0: JobReturnCode.SUCCESS, 1: JobReturnCode.EXECUTION_ERROR, 9: JobReturnCode.ABORTED} @@ -64,51 +62,11 @@ def __init__(self): def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: - new_env = os.environ.copy() - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) - job_id = job_meta.get(JobMetaKey.JOB_ID) - server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) - if not server_config: - raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") - service = server_config[0].get("service", {}) - if not isinstance(service, dict): - raise RuntimeError(f"expect server config data to be dict but got {type(service)}") - - app_custom_folder = workspace_obj.get_app_custom_dir(job_id) - if app_custom_folder != "": - add_custom_dir_to_path(app_custom_folder, new_env) - - command_options = "" - for t in args.set: - command_options += " " + t - command = ( - f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m " - + args.workspace - + " -w " - + (workspace_obj.get_startup_kit_dir()) - + " -t " - + client.token - + " -d " - + client.ssid - + " -n " - + job_id - + " -c " - + client.client_name - + " -p " - + str(client.cell.get_internal_listener_url()) - + " -g " - + service.get("target") - + " -scheme " - + service.get("scheme", "grpc") - + " -s fed_client.json " - " --set" + command_options + " print_conf=True" - ) + command, new_env = self.get_command(job_meta, fl_ctx) # use os.setsid to create new process group ID process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) - self.logger.info("Worker child process ID: {}".format(process.pid)) + self.logger.info("Launch the job in process ID: {}".format(process.pid)) return ProcessHandle(process) @@ -118,3 +76,17 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) if not job_image: add_launcher(self, fl_ctx) + + @abstractmethod + def get_command(self, launch_data, fl_ctx) -> (str, dict): + """To generate the command to launcher the job in sub-process + + Args: + fl_ctx: FLContext + launch_data: job launcher data + + Returns: + launch command, environment dict + + """ + pass diff --git a/nvflare/app_common/job_launcher/server_process_launcher.py b/nvflare/app_common/job_launcher/server_process_launcher.py new file mode 100644 index 0000000000..a7197f7aae --- /dev/null +++ b/nvflare/app_common/job_launcher/server_process_launcher.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +from nvflare.apis.fl_constant import FLContextKey, JobConstants +from nvflare.apis.workspace import Workspace +from nvflare.app_common.job_launcher.process_launcher import ProcessJobLauncher +from nvflare.private.fed.utils.fed_utils import add_custom_dir_to_path + + +class ServerProcessJobLauncher(ProcessJobLauncher): + def get_command(self, launch_data, fl_ctx) -> (str, dict): + new_env = os.environ.copy() + + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + server = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + job_id = launch_data.get(JobConstants.JOB_ID) + restore_snapshot = fl_ctx.get_prop(FLContextKey.SNAPSHOT, False) + + app_root = workspace_obj.get_app_dir(job_id) + cell = server.cell + server_state = server.server_state + + app_custom_folder = workspace_obj.get_app_custom_dir(job_id) + if app_custom_folder != "": + add_custom_dir_to_path(app_custom_folder, new_env) + + command_options = "" + for t in args.set: + command_options += " " + t + + command = ( + sys.executable + + " -m nvflare.private.fed.app.server.runner_process -m " + + args.workspace + + " -s fed_server.json -r " + + app_root + + " -n " + + str(job_id) + + " -p " + + str(cell.get_internal_listener_url()) + + " -u " + + str(cell.get_root_url_for_child()) + + " --host " + + str(server_state.host) + + " --port " + + str(server_state.service_port) + + " --ssid " + + str(server_state.ssid) + + " --ha_mode " + + str(server.ha_mode) + + " --set" + + command_options + + " print_conf=True restore_snapshot=" + + str(restore_snapshot) + ) + + return command, new_env diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index bba0df01e0..de939cb4c3 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import time +from abc import abstractmethod from enum import Enum from kubernetes import config @@ -82,7 +83,7 @@ def __init__(self, job_id: str, api_instance: core_v1_api, job_config: dict, nam "imagePullPolicy": "Always", } ] - self.container_args_python_args_list = ["-u", "-m", "nvflare.private.fed.app.client.worker_process"] + self.container_args_python_args_list = ["-u", "-m", job_config.get("command")] self.container_args_module_args_dict = { "-m": None, "-w": None, @@ -218,39 +219,19 @@ def __init__( def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) job_id = job_meta.get(JobConstants.JOB_ID) - server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) - if not server_config: - raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") - service = server_config[0].get("service", {}) - if not isinstance(service, dict): - raise RuntimeError(f"expect server config data to be dict but got {type(service)}") - - self.logger.info(f"K8sJobLauncher start to launch job: {job_id} for client: {client.client_name}") + args = fl_ctx.get_prop(FLContextKey.ARGS) job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) self.logger.info(f"launch job use image: {job_image}") job_config = { "name": job_id, "image": job_image, "container_name": f"container-{job_id}", + "command": self.get_command(), "volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}], "volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}], - "module_args": { - "-m": args.workspace, - "-w": (workspace_obj.get_startup_kit_dir()), - "-t": client.token, - "-d": client.ssid, - "-n": job_id, - "-c": client.client_name, - "-p": "tcp://parent-pod:8004", - "-g": service.get("target"), - "-scheme": service.get("scheme", "grpc"), - "-s": "fed_client.json", - }, - "set_list": args.set, + "module_args": self.get_module_args(job_id, fl_ctx), + "set_list": self.get_set_list(args, fl_ctx), } self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}") @@ -273,3 +254,101 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) if job_image: add_launcher(self, fl_ctx) + + @abstractmethod + def get_command(self): + """To get the run command of the launcher + + Returns: the command for the launcher process + + """ + pass + + @abstractmethod + def get_module_args(self, job_id, fl_ctx: FLContext): + """To get the args to run the launcher + + Args: + job_id: run job_id + fl_ctx: FLContext + + Returns: + + """ + pass + + @abstractmethod + def get_set_list(self, args, fl_ctx: FLContext): + """To get the command set_list + + Args: + args: command args + fl_ctx: FLContext + + Returns: set_list command options + + """ + pass + + +class ClientK8sJobLauncher(K8sJobLauncher): + def get_command(self): + return "nvflare.private.fed.app.client.worker_process" + + def get_module_args(self, job_id, fl_ctx: FLContext): + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) + if not server_config: + raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") + service = server_config[0].get("service", {}) + if not isinstance(service, dict): + raise RuntimeError(f"expect server config data to be dict but got {type(service)}") + self.logger.info(f"K8sJobLauncher start to launch job: {job_id} for client: {client.client_name}") + + return { + "-m": args.workspace, + "-w": (workspace_obj.get_startup_kit_dir()), + "-t": client.token, + "-d": client.ssid, + "-n": job_id, + "-c": client.client_name, + "-p": str(client.cell.get_internal_listener_url()), + "-g": service.get("target"), + "-scheme": service.get("scheme", "grpc"), + "-s": "fed_client.json", + } + + def get_set_list(self, args, fl_ctx: FLContext): + args.set.append("print_conf=True") + return args.set + + +class ServerK8sJobLauncher(K8sJobLauncher): + def get_command(self): + return "nvflare.private.fed.app.server.runner_process" + + def get_module_args(self, job_id, fl_ctx: FLContext): + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + args = fl_ctx.get_prop(FLContextKey.ARGS) + server = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + + return { + "-m": args.workspace, + "-s": "fed_server.json", + "-r": workspace_obj.get_app_dir(), + "-n": str(job_id), + "-p": str(server.cell.get_internal_listener_url()), + "-u": str(server.cell.get_root_url_for_child()), + "--host": str(server.server_state.host), + "--port": str(server.server_state.service_port), + "--ssid": str(server.server_state.ssid), + "--ha_mode": str(server.ha_mode), + } + + def get_set_list(self, args, fl_ctx: FLContext): + restore_snapshot = fl_ctx.get_prop(FLContextKey.SNAPSHOT, False) + args.set.append("print_conf=True") + args.set.append("restore_snapshot=" + str(restore_snapshot)) + return args.set diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index 1265d1206e..7b1e51af88 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -95,7 +95,7 @@ local_client_resources: | }, { "id": "process_launcher", - "path": "nvflare.app_common.job_launcher.process_launcher.ProcessJobLauncher", + "path": "nvflare.app_common.job_launcher.client_process_launcher.ClientProcessJobLauncher", "args": {} } ] @@ -221,6 +221,11 @@ local_server_resources: | { "id": "job_store", "path": "nvflare.app_common.storages.filesystem_storage.FilesystemStorage" + }, + { + "id": "process_launcher", + "path": "nvflare.app_common.job_launcher.server_process_launcher.ServerProcessJobLauncher", + "args": {} } ] } diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index f472a8c955..de17259c0c 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -124,6 +124,8 @@ def deploy(self, args): with services.engine.new_context() as fl_ctx: fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) + fl_ctx.set_prop(FLContextKey.ARGS, args, private=True, sticky=True) + fl_ctx.set_prop(FLContextKey.SITE_OBJ, services, private=True, sticky=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) component_security_check(fl_ctx) diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index 6a8a111239..898dd2e93c 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -27,7 +27,7 @@ from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.utils.config_service import ConfigService from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message -from nvflare.private.fed.utils.fed_utils import get_return_code +from nvflare.private.fed.utils.fed_utils import get_job_launcher, get_return_code from nvflare.security.logging import secure_format_exception, secure_log_traceback from .client_status import ClientStatus, get_status_message @@ -165,7 +165,7 @@ def start_app( fl_ctx: FLContext """ - job_launcher: JobLauncherSpec = self._get_job_launcher(job_meta, fl_ctx) + job_launcher: JobLauncherSpec = get_job_launcher(job_meta, fl_ctx) job_handle = job_launcher.launch_job(job_meta, fl_ctx) self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 89e7175ab7..d883d7024a 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -987,7 +987,6 @@ def _turn_to_hot(self): self.logger.info(f"Restore the previous snapshot. Run_number: {run_number}") with self.engine.new_context() as fl_ctx: self.engine.job_runner.restore_running_job( - run_number=run_number, job_id=job_id, job_clients=job_clients, snapshot=snapshot, diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index 7c38294027..4b8e79a5de 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -246,7 +246,7 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo """ engine = fl_ctx.get_engine() job_clients = engine.get_job_clients(client_sites) - err = engine.start_app_on_server(job_id, job=job, job_clients=job_clients) + err = engine.start_app_on_server(fl_ctx, job=job, job_clients=job_clients) if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") @@ -498,13 +498,13 @@ def _check_job_status(job_manager, job_id, job_run_status, fl_ctx: FLContext): def stop(self): self.ask_to_stop = True - def restore_running_job(self, run_number: str, job_id: str, job_clients, snapshot, fl_ctx: FLContext): + def restore_running_job(self, job_id: str, job_clients, snapshot, fl_ctx: FLContext): engine = fl_ctx.get_engine() try: job_manager = engine.get_component(SystemComponents.JOB_MANAGER) job = job_manager.get_job(jid=job_id, fl_ctx=fl_ctx) - err = engine.start_app_on_server(run_number, job=job, job_clients=job_clients, snapshot=snapshot) + err = engine.start_app_on_server(fl_ctx, job=job, job_clients=job_clients, snapshot=snapshot) if err: raise RuntimeError(f"Could not restore the server App for job: {job_id}.") with self.lock: diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index be9823414e..ee753bf848 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -16,9 +16,7 @@ import logging import os import re -import shlex import shutil -import subprocess import sys import threading import time @@ -43,10 +41,11 @@ from nvflare.apis.fl_snapshot import RunSnapshot from nvflare.apis.impl.job_def_manager import JobDefManagerSpec from nvflare.apis.job_def import Job +from nvflare.apis.job_launcher_spec import JobLauncherSpec from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.utils.fl_context_utils import get_serializable_data from nvflare.apis.workspace import Workspace -from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell +from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellMsgReturnCode from nvflare.fuel.utils.argument_utils import parse_vars @@ -55,9 +54,8 @@ from nvflare.private.aux_runner import AuxMsgTarget from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, RequestHeader, TrainingTopic, new_cell_message from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator -from nvflare.private.fed.server.server_state import ServerState from nvflare.private.fed.utils.fed_utils import ( - add_custom_dir_to_path, + get_job_launcher, get_return_code, security_close, set_message_security_data, @@ -164,32 +162,21 @@ def get_clients(self) -> [Client]: def validate_targets(self, client_names: List[str]) -> Tuple[List[Client], List[str]]: return self.client_manager.get_all_clients_from_inputs(client_names) - def start_app_on_server(self, run_number: str, job: Job = None, job_clients=None, snapshot=None) -> str: - if run_number in self.run_processes.keys(): - return f"Server run: {run_number} already started." + def start_app_on_server(self, fl_ctx: FLContext, job: Job = None, job_clients=None, snapshot=None) -> str: + if not isinstance(job, Job): + return "Must provide a job object to start the server app." + + if job.job_id in self.run_processes.keys(): + return f"Server run: {job.job_id} already started." else: workspace = Workspace(root_dir=self.args.workspace, site_name="server") - app_root = workspace.get_app_dir(run_number) + app_root = workspace.get_app_dir(job.job_id) if not os.path.exists(app_root): return "Server app does not exist. Please deploy the server app before starting." self.engine_info.status = MachineStatus.STARTING - app_custom_folder = workspace.get_app_custom_dir(run_number) - - if not isinstance(job, Job): - return "Must provide a job object to start the server app." - - self._start_runner_process( - self.args, - app_root, - run_number, - app_custom_folder, - job.job_id, - job_clients, - snapshot, - self.server.cell, - self.server.server_state, - ) + + self._start_runner_process(job, job_clients, snapshot, fl_ctx) self.engine_info.status = MachineStatus.STARTED return "" @@ -225,73 +212,30 @@ def wait_for_complete(self, workspace, job_id, process): self.run_processes.pop(job_id, None) self.engine_info.status = MachineStatus.STOPPED - def _start_runner_process( - self, - args, - app_root, - run_number, - app_custom_folder, - job_id, - job_clients, - snapshot, - cell: CoreCell, - server_state: ServerState, - ): - new_env = os.environ.copy() - if app_custom_folder != "": - add_custom_dir_to_path(app_custom_folder, new_env) - + def _start_runner_process(self, job, job_clients, snapshot, fl_ctx: FLContext): + job_launcher: JobLauncherSpec = get_job_launcher(job.meta, fl_ctx) if snapshot: restore_snapshot = True else: restore_snapshot = False - command_options = "" - for t in args.set: - command_options += " " + t - - command = ( - sys.executable - + " -m nvflare.private.fed.app.server.runner_process -m " - + args.workspace - + " -s fed_server.json -r " - + app_root - + " -n " - + str(run_number) - + " -p " - + str(cell.get_internal_listener_url()) - + " -u " - + str(cell.get_root_url_for_child()) - + " --host " - + str(server_state.host) - + " --port " - + str(server_state.service_port) - + " --ssid " - + str(server_state.ssid) - + " --ha_mode " - + str(self.server.ha_mode) - + " --set" - + command_options - + " print_conf=True restore_snapshot=" - + str(restore_snapshot) - ) - # use os.setsid to create new process group ID + fl_ctx.set_prop(FLContextKey.SNAPSHOT, restore_snapshot, private=True, sticky=False) + job_handle = job_launcher.launch_job(job.meta, fl_ctx) + self.logger.info(f"Launch job_id: {job.job_id} with job launcher: {type(job_launcher)} ") - process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) + args = fl_ctx.get_prop(FLContextKey.ARGS) - if not job_id: - job_id = "" if not job_clients: job_clients = self.client_manager.clients with self.lock: - self.run_processes[run_number] = { - RunProcessKey.JOB_HANDLE: process, - RunProcessKey.JOB_ID: job_id, + self.run_processes[job.job_id] = { + RunProcessKey.JOB_HANDLE: job_handle, + RunProcessKey.JOB_ID: job.job_id, RunProcessKey.PARTICIPANTS: job_clients, } - threading.Thread(target=self.wait_for_complete, args=[args.workspace, run_number, process]).start() - return process + threading.Thread(target=self.wait_for_complete, args=[args.workspace, job.job_id, job_handle]).start() + return job_handle def get_job_clients(self, client_sites): job_clients = {} @@ -619,7 +563,7 @@ def sync_clients_from_main_process(self): if time.time() - start >= max_wait: self.logger.critical(f"Cannot get participating clients for job {job_id} after {max_wait} seconds") - raise RuntimeError("Exiting job process: Cannot get participating clients for job {job_id}") + raise RuntimeError(f"Exiting job process: Cannot get participating clients for job {job_id}") self.logger.debug("didn't receive clients info - retry in 1 second") time.sleep(1.0) diff --git a/nvflare/private/fed/server/server_engine_internal_spec.py b/nvflare/private/fed/server/server_engine_internal_spec.py index f4faed302e..3d06567917 100644 --- a/nvflare/private/fed/server/server_engine_internal_spec.py +++ b/nvflare/private/fed/server/server_engine_internal_spec.py @@ -18,6 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.fl_constant import MachineStatus +from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import Job from nvflare.apis.job_def_manager_spec import JobDefManagerSpec from nvflare.apis.server_engine_spec import ServerEngineSpec @@ -103,7 +104,7 @@ def delete_job_id(self, job_id: str) -> str: pass @abstractmethod - def start_app_on_server(self, run_number: str, job: Job = None, job_clients=None, snapshot=None) -> str: + def start_app_on_server(self, fl_ctx: FLContext, job: Job = None, job_clients=None, snapshot=None) -> str: """Start the FL app on Server. Returns: diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index 60651700c9..6ffba97a4d 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -540,3 +540,19 @@ def get_scope_prop(scope_name: str, key: str) -> Any: check_str("key", key) data_bus = DataBus() return data_bus.get_data(_scope_prop_key(scope_name, key)) + + +def get_job_launcher(job_meta: dict, fl_ctx: FLContext) -> dict: + engine = fl_ctx.get_engine() + + with engine.new_context() as job_launcher_ctx: + # Remove the potential not cleaned up JOB_LAUNCHER + job_launcher_ctx.remove_prop(FLContextKey.JOB_LAUNCHER) + job_launcher_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + engine.fire_event(EventType.GET_JOB_LAUNCHER, job_launcher_ctx) + + job_launcher = job_launcher_ctx.get_prop(FLContextKey.JOB_LAUNCHER) + if not (job_launcher and isinstance(job_launcher, list)): + raise RuntimeError(f"There's no job launcher can handle this job: {job_meta}.") + + return job_launcher[0]