Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Feb 10, 2025
1 parent 8b7141a commit 2ad0c1b
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 42 deletions.
7 changes: 4 additions & 3 deletions nvflare/app_common/tie/cli_applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@


class CLIApplet(Applet, ABC):
def __init__(self):
def __init__(self, stop_method="kill"):
"""Constructor of CLIApplet, which runs the applet as a subprocess started with CLI command."""
Applet.__init__(self)
self.stop_method = stop_method
self._proc_mgr = None
self._start_error = False

Expand Down Expand Up @@ -55,7 +56,7 @@ def start(self, app_ctx: dict):

fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT)
try:
self._proc_mgr = start_process(cmd_desc, fl_ctx)
self._proc_mgr = start_process(cmd_desc, fl_ctx, stop_method=self.stop_method)
except Exception as ex:
self.logger.error(f"exception starting applet '{cmd_desc.cmd}': {secure_format_exception(ex)}")
self._start_error = True
Expand Down Expand Up @@ -89,8 +90,8 @@ def stop(self, timeout=0.0) -> int:
return rc
time.sleep(0.1)

self.logger.info(f"about to stop process manager: {type(mgr)}")
rc = mgr.stop()

self.logger.info(f"applet stopped: {rc=}")

if rc is None:
Expand Down
6 changes: 5 additions & 1 deletion nvflare/app_common/tie/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
# configure all clients
if not self._configure_clients(abort_signal, fl_ctx):
self.system_panic("failed to configure all clients", fl_ctx)
abort_signal.trigger(True)
return

# configure and start the connector
Expand All @@ -441,13 +442,15 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
error = f"failed to start connector: {secure_format_exception(ex)}"
self.log_error(fl_ctx, error)
self.system_panic(error, fl_ctx)
abort_signal.trigger(True)
return

self.connector.monitor(fl_ctx, self._app_stopped)

# start all clients
if not self._start_clients(abort_signal, fl_ctx):
self.system_panic("failed to start all clients", fl_ctx)
abort_signal.trigger(True)
return

# monitor client health
Expand All @@ -456,7 +459,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
while not abort_signal.triggered:
done = self._check_job_status(fl_ctx)
if done:
break
self.connector.stop(fl_ctx)
return
time.sleep(self.job_status_check_interval)

