From d445a09efb7adb21644803cc30f8b18bcd1d1645 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 20 Sep 2023 16:54:08 -0400 Subject: [PATCH] Added ability to send the client job heartbeat calls to server. (#2016) * Added ability to send the client job heartbeat calls to server. * codestyle fixes. * Changed to use aux_message for sending the client job heartbeat calls. * Made the thread daemon=true. --- nvflare/apis/fl_constant.py | 1 + nvflare/private/fed/client/client_runner.py | 23 +++++++++++++++++++++ nvflare/private/fed/server/server_runner.py | 11 ++++++++++ 3 files changed, 35 insertions(+) diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 0bdb60a4c5..f2445ef55a 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -162,6 +162,7 @@ class ReservedTopic(object): END_RUN = "__end_run__" ABORT_ASK = "__abort_task__" AUX_COMMAND = "__aux_command__" + JOB_HEART_BEAT = "__job_heartbeat__" class AdminCommandNames(object): diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 041f70ae74..15a684d47e 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -22,6 +22,7 @@ from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import add_job_audit_event +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment from nvflare.private.privacy_manager import Scope @@ -392,6 +393,9 @@ def _check_stop_conditions(self, fl_ctx: FLContext) -> bool: return False def _try_run(self): + heartbeat_thread = threading.Thread(target=self.send_job_heartbeat, args=[], daemon=True) + heartbeat_thread.start() + while not self.asked_to_stop: with self.engine.new_context() as fl_ctx: if self._check_stop_conditions(fl_ctx): @@ -404,6 +408,25 @@ def _try_run(self): time.sleep(task_fetch_interval) + def send_job_heartbeat(self, interval=30.0): + wait_times = int(interval / 2) + request = Shareable() + while not self.asked_to_stop: + with self.engine.new_context() as fl_ctx: + self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.JOB_HEART_BEAT, + request=request, + timeout=0, + fl_ctx=fl_ctx, + optional=True, + ) + + for i in range(wait_times): + time.sleep(2) + if self.asked_to_stop: + break + def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): """Fetches and runs a task. diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index dd1651c47b..2d17182511 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -99,6 +99,13 @@ def __init__(self, config: ServerRunnerConfig, job_id: str, engine: ServerEngine self.status = "init" self.turn_to_cold = False + self._register_aux_message_handler(engine) + + def _register_aux_message_handler(self, engine): + engine.register_aux_message_handler( + topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat + ) + def _execute_run(self): while self.current_wf_index < len(self.config.workflows): wf = self.config.workflows[self.current_wf_index] @@ -489,6 +496,10 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul "Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)), ) + def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + self.log_info(fl_ctx, "received client job_heartbeat aux request") + return make_reply(ReturnCode.OK) + def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False): self.status = "done" self.abort_signal.trigger(value=True)