From 4bb222c2a5df08166c168a24b7daeccea226c26b Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:20:51 -0500 Subject: [PATCH] Support Edge Executor (#3245) Fixes # . ### Description This PR adds following features. - A base class (EdgeTaskExecutor) for developing executors for processing edge tasks. - The EdgeTaskDispatcher (ETD) that is to be installed on CP and is responsible for dispatching edge requests to the appropriate CJ. The ETD determines the right CJ based on the job's "edge_method" meta property against device's capabilities. - Added some test components for determining active devices from simulated device activities. - Added a more elegant way for registering and handling events. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- nvflare/apis/event_type.py | 1 + nvflare/apis/fl_component.py | 30 +++ nvflare/apis/fl_constant.py | 1 + nvflare/apis/job_def.py | 1 + nvflare/app_common/executors/ham.py | 132 +++++++------ nvflare/edge/__init__.py | 13 ++ nvflare/edge/aggregators/__init__.py | 13 ++ nvflare/edge/aggregators/edge_survey_aggr.py | 36 ++++ nvflare/edge/constants.py | 36 ++++ nvflare/edge/controllers/__init__.py | 13 ++ nvflare/edge/controllers/edge_survey_ctl.py | 59 ++++++ nvflare/edge/executors/__init__.py | 13 ++ .../edge/executors/edge_survey_executor.py | 49 +++++ nvflare/edge/executors/ete.py | 152 +++++++++++++++ nvflare/{lighter => edge}/tree_prov.py | 50 +++-- nvflare/edge/widgets/__init__.py | 13 ++ nvflare/edge/widgets/etd.py | 184 ++++++++++++++++++ nvflare/edge/widgets/etg.py | 96 +++++++++ nvflare/fuel/f3/cellnet/defs.py | 1 + nvflare/fuel/f3/cellnet/utils.py | 7 + nvflare/lighter/ctx.py | 10 + nvflare/lighter/provisioner.py | 2 +- nvflare/private/defs.py | 9 +- nvflare/private/event.py | 14 +- nvflare/private/fed/client/client_engine.py | 15 ++ .../fed/client/client_engine_executor_spec.py | 9 + nvflare/private/fed/client/client_executor.py | 31 ++- nvflare/private/fed/client/client_runner.py | 34 ++++ nvflare/private/fed/client/fed_client_base.py | 1 + 29 files changed, 939 insertions(+), 86 deletions(-) create mode 100644 nvflare/edge/__init__.py create mode 100644 nvflare/edge/aggregators/__init__.py create mode 100644 nvflare/edge/aggregators/edge_survey_aggr.py create mode 100644 nvflare/edge/constants.py create mode 100644 nvflare/edge/controllers/__init__.py create mode 100644 nvflare/edge/controllers/edge_survey_ctl.py create mode 100644 nvflare/edge/executors/__init__.py create mode 100644 nvflare/edge/executors/edge_survey_executor.py create mode 100644 nvflare/edge/executors/ete.py rename nvflare/{lighter => edge}/tree_prov.py (88%) create mode 100644 nvflare/edge/widgets/__init__.py create mode 100644 nvflare/edge/widgets/etd.py create mode 100644 nvflare/edge/widgets/etg.py diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 64ace70972..944b202b22 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -90,6 +90,7 @@ class EventType(object): AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" BEFORE_JOB_LAUNCH = "_before_job_launch" + AFTER_JOB_LAUNCH = "_after_job_launch" TASK_RESULT_RECEIVED = "_task_result_received" TASK_ASSIGNMENT_SENT = "_task_assignment_sent" diff --git a/nvflare/apis/fl_component.py b/nvflare/apis/fl_component.py index c126bd51ec..cf59e82a46 100644 --- a/nvflare/apis/fl_component.py +++ b/nvflare/apis/fl_component.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union from nvflare.apis.utils.fl_context_utils import generate_log_message from nvflare.fuel.utils.log_utils import get_obj_logger @@ -35,6 +36,7 @@ def __init__(self): """ self._name = self.__class__.__name__ self.logger = get_obj_logger(self) + self._event_handlers = {} @property def name(self): @@ -227,3 +229,31 @@ def _fire_log_event(self, event_type: str, log_tag: str, log_msg: str, fl_ctx: F dxo = event_data.to_dxo() fl_ctx.set_prop(key=FLContextKey.EVENT_DATA, value=dxo.to_shareable(), private=True, sticky=False) self.fire_event(event_type=event_type, fl_ctx=fl_ctx) + + def register_event_handler(self, event_types: Union[str, List[str]], handler, **kwargs): + if isinstance(event_types, str): + event_types = [event_types] + elif not isinstance(event_types, list): + raise ValueError(f"event_types must be string or list of strings but got {type(event_types)}") + + if not callable(handler): + raise ValueError(f"handler {handler.__name__} is not callable") + + for e in event_types: + entries = self._event_handlers.get(e) + if not entries: + entries = [] + self._event_handlers[e] = entries + + already_registered = False + for h, _ in entries: + if handler == h: + # already registered: either by a super class or by the class itself. + already_registered = True + break + + if not already_registered: + entries.append((handler, kwargs)) + + def get_event_handlers(self): + return self._event_handlers diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index c53e3b3c90..89a3c3e127 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -193,6 +193,7 @@ class FLContextKey(object): PROCESS_TYPE = ReservedKey.PROCESS_TYPE JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS EVENT_PROCESSED = "__event_processed__" + CELL_MESSAGE = "__cell_message__" class ProcessType: diff --git a/nvflare/apis/job_def.py b/nvflare/apis/job_def.py index a36720c825..44e781a3d9 100644 --- a/nvflare/apis/job_def.py +++ b/nvflare/apis/job_def.py @@ -74,6 +74,7 @@ class JobMetaKey(str, Enum): STATS_POOL_CONFIG = "stats_pool_config" FROM_HUB_SITE = "from_hub_site" CUSTOM_PROPS = "custom_props" + EDGE_METHOD = "edge_method" def __repr__(self): return self.value diff --git a/nvflare/app_common/executors/ham.py b/nvflare/app_common/executors/ham.py index 7cb63ef042..40b1477e0b 100644 --- a/nvflare/app_common/executors/ham.py +++ b/nvflare/app_common/executors/ham.py @@ -57,67 +57,77 @@ def __init__( self._aggr_lock = threading.Lock() self._process_error = None - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - engine = fl_ctx.get_engine() - - aggr = engine.get_component(self.aggregator_id) - if not isinstance(aggr, Aggregator): - self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}") - self.aggregator = aggr - - learner = engine.get_component(self.learner_id) - if not isinstance(learner, Executor): - self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}") - self.learner = learner - elif event_type == EventType.TASK_ASSIGNMENT_SENT: - # the task was sent to a child client - if not self.pending_task_id: - # I don't have a pending task - return - - child_client_ctx = fl_ctx.get_peer_context() - assert isinstance(child_client_ctx, FLContext) - child_client_name = child_client_ctx.get_identity_name() - self._update_client_status(child_client_name, None) - task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) - - # indicate that this event has been processed by me - fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) - self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}") - elif event_type == EventType.TASK_RESULT_RECEIVED: - # received results from a child client - if not self.pending_task_id: - # I don't have a pending task - return - - # indicate that this event has been processed by me - fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) - - result = fl_ctx.get_prop(FLContextKey.TASK_RESULT) - assert isinstance(result, Shareable) - task_id = result.get_header(ReservedKey.TASK_ID) - peer_ctx = fl_ctx.get_peer_context() - assert isinstance(peer_ctx, FLContext) - child_client_name = peer_ctx.get_identity_name() - self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}") - - if task_id != self.pending_task_id: - self.log_warning( - fl_ctx, - f"dropped the received result from child {child_client_name} " - f"for task {task_id} while waiting for task {self.pending_task_id}", - ) - return - - rc = result.get_return_code(ReturnCode.OK) - if rc == ReturnCode.OK: - self.log_info(fl_ctx, f"accepting result from client {child_client_name}") - self._do_aggregation(result, fl_ctx) - else: - self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}") - self.log_info(fl_ctx, f"received result from child {child_client_name}") - self._update_client_status(child_client_name, time.time()) + self.register_event_handler(EventType.START_RUN, self._handle_start_run) + self.register_event_handler(EventType.TASK_ASSIGNMENT_SENT, self._handle_task_sent) + self.register_event_handler(EventType.TASK_RESULT_RECEIVED, self._handle_result_received) + + def _handle_start_run(self, event_type: str, fl_ctx: FLContext): + self.log_debug(fl_ctx, f"handling event {event_type}") + engine = fl_ctx.get_engine() + + aggr = engine.get_component(self.aggregator_id) + if not isinstance(aggr, Aggregator): + self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}") + self.aggregator = aggr + + learner = engine.get_component(self.learner_id) + if not isinstance(learner, Executor): + self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}") + self.learner = learner + + def _handle_task_sent(self, event_type: str, fl_ctx: FLContext): + # the task was sent to a child client + self.log_debug(fl_ctx, f"handling event {event_type}") + + if not self.pending_task_id: + # I don't have a pending task + return + + child_client_ctx = fl_ctx.get_peer_context() + assert isinstance(child_client_ctx, FLContext) + child_client_name = child_client_ctx.get_identity_name() + self._update_client_status(child_client_name, None) + task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) + + # indicate that this event has been processed by me + fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) + self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}") + + def _handle_result_received(self, event_type: str, fl_ctx: FLContext): + # received results from a child client + self.log_debug(fl_ctx, f"handling event {event_type}") + + if not self.pending_task_id: + # I don't have a pending task + return + + # indicate that this event has been processed by me + fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) + + result = fl_ctx.get_prop(FLContextKey.TASK_RESULT) + assert isinstance(result, Shareable) + task_id = result.get_header(ReservedKey.TASK_ID) + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + child_client_name = peer_ctx.get_identity_name() + self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}") + + if task_id != self.pending_task_id: + self.log_warning( + fl_ctx, + f"dropped the received result from child {child_client_name} " + f"for task {task_id} while waiting for task {self.pending_task_id}", + ) + return + + rc = result.get_return_code(ReturnCode.OK) + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"accepting result from client {child_client_name}") + self._do_aggregation(result, fl_ctx) + else: + self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}") + self.log_info(fl_ctx, f"received result from child {child_client_name}") + self._update_client_status(child_client_name, time.time()) def _do_aggregation(self, result: Shareable, fl_ctx: FLContext): with self._aggr_lock: diff --git a/nvflare/edge/__init__.py b/nvflare/edge/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/aggregators/__init__.py b/nvflare/edge/aggregators/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/aggregators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/aggregators/edge_survey_aggr.py b/nvflare/edge/aggregators/edge_survey_aggr.py new file mode 100644 index 0000000000..e86cb2857c --- /dev/null +++ b/nvflare/edge/aggregators/edge_survey_aggr.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.aggregator import Aggregator + + +class EdgeSurveyAggregator(Aggregator): + def __init__(self): + Aggregator.__init__(self) + self.num_devices = 0 + + def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: + self.log_info(fl_ctx, f"accepting: {shareable}") + num_devices = shareable.get("num_devices") + if num_devices: + self.num_devices += num_devices + return True + + def reset(self, fl_ctx: FLContext): + self.num_devices = 0 + + def aggregate(self, fl_ctx: FLContext) -> Shareable: + self.log_info(fl_ctx, f"aggregating final result: {self.num_devices}") + return Shareable({"num_devices": self.num_devices}) diff --git a/nvflare/edge/constants.py b/nvflare/edge/constants.py new file mode 100644 index 0000000000..6105c865f6 --- /dev/null +++ b/nvflare/edge/constants.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode + + +class Status(CellReturnCode): + NO_TASK = "no_task" + NO_JOB = "no_job" + + +class EdgeProtoKey: + STATUS = "status" + DATA = "data" + + +class EdgeContextKey: + JOB_ID = "__edge_job_id__" + EDGE_CAPABILITIES = "__edge_capabilities__" + REQUEST_FROM_EDGE = "__request_from_edge__" + REPLY_TO_EDGE = "__reply_to_edge__" + + +class EventType: + EDGE_REQUEST_RECEIVED = "_edge_request_received" + EDGE_JOB_REQUEST_RECEIVED = "_edge_job_request_received" diff --git a/nvflare/edge/controllers/__init__.py b/nvflare/edge/controllers/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/controllers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/controllers/edge_survey_ctl.py b/nvflare/edge/controllers/edge_survey_ctl.py new file mode 100644 index 0000000000..1103e1ebca --- /dev/null +++ b/nvflare/edge/controllers/edge_survey_ctl.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + + +class EdgeSurveyController(Controller): + def __init__(self, num_rounds: int, timeout: int): + Controller.__init__(self) + self.num_rounds = num_rounds + self.timeout = timeout + + def start_controller(self, fl_ctx: FLContext): + pass + + def stop_controller(self, fl_ctx: FLContext): + pass + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + for r in range(self.num_rounds): + task = Task( + name="survey", + data=Shareable(), + timeout=self.timeout, + ) + + self.broadcast_and_wait( + task=task, + min_responses=2, + wait_time_after_min_received=0, + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + total_devices = 0 + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + result = ct.result + assert isinstance(result, Shareable) + self.log_info(fl_ctx, f"result from client {ct.client.name}: {result}") + count = result.get("num_devices") + if count: + total_devices += count + + self.log_info(fl_ctx, f"total devices in round {r}: {total_devices}") diff --git a/nvflare/edge/executors/__init__.py b/nvflare/edge/executors/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/executors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/executors/edge_survey_executor.py b/nvflare/edge/executors/edge_survey_executor.py new file mode 100644 index 0000000000..d3c42f2dca --- /dev/null +++ b/nvflare/edge/executors/edge_survey_executor.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from typing import Any + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.edge.executors.ete import EdgeTaskExecutor + + +class EdgeSurveyExecutor(EdgeTaskExecutor): + """This executor is for test purpose only. It is to be used as the "learner" for the + HierarchicalAggregationManager. + """ + + def __init__(self, timeout=10.0): + EdgeTaskExecutor.__init__(self) + self.timeout = timeout + self.num_devices = 0 + self.start_time = None + + def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): + self.num_devices = 0 + self.start_time = time.time() + + def is_task_done(self, fl_ctx: FLContext) -> bool: + return time.time() - self.start_time > self.timeout + + def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: + assert isinstance(request, dict) + self.log_info(fl_ctx, f"received edge request: {request}") + self.num_devices += 1 + return {"status": "tryAgain", "comment": f"received {request}"} + + def get_task_result(self, fl_ctx: FLContext) -> Shareable: + result = make_reply(ReturnCode.OK) + result["num_devices"] = self.num_devices + return result diff --git a/nvflare/edge/executors/ete.py b/nvflare/edge/executors/ete.py new file mode 100644 index 0000000000..be0b84be3c --- /dev/null +++ b/nvflare/edge/executors/ete.py @@ -0,0 +1,152 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from abc import abstractmethod +from typing import Any + +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.security.logging import secure_format_exception + + +class EdgeTaskExecutor(Executor): + """This is the base class for executors to handling requests from edge devices. + Subclasses must implement the required abstract methods defined here. + """ + + def __init__(self): + """Constructor of EdgeTaskExecutor""" + Executor.__init__(self) + self.current_task = None + + self.register_event_handler(EdgeEventType.EDGE_REQUEST_RECEIVED, self._handle_edge_request) + + @abstractmethod + def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: + """This is called to process an edge request sent from the edge device. + + Args: + request: the request from edge device + fl_ctx: FLContext object + + Returns: reply to the edge device + + """ + pass + + def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): + """This method is called when a task assignment is received from the controller. + Subclass can implement this method to prepare for task processing. + + Args: + task_name: name of the task + task_data: task data + fl_ctx: FLContext object + + Returns: None + + """ + pass + + @abstractmethod + def is_task_done(self, fl_ctx: FLContext) -> bool: + """This is called by the base class to determine whether the task processing is done. + Subclass must implement this method. + + Args: + fl_ctx: FLContext object + + Returns: whether task is done. + + """ + pass + + @abstractmethod + def get_task_result(self, fl_ctx: FLContext) -> Shareable: + """This is called by the base class to get the final result of the task. + Base class will send the result to the controller. + + Args: + fl_ctx: FLContext object + + Returns: a Shareable object that is the task result + + """ + pass + + def _handle_edge_request(self, event_type: str, fl_ctx: FLContext): + if not self.current_task: + self.logger.debug(f"received edge event {event_type} but I don't have pending task") + return + + try: + msg = fl_ctx.get_prop(FLContextKey.CELL_MESSAGE) + assert isinstance(msg, CellMessage) + self.log_debug(fl_ctx, f"received edge request: {msg.payload}") + reply = self.process_edge_request(request=msg.payload, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.TASK_RESULT, reply, private=True, sticky=False) + except Exception as ex: + self.logger.error(f"exception from self.process_edge_request: {secure_format_exception(ex)}") + fl_ctx.set_prop(FLContextKey.EXCEPTIONS, ex, private=True, sticky=False) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.current_task = shareable + result = self._execute(task_name, shareable, fl_ctx, abort_signal) + self.current_task = None + return result + + def _execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + self.task_received(task_name, shareable, fl_ctx) + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.task_received: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + start_time = time.time() + while True: + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + try: + task_done = self.is_task_done(fl_ctx) + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.is_task_done: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + if task_done: + break + + time.sleep(0.2) + + self.log_debug(fl_ctx, f"task done after {time.time() - start_time} seconds") + try: + result = self.get_task_result(fl_ctx) + + if not isinstance(result, Shareable): + self.log_error( + fl_ctx, + f"bad result from self.get_task_result: expect Shareable but got {type(result)}", + ) + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.get_task_result: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + return result diff --git a/nvflare/lighter/tree_prov.py b/nvflare/edge/tree_prov.py similarity index 88% rename from nvflare/lighter/tree_prov.py rename to nvflare/edge/tree_prov.py index 15bba28aba..5b60f8ced7 100644 --- a/nvflare/lighter/tree_prov.py +++ b/nvflare/edge/tree_prov.py @@ -18,14 +18,15 @@ """ import argparse +import json +import os.path +from nvflare.lighter.entity import Participant, ParticipantType, Project from nvflare.lighter.impl.cert import CertBuilder from nvflare.lighter.impl.signature import SignatureBuilder from nvflare.lighter.impl.static_file import StaticFileBuilder from nvflare.lighter.impl.workspace import WorkspaceBuilder - -from .entity import Participant, ParticipantType, Project -from .provisioner import Provisioner +from nvflare.lighter.provisioner import Provisioner def _new_participant(name: str, ptype: str, props: dict) -> Participant: @@ -46,20 +47,33 @@ class Stats: num_non_leaf_clients = 0 -class _Node: - +class PortManager: last_port_number = 9000 + @classmethod + def get_port(cls): + cls.last_port_number += 1 + return cls.last_port_number + + +class _Node: def __init__(self): self.name = None self.client_name = None self.parent = None self.children = [] - self.port = _Node.last_port_number - _Node.last_port_number += 1 - - -def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clients: int, project: Project): + self.port = PortManager.get_port() + + +def _build_tree( + depth: int, + width: int, + max_depth: int, + parent: _Node, + num_clients: int, + project: Project, + lcp_map: dict, +): """Build relay hierarchy and client hierarchy, recursively. Relays are organized hierarchically. Attach a client to each relay. Such clients are non-leaf clients (a.k.a @@ -82,7 +96,7 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien """ if depth == max_depth: - # the parent is a leaf node - add leaf clients + # the parent is a leaf node - add leaf clients (LCPs) Stats.num_leaf_relays += 1 for i in range(num_clients): name = _make_client_name(parent.name) + str(i + 1) @@ -92,6 +106,8 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien project.add_participant(client) Stats.num_clients += 1 Stats.num_leaf_clients += 1 + + lcp_map[name] = {"host": "localhost", "port": PortManager.get_port()} return if depth > 0: @@ -128,7 +144,7 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien Stats.num_non_leaf_clients += 1 # depth-first recursion - _build_tree(depth + 1, width, max_depth, child, num_clients, project) + _build_tree(depth + 1, width, max_depth, child, num_clients, project, lcp_map) def main(): @@ -206,7 +222,8 @@ def main(): # add relays and clients root_relay = _Node() root_relay.name = "R" - _build_tree(0, args.width, args.depth, root_relay, args.clients, project) + lcp_map = {} + _build_tree(0, args.width, args.depth, root_relay, args.clients, project, lcp_map) total_sites = Stats.num_clients + Stats.num_relays + 1 @@ -226,7 +243,12 @@ def main(): "admin@nvidia.com", ParticipantType.ADMIN, props={"role": "project_admin", "connect_to": "localhost"} ) project.add_participant(admin) - provisioner.provision(project) + ctx = provisioner.provision(project) + location = ctx.get_result_location() + lcp_map_file_name = os.path.join(location, "lcp_map.json") + with open(lcp_map_file_name, "wt") as f: + json.dump(lcp_map, f, indent=4) + print(f"Generated LCP Map: {lcp_map_file_name}") if __name__ == "__main__": diff --git a/nvflare/edge/widgets/__init__.py b/nvflare/edge/widgets/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/widgets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/widgets/etd.py b/nvflare/edge/widgets/etd.py new file mode 100644 index 0000000000..c1aa1616a9 --- /dev/null +++ b/nvflare/edge/widgets/etd.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from random import randrange + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status as EdgeStatus +from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey +from nvflare.fuel.f3.cellnet.utils import new_cell_message +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.widgets.widget import Widget + + +class EdgeTaskDispatcher(Widget): + """Edge Task Dispatcher (ETD) is to be used to dispatch a received edge request to a running job (CJ). + ETD must be installed on CP before the CP is started. + """ + + def __init__(self, request_timeout: float = 2.0): + Widget.__init__(self) + self.request_timeout = request_timeout + self.edge_jobs = {} # edge_method => list of job_ids + self.lock = threading.Lock() + + self.register_event_handler( + EventType.AFTER_JOB_LAUNCH, + self._handle_job_launched, + ) + self.register_event_handler( + [EventType.JOB_COMPLETED, EventType.JOB_CANCELLED, EventType.JOB_ABORTED], + self._handle_job_done, + ) + self.register_event_handler( + EdgeEventType.EDGE_JOB_REQUEST_RECEIVED, + self._handle_edge_job_request, + ) + self.register_event_handler( + EdgeEventType.EDGE_REQUEST_RECEIVED, + self._handle_edge_request, + ) + self.logger.info("EdgeTaskDispatcher created!") + + def _add_job(self, job_meta: dict): + with self.lock: + edge_method = job_meta.get(JobMetaKey.EDGE_METHOD) + if not edge_method: + # this is not an edge job + return + + jobs = self.edge_jobs.get(edge_method) + if not jobs: + jobs = [] + self.edge_jobs[edge_method] = jobs + + job_id = job_meta.get(JobMetaKey.JOB_ID) + jobs.append(job_id) + + def _remove_job(self, job_meta: dict): + with self.lock: + job_id = job_meta.get(JobMetaKey.JOB_ID) + edge_method = job_meta.get(JobMetaKey.EDGE_METHOD) + if not edge_method: + # this is not an edge job + self.logger.info(f"no edge_method in job {job_id}") + return + + jobs = self.edge_jobs.get(edge_method) + if not jobs: + self.logger.info("no edge jobs pending") + return + + assert isinstance(jobs, list) + job_id = job_meta.get(JobMetaKey.JOB_ID) + self.logger.info(f"current jobs for {edge_method}: {jobs}") + if job_id in jobs: + jobs.remove(job_id) + if not jobs: + # no more jobs for this edge method + self.edge_jobs.pop(edge_method) + + def _match_job(self, caps: list): + with self.lock: + for edge_method, jobs in self.edge_jobs.items(): + if edge_method in caps: + # pick one randomly + i = randrange(len(jobs)) + return jobs[i] + + # no job matched + return None + + def _find_job(self, job_id: str): + with self.lock: + for jobs in self.edge_jobs.values(): + if job_id in jobs: + return True + return False + + def _handle_job_launched(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if not job_meta: + self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}") + else: + self._add_job(job_meta) + + def _handle_job_done(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if not job_meta: + self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}") + else: + self._remove_job(job_meta) + + def _handle_edge_job_request(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + edge_capabilities = fl_ctx.get_prop(EdgeContextKey.EDGE_CAPABILITIES) + if not edge_capabilities: + self.logger.error(f"missing {EdgeContextKey.EDGE_CAPABILITIES} from fl_ctx for event {event_type}") + self._set_edge_reply(EdgeStatus.INVALID_REQUEST, None, fl_ctx) + return + + # find job for the caps + job_id = self._match_job(edge_capabilities) + if job_id: + status = EdgeStatus.OK + else: + status = EdgeStatus.NO_JOB + self._set_edge_reply(status, job_id, fl_ctx) + + @staticmethod + def _set_edge_reply(status, data, fl_ctx: FLContext): + fl_ctx.set_prop( + key=EdgeContextKey.REPLY_TO_EDGE, + value={EdgeProtoKey.STATUS: status, EdgeProtoKey.DATA: data}, + private=True, + sticky=False, + ) + + def _handle_edge_request(self, event_type: str, fl_ctx: FLContext): + # try to find the job + job_id = fl_ctx.get_prop(EdgeContextKey.JOB_ID) + if not job_id: + self.logger.error(f"handling event {event_type}: missing {EdgeContextKey.JOB_ID} from fl_ctx") + self._set_edge_reply(EdgeStatus.INVALID_REQUEST, None, fl_ctx) + return + + if not self._find_job(job_id): + self._set_edge_reply(EdgeStatus.NO_JOB, None, fl_ctx) + return + + # send edge request data to CJ + edge_req_data = fl_ctx.get_prop(EdgeContextKey.REQUEST_FROM_EDGE) + self.logger.info(f"Sending edge request to CJ {job_id}: {edge_req_data}") + engine = fl_ctx.get_engine() + reply = engine.send_to_job( + job_id=job_id, + channel=CellChannel.EDGE_REQUEST, + topic="request", + msg=new_cell_message({}, edge_req_data), + timeout=self.request_timeout, + ) + + assert isinstance(reply, CellMessage) + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + reply_data = reply.payload + self.logger.debug(f"got edge result from CJ: {rc=} {reply_data=}") + self._set_edge_reply(rc, reply_data, fl_ctx) diff --git a/nvflare/edge/widgets/etg.py b/nvflare/edge/widgets/etg.py new file mode 100644 index 0000000000..f65981e578 --- /dev/null +++ b/nvflare/edge/widgets/etg.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Edge Task Generator - for test only +Randomly generate edge tasks +""" +import threading +import time +import uuid + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.signal import Signal +from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status +from nvflare.widgets.widget import Widget + + +class EdgeTaskGenerator(Widget): + def __init__(self): + Widget.__init__(self) + self.generator = None + self.engine = None + self.job_id = None + self.abort_signal = Signal() + self.logger.info("EdgeTaskGenerator created!") + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.SYSTEM_START: + # start the generator + self.logger.info("Starting generator ...") + self.engine = fl_ctx.get_engine() + self.generator = threading.Thread(target=self._generate_tasks, daemon=True) + self.generator.start() + elif event_type == EventType.SYSTEM_END: + self.abort_signal.trigger(True) + + @staticmethod + def _make_task(): + return { + "device_id": str(uuid.uuid4()), + "request_type": "getTask", + } + + def _generate_tasks(self): + caps = ["xgb", "llm"] + while True: + if self.abort_signal.triggered: + self.logger.info("received abort signal - exiting") + return + + with self.engine.new_context() as fl_ctx: + assert isinstance(fl_ctx, FLContext) + if not self.job_id: + fl_ctx.set_prop(EdgeContextKey.EDGE_CAPABILITIES, caps, private=True, sticky=False) + self.fire_event(EdgeEventType.EDGE_JOB_REQUEST_RECEIVED, fl_ctx) + result = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) + if result: + assert isinstance(result, dict) + status = result[EdgeProtoKey.STATUS] + job_id = result[EdgeProtoKey.DATA] + self.logger.info(f"job reply from ETD: {status=} {job_id=}") + if job_id: + self.job_id = job_id + else: + self.logger.error(f"no result from ETD for event {EdgeEventType.EDGE_JOB_REQUEST_RECEIVED}") + else: + task = self._make_task() + fl_ctx.set_prop(EdgeContextKey.JOB_ID, self.job_id, sticky=False, private=True) + fl_ctx.set_prop(EdgeContextKey.REQUEST_FROM_EDGE, task, sticky=False, private=True) + self.fire_event(EdgeEventType.EDGE_REQUEST_RECEIVED, fl_ctx) + result = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) + if not result: + self.logger.error(f"no result from ETD for event {EdgeEventType.EDGE_REQUEST_RECEIVED}") + else: + status = result[EdgeProtoKey.STATUS] + edge_reply = result[EdgeProtoKey.DATA] + self.logger.info(f"task reply from ETD: {status=} {edge_reply=}") + + if status == Status.NO_JOB: + # job already finished + self.job_id = None + + time.sleep(1.0) diff --git a/nvflare/fuel/f3/cellnet/defs.py b/nvflare/fuel/f3/cellnet/defs.py index 3c6aab6693..d54214b07b 100644 --- a/nvflare/fuel/f3/cellnet/defs.py +++ b/nvflare/fuel/f3/cellnet/defs.py @@ -157,6 +157,7 @@ class CellChannel: MULTI_PROCESS_EXECUTOR = "multi_process_executor" SIMULATOR_RUNNER = "simulator_runner" RETURN_ONLY = "return_only" + EDGE_REQUEST = "edge_request" class CellChannelTopic: diff --git a/nvflare/fuel/f3/cellnet/utils.py b/nvflare/fuel/f3/cellnet/utils.py index 43d2ec16fe..a01a567190 100644 --- a/nvflare/fuel/f3/cellnet/utils.py +++ b/nvflare/fuel/f3/cellnet/utils.py @@ -35,6 +35,13 @@ } +def new_cell_message(headers: dict, payload=None): + msg_headers = {} + if headers: + msg_headers.update(headers) + return Message(msg_headers, payload) + + def make_reply(rc: str, error: str = "", body=None) -> Message: headers = {MessageHeaderKey.RETURN_CODE: rc} if error: diff --git a/nvflare/lighter/ctx.py b/nvflare/lighter/ctx.py index 0c06bf234a..6576f0b243 100644 --- a/nvflare/lighter/ctx.py +++ b/nvflare/lighter/ctx.py @@ -13,6 +13,7 @@ # limitations under the License. import json import os +from typing import Optional import yaml @@ -178,3 +179,12 @@ def warning(self, msg: str): logger.warning(msg) else: print(f"WARNING: {msg}") + + def get_result_location(self) -> Optional[str]: + """Get the directory of the provision result. + This should be called after the provision is done. + + Returns: the name of the directory that holds the provisioned result. + + """ + return self.get(CtxKey.CURRENT_PROD_DIR) diff --git a/nvflare/lighter/provisioner.py b/nvflare/lighter/provisioner.py index 939afc012e..94894b54fe 100644 --- a/nvflare/lighter/provisioner.py +++ b/nvflare/lighter/provisioner.py @@ -62,7 +62,7 @@ def _check_method(logger, method_name: str): elif not callable(getattr(logger, method_name)): raise ValueError(f"invalid logger {type(logger)}: method '{method_name}' is not callable") - def provision(self, project: Project, mode=None, logger=None): + def provision(self, project: Project, mode=None, logger=None) -> ProvisionContext: """Provision a specified project. Args: diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 3a2af90122..093e0185c6 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -19,7 +19,7 @@ # this import is to let existing scripts import from nvflare.private.defs from nvflare.fuel.f3.cellnet.defs import CellChannel, CellChannelTopic, SSLConstants # noqa: F401 -from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.cellnet.utils import new_cell_message # noqa: F401 from nvflare.fuel.hci.server.constants import ConnProps @@ -181,10 +181,3 @@ def __init__(self, client_name: str): self.client_name = client_name self.nonce = str(uuid.uuid4()) self.reg_start_time = time.time() - - -def new_cell_message(headers: dict, payload=None): - msg_headers = {} - if headers: - msg_headers.update(headers) - return Message(msg_headers, payload) diff --git a/nvflare/private/event.py b/nvflare/private/event.py index faf04442c0..c9f23ec1d1 100644 --- a/nvflare/private/event.py +++ b/nvflare/private/event.py @@ -58,7 +58,19 @@ def fire_event(event: str, handlers: list, ctx: FLContext): ctx.set_prop(key=FLContextKey.EVENT_DATA, value=event_data, private=True, sticky=False) ctx.set_prop(key=FLContextKey.EVENT_ORIGIN, value=event_origin, private=True, sticky=False) ctx.set_prop(key=FLContextKey.EVENT_SCOPE, value=event_scope, private=True, sticky=False) - h.handle_event(event, ctx) + + event_table = h.get_event_handlers() + if event_table: + entries = event_table.get(event) + if entries: + for cb, kwargs in entries: + cb(event, ctx, **kwargs) + else: + # no CB explicitly for this event - call the default handler. + h.handle_event(event, ctx) + else: + # no explicitly defined CBs - call the default handler. + h.handle_event(event, ctx) except Exception as e: h.log_exception( ctx, f'Exception when handling event "{event}": {secure_format_exception(e)}', fire_event=False diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 3543fb16cf..48170ac824 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -401,6 +401,21 @@ def abort_app(self, job_id: str) -> str: return "Abort signal has been sent to the client App." + def send_to_job(self, job_id, channel: str, topic: str, msg: CellMessage, timeout: float) -> CellMessage: + """Send a message to CJ + + Args: + job_id: id of the job + channel: message channel + topic: message topic + msg: the message to be sent + timeout: how long to wait for reply + + Returns: reply from CJ + + """ + return self.client_executor.send_to_job(job_id, channel, topic, msg, timeout) + def abort_task(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.NOT_STARTED: diff --git a/nvflare/private/fed/client/client_engine_executor_spec.py b/nvflare/private/fed/client/client_engine_executor_spec.py index 062955d505..3d3ce10327 100644 --- a/nvflare/private/fed/client/client_engine_executor_spec.py +++ b/nvflare/private/fed/client/client_engine_executor_spec.py @@ -180,3 +180,12 @@ def abort_app(self, job_id: str, fl_ctx: FLContext): """ pass + + @abstractmethod + def get_cell(self): + """Get communication cell + + Returns: + + """ + pass diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index c1c178d20b..cd8d0c104d 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -25,6 +25,7 @@ from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message @@ -210,7 +211,11 @@ def start_app( fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) - self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") + self.logger.info(f"Launched job {job_id} with job launcher: {type(job_launcher)} ") + + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + engine = fl_ctx.get_engine() + engine.fire_event(EventType.AFTER_JOB_LAUNCH, fl_ctx) client.multi_gpu = False @@ -419,6 +424,29 @@ def abort_app(self, job_id): self.logger.info("Client worker process is terminated.") + def send_to_job(self, job_id, channel: str, topic: str, msg: CellMessage, timeout: float) -> CellMessage: + """Send a message to CJ + + Args: + job_id: id of the job + channel: message channel + topic: message topic + msg: the message to be sent + timeout: how long to wait for reply + + Returns: reply from CJ + + """ + # send any serializable data to the job cell + return self.client.cell.send_request( + target=self._job_fqcn(job_id), + channel=channel, + topic=topic, + request=msg, + optional=False, + timeout=timeout, + ) + def _terminate_job(self, job_handle, job_id): max_wait = 10.0 done = False @@ -500,6 +528,7 @@ def _wait_child_process_finish( fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CLIENT_NAME, client.client_name, private=True, sticky=False) engine.fire_event(EventType.JOB_COMPLETED, fl_ctx) + self.logger.info(f"Fired event JOB_COMPLETED {EventType.JOB_COMPLETED}") def get_status(self, job_id): process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED) diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index dbe4c82b12..9ee233744b 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -26,7 +26,12 @@ from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.reliable_message import ReliableMessage from nvflare.apis.utils.task_utils import apply_filters +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status as EdgeStatus +from nvflare.fuel.f3.cellnet.defs import CellChannel from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.cellnet.utils import make_reply as make_cell_reply +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment from nvflare.private.fed.tbi import TBI @@ -157,6 +162,35 @@ def __init__( self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None) self._register_aux_message_handlers(engine) + def set_cell(self, cell): + cell.register_request_cb( + channel=CellChannel.EDGE_REQUEST, + topic="*", + cb=self._receive_edge_request, + ) + + def _receive_edge_request(self, request: CellMessage): + with self.engine.new_context() as fl_ctx: + assert isinstance(fl_ctx, FLContext) + try: + # place the cell message into fl_ctx in case it's needed by process_edge_request. + fl_ctx.set_prop(FLContextKey.CELL_MESSAGE, request, private=True, sticky=False) + self.engine.fire_event(EdgeEventType.EDGE_REQUEST_RECEIVED, fl_ctx) + exception = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + if exception: + return make_cell_reply(EdgeStatus.PROCESS_EXCEPTION) + + reply = fl_ctx.get_prop(FLContextKey.TASK_RESULT) + if not reply: + self.logger.debug("no result for edge request") + return make_cell_reply(EdgeStatus.NO_TASK) + else: + self.logger.info(f"sending back edge result: {reply}") + return make_cell_reply(EdgeStatus.OK, body=reply) + except Exception as ex: + self.log_error(fl_ctx, f"exception from receive_edge_request: {secure_format_exception(ex)}") + return make_cell_reply(EdgeStatus.PROCESS_EXCEPTION) + def find_executor(self, task_name): return self.task_router.route(task_name) diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 08691e81f6..1c9e07cc3e 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -262,6 +262,7 @@ def _create_cell(self, location, scheme): time.sleep(self.cell_check_frequency) self.logger.info(f"Got client_runner after {time.time() - start} seconds") self.client_runner.engine.cell = self.cell + self.client_runner.set_cell(self.cell) else: start = time.time() self.logger.info("Wait for engine to be created.")