def _app_stopped(self, rc, fl_ctx: FLContext):
Expand Down
9 changes: 5 additions & 4 deletions nvflare/app_common/tie/process_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(


class ProcessManager:
def __init__(self, cmd_desc: CommandDescriptor):
def __init__(self, cmd_desc: CommandDescriptor, stop_method="kill"):
"""Constructor of ProcessManager.
ProcessManager provides methods for managing the lifecycle of a subprocess (start, stop, poll), as well
as the handling of log file to be used by the subprocess.
Expand All @@ -96,6 +96,7 @@ def __init__(self, cmd_desc: CommandDescriptor):
check_object_type("cmd_desc", cmd_desc, CommandDescriptor)
self.process = None
self.cmd_desc = cmd_desc
self.stop_method = stop_method
self.log_file = None
self.msg_prefix = None
self.file_lock = threading.Lock()
Expand Down Expand Up @@ -143,7 +144,6 @@ def start(
env=env,
stdout=subprocess.PIPE,
)

log_writer = threading.Thread(target=self._write_log, daemon=True)
log_writer.start()

Expand Down Expand Up @@ -216,17 +216,18 @@ def stop(self) -> int:
return rc


def start_process(cmd_desc: CommandDescriptor, fl_ctx: FLContext) -> ProcessManager:
def start_process(cmd_desc: CommandDescriptor, fl_ctx: FLContext, stop_method="kill") -> ProcessManager:
"""Convenience function for starting a subprocess.
Args:
cmd_desc: the CommandDescriptor the describes the command to be executed
fl_ctx: FLContext object
stop_method: how to stop the process
Returns: a ProcessManager object.
"""
mgr = ProcessManager(cmd_desc)
mgr = ProcessManager(cmd_desc, stop_method)
mgr.start(fl_ctx)
return mgr

Expand Down
111 changes: 77 additions & 34 deletions nvflare/app_opt/flower/applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import threading
import time

import tomli
import tomli_w

from nvflare.apis.fl_context import FLContext
from nvflare.apis.workspace import Workspace
from nvflare.app_common.tie.applet import Applet
from nvflare.app_common.tie.cli_applet import CLIApplet
from nvflare.app_common.tie.defs import Constant as TieConstant
from nvflare.app_common.tie.process_mgr import CommandDescriptor, ProcessManager, run_command, start_process, StopMethod
from nvflare.app_common.tie.process_mgr import CommandDescriptor, ProcessManager, run_command, start_process
from nvflare.app_opt.flower.defs import Constant
from nvflare.fuel.utils.grpc_utils import create_channel
from nvflare.security.logging import secure_format_exception
Expand Down Expand Up @@ -72,14 +75,12 @@ def get_command(self, ctx: dict) -> CommandDescriptor:
f"--clientappio-api-address {clientapp_api_addr}"
)

# use app_dir as the cwd for flower's client app.
# this is necessary for client_api to be used with the flower client app for metrics logging
# client_api expects config info from the "config" folder in the cwd!
self.logger.info(f"starting flower client app: {cmd}")
return CommandDescriptor(
cmd=cmd,
cwd=app_dir,
env=self.extra_env,
log_file_name="client_app_log.txt",
stdout_msg_prefix="FLWR-CA",
stop_method=StopMethod.TERMINATE,
cmd=cmd, cwd=app_dir, env=self.extra_env, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA"
)


Expand All @@ -106,7 +107,6 @@ def __init__(
self.last_check_status = None
self.last_check_time = None
self.last_check_stopped = False
self.exec_api_addr = None
self.flower_app_dir = None
self.flower_run_finished = False
self.flower_run_stopped = False # have we issued 'flwr stop'?
Expand All @@ -122,22 +122,77 @@ def _start_process(self, name: str, cmd_desc: CommandDescriptor, fl_ctx: FLConte
self.logger.error(f"exception starting applet: {secure_format_exception(ex)}")
self._start_error = True

def _modify_flower_app_config(self, exec_api_addr: str):
"""Currently the exec-api-address must be specified in pyproject.toml to be able to submit to the
superlink with "flwr run" command.
Args:
exec_api_addr:
Returns:
"""
config_file = os.path.join(self.flower_app_dir, "pyproject.toml")
if not os.path.isfile(config_file):
raise RuntimeError(f"invalid flower app: missing {config_file}")

with open(config_file, mode="rb") as fp:
config = tomli.load(fp)

# add or modify address
tool = config.get("tool")
if not tool:
tool = {}
config["tool"] = tool

flwr = tool.get("flwr")
if not flwr:
flwr = {}
tool["flwr"] = flwr

fed = flwr.get("federations")
if not fed:
fed = {}
flwr["federations"] = fed

default_mode = fed.get("default")
if not default_mode:
default_mode = "local-poc"
fed["default"] = default_mode

mode_config = fed.get(default_mode)
if not mode_config:
mode_config = {}
fed[default_mode] = mode_config

mode_config["address"] = exec_api_addr
mode_config["insecure"] = True

# recreate the app config
with open(config_file, mode="wb") as fp:
tomli_w.dump(config, fp)

def start(self, app_ctx: dict):
"""Start the applet.
We start the superlink, and wait for it to become ready.
We then use "flwr run" command to submit the app to superlink.
Flower requires two processes for server application:
superlink: this process is responsible for client communication
server_app: this process performs server side of training.
We start the superlink first, and wait for it to become ready, then start the server app.
Each process will have its own log file in the job's run dir. The superlink's log file is named
"superlink_log.txt". The server app's log file is named "server_app_log.txt".
Args:
app_ctx: the run context of the applet.
Returns:
"""
# try to start superlink
# try to start superlink first
serverapp_api_addr = app_ctx.get(Constant.APP_CTX_SERVERAPP_API_ADDR)
fleet_api_addr = app_ctx.get(Constant.APP_CTX_FLEET_API_ADDR)
self.exec_api_addr = app_ctx.get(Constant.APP_CTX_EXEC_API_ADDR)
exec_api_addr = app_ctx.get(Constant.APP_CTX_EXEC_API_ADDR)
fl_ctx = app_ctx.get(Constant.APP_CTX_FL_CONTEXT)
if not isinstance(fl_ctx, FLContext):
self.logger.error(f"expect APP_CTX_FL_CONTEXT to be FLContext but got {type(fl_ctx)}")
Expand All @@ -149,14 +204,11 @@ def start(self, app_ctx: dict):
self.logger.error(f"expect workspace to be Workspace but got {type(ws)}")
raise RuntimeError("invalid workspace")

job_id = fl_ctx.get_job_id()
custom_dir = ws.get_app_custom_dir(job_id)
app_dir = ws.get_app_dir(job_id)
if not os.path.isabs(custom_dir):
custom_dir = os.path.relpath(custom_dir, app_dir)

custom_dir = ws.get_app_custom_dir(fl_ctx.get_job_id())
self.flower_app_dir = custom_dir

self._modify_flower_app_config(exec_api_addr)

db_arg = ""
if self.database:
db_arg = f"--database {self.database}"
Expand All @@ -171,16 +223,10 @@ def start(self, app_ctx: dict):
f"flower-superlink --insecure --fleet-api-type grpc-adapter {db_arg} "
f"--serverappio-api-address {serverapp_api_addr} "
f"--fleet-api-address {fleet_api_addr} "
f"--exec-api-address {self.exec_api_addr}"
f"--exec-api-address {exec_api_addr}"
)

cmd_desc = CommandDescriptor(
cmd=superlink_cmd,
cwd=app_dir,
log_file_name="superlink_log.txt",
stdout_msg_prefix="FLWR-SL",
stop_method=StopMethod.TERMINATE,
)
cmd_desc = CommandDescriptor(cmd=superlink_cmd, log_file_name="superlink_log.txt", stdout_msg_prefix="FLWR-SL")

self._superlink_process_mgr = self._start_process(name="superlink", cmd_desc=cmd_desc, fl_ctx=fl_ctx)
if not self._superlink_process_mgr:
Expand All @@ -197,9 +243,9 @@ def start(self, app_ctx: dict):
)
self.logger.info(f"superlink is ready for server app in {time.time() - start_time} seconds")

# submit the app using "flwr run" command
# flwr_run_cmd = f"flwr run --format json -c 'address={exec_api_addr}' {flower_app_dir}"
flwr_run_cmd = f"flwr run {self._flwr_cmd_option()} {self.flower_app_dir}"
# submitting the server app using "flwr run" command
# flwr_run_cmd = f"flwr run --format json -c 'address={exec_api_addr}' {custom_dir}"
flwr_run_cmd = f"flwr run --format json {self.flower_app_dir}"
run_info = self._run_flower_command(flwr_run_cmd)
run_id = run_info.get("run-id")
if not run_id:
Expand All @@ -208,9 +254,6 @@ def start(self, app_ctx: dict):
self.logger.info(f"submitted Flower App and got run id {run_id}")
self.run_id = run_id

def _flwr_cmd_option(self):
return f"--format json --federation-config 'address=\"{self.exec_api_addr}\"'"

def _run_flower_command(self, command: str):
self.logger.info(f"running flower command: {command}")
cmd_desc = CommandDescriptor(cmd=command)
Expand Down Expand Up @@ -265,7 +308,7 @@ def stop(self, timeout=0.0) -> int:
# stop the server app
# we may not be able to issue 'flwr stop' more than once!
self.flower_run_stopped = True
flwr_stop_cmd = f"flwr stop {self._flwr_cmd_option()} {self.run_id} {self.flower_app_dir}"
flwr_stop_cmd = f"flwr stop --format json {self.run_id} {self.flower_app_dir}"
try:
self._run_flower_command(flwr_stop_cmd)
except Exception as ex:
Expand Down Expand Up @@ -293,7 +336,7 @@ def _is_process_stopped(p: ProcessManager):

def _check_flower_run_status(self):
# check whether the app is finished
flwr_ls_cmd = f"flwr ls {self._flwr_cmd_option()} {self.flower_app_dir}"
flwr_ls_cmd = f"flwr ls --format json {self.flower_app_dir}"
try:
run_info = self._run_flower_command(flwr_ls_cmd)
except Exception as ex:
Expand Down
30 changes: 30 additions & 0 deletions nvflare/app_opt/flower/connectors/grpc_client_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

import flwr.proto.grpcadapter_pb2 as pb2
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterServicer

Expand Down Expand Up @@ -49,6 +52,8 @@ def __init__(
self.internal_server_addr = None
self._training_stopped = False
self._client_name = None
self._stopping = False
self._exit_waiter = threading.Event()

def initialize(self, fl_ctx: FLContext):
super().initialize(fl_ctx)
Expand All @@ -64,6 +69,14 @@ def _start_client(self, server_addr: str, fl_ctx: FLContext):

def _stop_client(self):
self._training_stopped = True

# do not stop the applet until should-exit is sent
if not self._exit_waiter.wait(timeout=2.0):
self.logger.warning(f"did not send should-exit before shutting down supernode")

# give 1 sec for the supernode to quite gracefully
self.logger.info("about to stop applet")
time.sleep(1.0)
self.stop_applet(self.client_shutdown_timeout)

def _is_stopped(self) -> (bool, int):
Expand All @@ -74,6 +87,10 @@ def _is_stopped(self) -> (bool, int):
if self._training_stopped:
return True, 0

if self._stopping:
self.stop(fl_ctx=None)
return True, 0

return False, 0

def start(self, fl_ctx: FLContext):
Expand Down Expand Up @@ -127,13 +144,26 @@ def SendReceive(self, request: pb2.MessageContainer, context):
"""
try:
if self.stopped:
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_1!")
return reply_should_exit()

reply = self._send_flower_request(msg_container_to_shareable(request))
rc = reply.get_return_code()
if rc == ReturnCode.OK:
return shareable_to_msg_container(reply)
else:
# server side already ended
self.logger.warning(f"Flower server has stopped with RC {rc}")
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_2!")
return reply_should_exit()
except Exception as ex:
self._abort(reason=f"_send_flower_request exception: {secure_format_exception(ex)}")
self._stopping = True
self._exit_waiter.set()
self.logger.info("asked supernode to exit_3!")
return reply_should_exit()

0 comments on commit 2ad0c1b

Please sign in to comment.