diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 64ace70972..944b202b22 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -90,6 +90,7 @@ class EventType(object): AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" BEFORE_JOB_LAUNCH = "_before_job_launch" + AFTER_JOB_LAUNCH = "_after_job_launch" TASK_RESULT_RECEIVED = "_task_result_received" TASK_ASSIGNMENT_SENT = "_task_assignment_sent" diff --git a/nvflare/apis/fl_component.py b/nvflare/apis/fl_component.py index c126bd51ec..cf59e82a46 100644 --- a/nvflare/apis/fl_component.py +++ b/nvflare/apis/fl_component.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union from nvflare.apis.utils.fl_context_utils import generate_log_message from nvflare.fuel.utils.log_utils import get_obj_logger @@ -35,6 +36,7 @@ def __init__(self): """ self._name = self.__class__.__name__ self.logger = get_obj_logger(self) + self._event_handlers = {} @property def name(self): @@ -227,3 +229,31 @@ def _fire_log_event(self, event_type: str, log_tag: str, log_msg: str, fl_ctx: F dxo = event_data.to_dxo() fl_ctx.set_prop(key=FLContextKey.EVENT_DATA, value=dxo.to_shareable(), private=True, sticky=False) self.fire_event(event_type=event_type, fl_ctx=fl_ctx) + + def register_event_handler(self, event_types: Union[str, List[str]], handler, **kwargs): + if isinstance(event_types, str): + event_types = [event_types] + elif not isinstance(event_types, list): + raise ValueError(f"event_types must be string or list of strings but got {type(event_types)}") + + if not callable(handler): + raise ValueError(f"handler {handler.__name__} is not callable") + + for e in event_types: + entries = self._event_handlers.get(e) + if not entries: + entries = [] + self._event_handlers[e] = entries + + already_registered = False + for h, _ in entries: + if handler == h: + # already registered: either by a super class or by the class itself. + already_registered = True + break + + if not already_registered: + entries.append((handler, kwargs)) + + def get_event_handlers(self): + return self._event_handlers diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index c53e3b3c90..89a3c3e127 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -193,6 +193,7 @@ class FLContextKey(object): PROCESS_TYPE = ReservedKey.PROCESS_TYPE JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS EVENT_PROCESSED = "__event_processed__" + CELL_MESSAGE = "__cell_message__" class ProcessType: diff --git a/nvflare/apis/job_def.py b/nvflare/apis/job_def.py index a36720c825..44e781a3d9 100644 --- a/nvflare/apis/job_def.py +++ b/nvflare/apis/job_def.py @@ -74,6 +74,7 @@ class JobMetaKey(str, Enum): STATS_POOL_CONFIG = "stats_pool_config" FROM_HUB_SITE = "from_hub_site" CUSTOM_PROPS = "custom_props" + EDGE_METHOD = "edge_method" def __repr__(self): return self.value diff --git a/nvflare/app_common/executors/ham.py b/nvflare/app_common/executors/ham.py index 7cb63ef042..40b1477e0b 100644 --- a/nvflare/app_common/executors/ham.py +++ b/nvflare/app_common/executors/ham.py @@ -57,67 +57,77 @@ def __init__( self._aggr_lock = threading.Lock() self._process_error = None - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - engine = fl_ctx.get_engine() - - aggr = engine.get_component(self.aggregator_id) - if not isinstance(aggr, Aggregator): - self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}") - self.aggregator = aggr - - learner = engine.get_component(self.learner_id) - if not isinstance(learner, Executor): - self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}") - self.learner = learner - elif event_type == EventType.TASK_ASSIGNMENT_SENT: - # the task was sent to a child client - if not self.pending_task_id: - # I don't have a pending task - return - - child_client_ctx = fl_ctx.get_peer_context() - assert isinstance(child_client_ctx, FLContext) - child_client_name = child_client_ctx.get_identity_name() - self._update_client_status(child_client_name, None) - task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) - - # indicate that this event has been processed by me - fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) - self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}") - elif event_type == EventType.TASK_RESULT_RECEIVED: - # received results from a child client - if not self.pending_task_id: - # I don't have a pending task - return - - # indicate that this event has been processed by me - fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) - - result = fl_ctx.get_prop(FLContextKey.TASK_RESULT) - assert isinstance(result, Shareable) - task_id = result.get_header(ReservedKey.TASK_ID) - peer_ctx = fl_ctx.get_peer_context() - assert isinstance(peer_ctx, FLContext) - child_client_name = peer_ctx.get_identity_name() - self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}") - - if task_id != self.pending_task_id: - self.log_warning( - fl_ctx, - f"dropped the received result from child {child_client_name} " - f"for task {task_id} while waiting for task {self.pending_task_id}", - ) - return - - rc = result.get_return_code(ReturnCode.OK) - if rc == ReturnCode.OK: - self.log_info(fl_ctx, f"accepting result from client {child_client_name}") - self._do_aggregation(result, fl_ctx) - else: - self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}") - self.log_info(fl_ctx, f"received result from child {child_client_name}") - self._update_client_status(child_client_name, time.time()) + self.register_event_handler(EventType.START_RUN, self._handle_start_run) + self.register_event_handler(EventType.TASK_ASSIGNMENT_SENT, self._handle_task_sent) + self.register_event_handler(EventType.TASK_RESULT_RECEIVED, self._handle_result_received) + + def _handle_start_run(self, event_type: str, fl_ctx: FLContext): + self.log_debug(fl_ctx, f"handling event {event_type}") + engine = fl_ctx.get_engine() + + aggr = engine.get_component(self.aggregator_id) + if not isinstance(aggr, Aggregator): + self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}") + self.aggregator = aggr + + learner = engine.get_component(self.learner_id) + if not isinstance(learner, Executor): + self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}") + self.learner = learner + + def _handle_task_sent(self, event_type: str, fl_ctx: FLContext): + # the task was sent to a child client + self.log_debug(fl_ctx, f"handling event {event_type}") + + if not self.pending_task_id: + # I don't have a pending task + return + + child_client_ctx = fl_ctx.get_peer_context() + assert isinstance(child_client_ctx, FLContext) + child_client_name = child_client_ctx.get_identity_name() + self._update_client_status(child_client_name, None) + task_id = fl_ctx.get_prop(FLContextKey.TASK_ID) + + # indicate that this event has been processed by me + fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) + self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}") + + def _handle_result_received(self, event_type: str, fl_ctx: FLContext): + # received results from a child client + self.log_debug(fl_ctx, f"handling event {event_type}") + + if not self.pending_task_id: + # I don't have a pending task + return + + # indicate that this event has been processed by me + fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False) + + result = fl_ctx.get_prop(FLContextKey.TASK_RESULT) + assert isinstance(result, Shareable) + task_id = result.get_header(ReservedKey.TASK_ID) + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + child_client_name = peer_ctx.get_identity_name() + self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}") + + if task_id != self.pending_task_id: + self.log_warning( + fl_ctx, + f"dropped the received result from child {child_client_name} " + f"for task {task_id} while waiting for task {self.pending_task_id}", + ) + return + + rc = result.get_return_code(ReturnCode.OK) + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"accepting result from client {child_client_name}") + self._do_aggregation(result, fl_ctx) + else: + self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}") + self.log_info(fl_ctx, f"received result from child {child_client_name}") + self._update_client_status(child_client_name, time.time()) def _do_aggregation(self, result: Shareable, fl_ctx: FLContext): with self._aggr_lock: diff --git a/nvflare/edge/__init__.py b/nvflare/edge/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/aggregators/__init__.py b/nvflare/edge/aggregators/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/aggregators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/aggregators/edge_survey_aggr.py b/nvflare/edge/aggregators/edge_survey_aggr.py new file mode 100644 index 0000000000..e86cb2857c --- /dev/null +++ b/nvflare/edge/aggregators/edge_survey_aggr.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.aggregator import Aggregator + + +class EdgeSurveyAggregator(Aggregator): + def __init__(self): + Aggregator.__init__(self) + self.num_devices = 0 + + def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: + self.log_info(fl_ctx, f"accepting: {shareable}") + num_devices = shareable.get("num_devices") + if num_devices: + self.num_devices += num_devices + return True + + def reset(self, fl_ctx: FLContext): + self.num_devices = 0 + + def aggregate(self, fl_ctx: FLContext) -> Shareable: + self.log_info(fl_ctx, f"aggregating final result: {self.num_devices}") + return Shareable({"num_devices": self.num_devices}) diff --git a/nvflare/edge/constants.py b/nvflare/edge/constants.py new file mode 100644 index 0000000000..6105c865f6 --- /dev/null +++ b/nvflare/edge/constants.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode + + +class Status(CellReturnCode): + NO_TASK = "no_task" + NO_JOB = "no_job" + + +class EdgeProtoKey: + STATUS = "status" + DATA = "data" + + +class EdgeContextKey: + JOB_ID = "__edge_job_id__" + EDGE_CAPABILITIES = "__edge_capabilities__" + REQUEST_FROM_EDGE = "__request_from_edge__" + REPLY_TO_EDGE = "__reply_to_edge__" + + +class EventType: + EDGE_REQUEST_RECEIVED = "_edge_request_received" + EDGE_JOB_REQUEST_RECEIVED = "_edge_job_request_received" diff --git a/nvflare/edge/controllers/__init__.py b/nvflare/edge/controllers/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/controllers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/controllers/edge_survey_ctl.py b/nvflare/edge/controllers/edge_survey_ctl.py new file mode 100644 index 0000000000..1103e1ebca --- /dev/null +++ b/nvflare/edge/controllers/edge_survey_ctl.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + + +class EdgeSurveyController(Controller): + def __init__(self, num_rounds: int, timeout: int): + Controller.__init__(self) + self.num_rounds = num_rounds + self.timeout = timeout + + def start_controller(self, fl_ctx: FLContext): + pass + + def stop_controller(self, fl_ctx: FLContext): + pass + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + for r in range(self.num_rounds): + task = Task( + name="survey", + data=Shareable(), + timeout=self.timeout, + ) + + self.broadcast_and_wait( + task=task, + min_responses=2, + wait_time_after_min_received=0, + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + total_devices = 0 + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + result = ct.result + assert isinstance(result, Shareable) + self.log_info(fl_ctx, f"result from client {ct.client.name}: {result}") + count = result.get("num_devices") + if count: + total_devices += count + + self.log_info(fl_ctx, f"total devices in round {r}: {total_devices}") diff --git a/nvflare/edge/executors/__init__.py b/nvflare/edge/executors/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/executors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/executors/edge_survey_executor.py b/nvflare/edge/executors/edge_survey_executor.py new file mode 100644 index 0000000000..d3c42f2dca --- /dev/null +++ b/nvflare/edge/executors/edge_survey_executor.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from typing import Any + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.edge.executors.ete import EdgeTaskExecutor + + +class EdgeSurveyExecutor(EdgeTaskExecutor): + """This executor is for test purpose only. It is to be used as the "learner" for the + HierarchicalAggregationManager. + """ + + def __init__(self, timeout=10.0): + EdgeTaskExecutor.__init__(self) + self.timeout = timeout + self.num_devices = 0 + self.start_time = None + + def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): + self.num_devices = 0 + self.start_time = time.time() + + def is_task_done(self, fl_ctx: FLContext) -> bool: + return time.time() - self.start_time > self.timeout + + def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: + assert isinstance(request, dict) + self.log_info(fl_ctx, f"received edge request: {request}") + self.num_devices += 1 + return {"status": "tryAgain", "comment": f"received {request}"} + + def get_task_result(self, fl_ctx: FLContext) -> Shareable: + result = make_reply(ReturnCode.OK) + result["num_devices"] = self.num_devices + return result diff --git a/nvflare/edge/executors/ete.py b/nvflare/edge/executors/ete.py new file mode 100644 index 0000000000..be0b84be3c --- /dev/null +++ b/nvflare/edge/executors/ete.py @@ -0,0 +1,152 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from abc import abstractmethod +from typing import Any + +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.security.logging import secure_format_exception + + +class EdgeTaskExecutor(Executor): + """This is the base class for executors to handling requests from edge devices. + Subclasses must implement the required abstract methods defined here. + """ + + def __init__(self): + """Constructor of EdgeTaskExecutor""" + Executor.__init__(self) + self.current_task = None + + self.register_event_handler(EdgeEventType.EDGE_REQUEST_RECEIVED, self._handle_edge_request) + + @abstractmethod + def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: + """This is called to process an edge request sent from the edge device. + + Args: + request: the request from edge device + fl_ctx: FLContext object + + Returns: reply to the edge device + + """ + pass + + def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): + """This method is called when a task assignment is received from the controller. + Subclass can implement this method to prepare for task processing. + + Args: + task_name: name of the task + task_data: task data + fl_ctx: FLContext object + + Returns: None + + """ + pass + + @abstractmethod + def is_task_done(self, fl_ctx: FLContext) -> bool: + """This is called by the base class to determine whether the task processing is done. + Subclass must implement this method. + + Args: + fl_ctx: FLContext object + + Returns: whether task is done. + + """ + pass + + @abstractmethod + def get_task_result(self, fl_ctx: FLContext) -> Shareable: + """This is called by the base class to get the final result of the task. + Base class will send the result to the controller. + + Args: + fl_ctx: FLContext object + + Returns: a Shareable object that is the task result + + """ + pass + + def _handle_edge_request(self, event_type: str, fl_ctx: FLContext): + if not self.current_task: + self.logger.debug(f"received edge event {event_type} but I don't have pending task") + return + + try: + msg = fl_ctx.get_prop(FLContextKey.CELL_MESSAGE) + assert isinstance(msg, CellMessage) + self.log_debug(fl_ctx, f"received edge request: {msg.payload}") + reply = self.process_edge_request(request=msg.payload, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.TASK_RESULT, reply, private=True, sticky=False) + except Exception as ex: + self.logger.error(f"exception from self.process_edge_request: {secure_format_exception(ex)}") + fl_ctx.set_prop(FLContextKey.EXCEPTIONS, ex, private=True, sticky=False) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.current_task = shareable + result = self._execute(task_name, shareable, fl_ctx, abort_signal) + self.current_task = None + return result + + def _execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + self.task_received(task_name, shareable, fl_ctx) + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.task_received: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + start_time = time.time() + while True: + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + try: + task_done = self.is_task_done(fl_ctx) + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.is_task_done: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + if task_done: + break + + time.sleep(0.2) + + self.log_debug(fl_ctx, f"task done after {time.time() - start_time} seconds") + try: + result = self.get_task_result(fl_ctx) + + if not isinstance(result, Shareable): + self.log_error( + fl_ctx, + f"bad result from self.get_task_result: expect Shareable but got {type(result)}", + ) + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + except Exception as ex: + self.log_error(fl_ctx, f"exception from self.get_task_result: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + return result diff --git a/nvflare/lighter/tree_prov.py b/nvflare/edge/tree_prov.py similarity index 88% rename from nvflare/lighter/tree_prov.py rename to nvflare/edge/tree_prov.py index 15bba28aba..5b60f8ced7 100644 --- a/nvflare/lighter/tree_prov.py +++ b/nvflare/edge/tree_prov.py @@ -18,14 +18,15 @@ """ import argparse +import json +import os.path +from nvflare.lighter.entity import Participant, ParticipantType, Project from nvflare.lighter.impl.cert import CertBuilder from nvflare.lighter.impl.signature import SignatureBuilder from nvflare.lighter.impl.static_file import StaticFileBuilder from nvflare.lighter.impl.workspace import WorkspaceBuilder - -from .entity import Participant, ParticipantType, Project -from .provisioner import Provisioner +from nvflare.lighter.provisioner import Provisioner def _new_participant(name: str, ptype: str, props: dict) -> Participant: @@ -46,20 +47,33 @@ class Stats: num_non_leaf_clients = 0 -class _Node: - +class PortManager: last_port_number = 9000 + @classmethod + def get_port(cls): + cls.last_port_number += 1 + return cls.last_port_number + + +class _Node: def __init__(self): self.name = None self.client_name = None self.parent = None self.children = [] - self.port = _Node.last_port_number - _Node.last_port_number += 1 - - -def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clients: int, project: Project): + self.port = PortManager.get_port() + + +def _build_tree( + depth: int, + width: int, + max_depth: int, + parent: _Node, + num_clients: int, + project: Project, + lcp_map: dict, +): """Build relay hierarchy and client hierarchy, recursively. Relays are organized hierarchically. Attach a client to each relay. Such clients are non-leaf clients (a.k.a @@ -82,7 +96,7 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien """ if depth == max_depth: - # the parent is a leaf node - add leaf clients + # the parent is a leaf node - add leaf clients (LCPs) Stats.num_leaf_relays += 1 for i in range(num_clients): name = _make_client_name(parent.name) + str(i + 1) @@ -92,6 +106,8 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien project.add_participant(client) Stats.num_clients += 1 Stats.num_leaf_clients += 1 + + lcp_map[name] = {"host": "localhost", "port": PortManager.get_port()} return if depth > 0: @@ -128,7 +144,7 @@ def _build_tree(depth: int, width: int, max_depth: int, parent: _Node, num_clien Stats.num_non_leaf_clients += 1 # depth-first recursion - _build_tree(depth + 1, width, max_depth, child, num_clients, project) + _build_tree(depth + 1, width, max_depth, child, num_clients, project, lcp_map) def main(): @@ -206,7 +222,8 @@ def main(): # add relays and clients root_relay = _Node() root_relay.name = "R" - _build_tree(0, args.width, args.depth, root_relay, args.clients, project) + lcp_map = {} + _build_tree(0, args.width, args.depth, root_relay, args.clients, project, lcp_map) total_sites = Stats.num_clients + Stats.num_relays + 1 @@ -226,7 +243,12 @@ def main(): "admin@nvidia.com", ParticipantType.ADMIN, props={"role": "project_admin", "connect_to": "localhost"} ) project.add_participant(admin) - provisioner.provision(project) + ctx = provisioner.provision(project) + location = ctx.get_result_location() + lcp_map_file_name = os.path.join(location, "lcp_map.json") + with open(lcp_map_file_name, "wt") as f: + json.dump(lcp_map, f, indent=4) + print(f"Generated LCP Map: {lcp_map_file_name}") if __name__ == "__main__": diff --git a/nvflare/edge/widgets/__init__.py b/nvflare/edge/widgets/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nvflare/edge/widgets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/edge/widgets/etd.py b/nvflare/edge/widgets/etd.py new file mode 100644 index 0000000000..c1aa1616a9 --- /dev/null +++ b/nvflare/edge/widgets/etd.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from random import randrange + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.job_def import JobMetaKey +from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status as EdgeStatus +from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey +from nvflare.fuel.f3.cellnet.utils import new_cell_message +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.widgets.widget import Widget + + +class EdgeTaskDispatcher(Widget): + """Edge Task Dispatcher (ETD) is to be used to dispatch a received edge request to a running job (CJ). + ETD must be installed on CP before the CP is started. + """ + + def __init__(self, request_timeout: float = 2.0): + Widget.__init__(self) + self.request_timeout = request_timeout + self.edge_jobs = {} # edge_method => list of job_ids + self.lock = threading.Lock() + + self.register_event_handler( + EventType.AFTER_JOB_LAUNCH, + self._handle_job_launched, + ) + self.register_event_handler( + [EventType.JOB_COMPLETED, EventType.JOB_CANCELLED, EventType.JOB_ABORTED], + self._handle_job_done, + ) + self.register_event_handler( + EdgeEventType.EDGE_JOB_REQUEST_RECEIVED, + self._handle_edge_job_request, + ) + self.register_event_handler( + EdgeEventType.EDGE_REQUEST_RECEIVED, + self._handle_edge_request, + ) + self.logger.info("EdgeTaskDispatcher created!") + + def _add_job(self, job_meta: dict): + with self.lock: + edge_method = job_meta.get(JobMetaKey.EDGE_METHOD) + if not edge_method: + # this is not an edge job + return + + jobs = self.edge_jobs.get(edge_method) + if not jobs: + jobs = [] + self.edge_jobs[edge_method] = jobs + + job_id = job_meta.get(JobMetaKey.JOB_ID) + jobs.append(job_id) + + def _remove_job(self, job_meta: dict): + with self.lock: + job_id = job_meta.get(JobMetaKey.JOB_ID) + edge_method = job_meta.get(JobMetaKey.EDGE_METHOD) + if not edge_method: + # this is not an edge job + self.logger.info(f"no edge_method in job {job_id}") + return + + jobs = self.edge_jobs.get(edge_method) + if not jobs: + self.logger.info("no edge jobs pending") + return + + assert isinstance(jobs, list) + job_id = job_meta.get(JobMetaKey.JOB_ID) + self.logger.info(f"current jobs for {edge_method}: {jobs}") + if job_id in jobs: + jobs.remove(job_id) + if not jobs: + # no more jobs for this edge method + self.edge_jobs.pop(edge_method) + + def _match_job(self, caps: list): + with self.lock: + for edge_method, jobs in self.edge_jobs.items(): + if edge_method in caps: + # pick one randomly + i = randrange(len(jobs)) + return jobs[i] + + # no job matched + return None + + def _find_job(self, job_id: str): + with self.lock: + for jobs in self.edge_jobs.values(): + if job_id in jobs: + return True + return False + + def _handle_job_launched(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if not job_meta: + self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}") + else: + self._add_job(job_meta) + + def _handle_job_done(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) + if not job_meta: + self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}") + else: + self._remove_job(job_meta) + + def _handle_edge_job_request(self, event_type: str, fl_ctx: FLContext): + self.logger.info(f"handling event {event_type}") + edge_capabilities = fl_ctx.get_prop(EdgeContextKey.EDGE_CAPABILITIES) + if not edge_capabilities: + self.logger.error(f"missing {EdgeContextKey.EDGE_CAPABILITIES} from fl_ctx for event {event_type}") + self._set_edge_reply(EdgeStatus.INVALID_REQUEST, None, fl_ctx) + return + + # find job for the caps + job_id = self._match_job(edge_capabilities) + if job_id: + status = EdgeStatus.OK + else: + status = EdgeStatus.NO_JOB + self._set_edge_reply(status, job_id, fl_ctx) + + @staticmethod + def _set_edge_reply(status, data, fl_ctx: FLContext): + fl_ctx.set_prop( + key=EdgeContextKey.REPLY_TO_EDGE, + value={EdgeProtoKey.STATUS: status, EdgeProtoKey.DATA: data}, + private=True, + sticky=False, + ) + + def _handle_edge_request(self, event_type: str, fl_ctx: FLContext): + # try to find the job + job_id = fl_ctx.get_prop(EdgeContextKey.JOB_ID) + if not job_id: + self.logger.error(f"handling event {event_type}: missing {EdgeContextKey.JOB_ID} from fl_ctx") + self._set_edge_reply(EdgeStatus.INVALID_REQUEST, None, fl_ctx) + return + + if not self._find_job(job_id): + self._set_edge_reply(EdgeStatus.NO_JOB, None, fl_ctx) + return + + # send edge request data to CJ + edge_req_data = fl_ctx.get_prop(EdgeContextKey.REQUEST_FROM_EDGE) + self.logger.info(f"Sending edge request to CJ {job_id}: {edge_req_data}") + engine = fl_ctx.get_engine() + reply = engine.send_to_job( + job_id=job_id, + channel=CellChannel.EDGE_REQUEST, + topic="request", + msg=new_cell_message({}, edge_req_data), + timeout=self.request_timeout, + ) + + assert isinstance(reply, CellMessage) + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + reply_data = reply.payload + self.logger.debug(f"got edge result from CJ: {rc=} {reply_data=}") + self._set_edge_reply(rc, reply_data, fl_ctx) diff --git a/nvflare/edge/widgets/etg.py b/nvflare/edge/widgets/etg.py new file mode 100644 index 0000000000..f65981e578 --- /dev/null +++ b/nvflare/edge/widgets/etg.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Edge Task Generator - for test only +Randomly generate edge tasks +""" +import threading +import time +import uuid + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.signal import Signal +from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status +from nvflare.widgets.widget import Widget + + +class EdgeTaskGenerator(Widget): + def __init__(self): + Widget.__init__(self) + self.generator = None + self.engine = None + self.job_id = None + self.abort_signal = Signal() + self.logger.info("EdgeTaskGenerator created!") + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.SYSTEM_START: + # start the generator + self.logger.info("Starting generator ...") + self.engine = fl_ctx.get_engine() + self.generator = threading.Thread(target=self._generate_tasks, daemon=True) + self.generator.start() + elif event_type == EventType.SYSTEM_END: + self.abort_signal.trigger(True) + + @staticmethod + def _make_task(): + return { + "device_id": str(uuid.uuid4()), + "request_type": "getTask", + } + + def _generate_tasks(self): + caps = ["xgb", "llm"] + while True: + if self.abort_signal.triggered: + self.logger.info("received abort signal - exiting") + return + + with self.engine.new_context() as fl_ctx: + assert isinstance(fl_ctx, FLContext) + if not self.job_id: + fl_ctx.set_prop(EdgeContextKey.EDGE_CAPABILITIES, caps, private=True, sticky=False) + self.fire_event(EdgeEventType.EDGE_JOB_REQUEST_RECEIVED, fl_ctx) + result = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) + if result: + assert isinstance(result, dict) + status = result[EdgeProtoKey.STATUS] + job_id = result[EdgeProtoKey.DATA] + self.logger.info(f"job reply from ETD: {status=} {job_id=}") + if job_id: + self.job_id = job_id + else: + self.logger.error(f"no result from ETD for event {EdgeEventType.EDGE_JOB_REQUEST_RECEIVED}") + else: + task = self._make_task() + fl_ctx.set_prop(EdgeContextKey.JOB_ID, self.job_id, sticky=False, private=True) + fl_ctx.set_prop(EdgeContextKey.REQUEST_FROM_EDGE, task, sticky=False, private=True) + self.fire_event(EdgeEventType.EDGE_REQUEST_RECEIVED, fl_ctx) + result = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) + if not result: + self.logger.error(f"no result from ETD for event {EdgeEventType.EDGE_REQUEST_RECEIVED}") + else: + status = result[EdgeProtoKey.STATUS] + edge_reply = result[EdgeProtoKey.DATA] + self.logger.info(f"task reply from ETD: {status=} {edge_reply=}") + + if status == Status.NO_JOB: + # job already finished + self.job_id = None + + time.sleep(1.0) diff --git a/nvflare/fuel/f3/cellnet/defs.py b/nvflare/fuel/f3/cellnet/defs.py index 3c6aab6693..d54214b07b 100644 --- a/nvflare/fuel/f3/cellnet/defs.py +++ b/nvflare/fuel/f3/cellnet/defs.py @@ -157,6 +157,7 @@ class CellChannel: MULTI_PROCESS_EXECUTOR = "multi_process_executor" SIMULATOR_RUNNER = "simulator_runner" RETURN_ONLY = "return_only" + EDGE_REQUEST = "edge_request" class CellChannelTopic: diff --git a/nvflare/fuel/f3/cellnet/utils.py b/nvflare/fuel/f3/cellnet/utils.py index 43d2ec16fe..a01a567190 100644 --- a/nvflare/fuel/f3/cellnet/utils.py +++ b/nvflare/fuel/f3/cellnet/utils.py @@ -35,6 +35,13 @@ } +def new_cell_message(headers: dict, payload=None): + msg_headers = {} + if headers: + msg_headers.update(headers) + return Message(msg_headers, payload) + + def make_reply(rc: str, error: str = "", body=None) -> Message: headers = {MessageHeaderKey.RETURN_CODE: rc} if error: diff --git a/nvflare/lighter/ctx.py b/nvflare/lighter/ctx.py index 0c06bf234a..6576f0b243 100644 --- a/nvflare/lighter/ctx.py +++ b/nvflare/lighter/ctx.py @@ -13,6 +13,7 @@ # limitations under the License. import json import os +from typing import Optional import yaml @@ -178,3 +179,12 @@ def warning(self, msg: str): logger.warning(msg) else: print(f"WARNING: {msg}") + + def get_result_location(self) -> Optional[str]: + """Get the directory of the provision result. + This should be called after the provision is done. + + Returns: the name of the directory that holds the provisioned result. + + """ + return self.get(CtxKey.CURRENT_PROD_DIR) diff --git a/nvflare/lighter/provisioner.py b/nvflare/lighter/provisioner.py index 939afc012e..94894b54fe 100644 --- a/nvflare/lighter/provisioner.py +++ b/nvflare/lighter/provisioner.py @@ -62,7 +62,7 @@ def _check_method(logger, method_name: str): elif not callable(getattr(logger, method_name)): raise ValueError(f"invalid logger {type(logger)}: method '{method_name}' is not callable") - def provision(self, project: Project, mode=None, logger=None): + def provision(self, project: Project, mode=None, logger=None) -> ProvisionContext: """Provision a specified project. Args: diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 3a2af90122..093e0185c6 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -19,7 +19,7 @@ # this import is to let existing scripts import from nvflare.private.defs from nvflare.fuel.f3.cellnet.defs import CellChannel, CellChannelTopic, SSLConstants # noqa: F401 -from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.cellnet.utils import new_cell_message # noqa: F401 from nvflare.fuel.hci.server.constants import ConnProps @@ -181,10 +181,3 @@ def __init__(self, client_name: str): self.client_name = client_name self.nonce = str(uuid.uuid4()) self.reg_start_time = time.time() - - -def new_cell_message(headers: dict, payload=None): - msg_headers = {} - if headers: - msg_headers.update(headers) - return Message(msg_headers, payload) diff --git a/nvflare/private/event.py b/nvflare/private/event.py index faf04442c0..c9f23ec1d1 100644 --- a/nvflare/private/event.py +++ b/nvflare/private/event.py @@ -58,7 +58,19 @@ def fire_event(event: str, handlers: list, ctx: FLContext): ctx.set_prop(key=FLContextKey.EVENT_DATA, value=event_data, private=True, sticky=False) ctx.set_prop(key=FLContextKey.EVENT_ORIGIN, value=event_origin, private=True, sticky=False) ctx.set_prop(key=FLContextKey.EVENT_SCOPE, value=event_scope, private=True, sticky=False) - h.handle_event(event, ctx) + + event_table = h.get_event_handlers() + if event_table: + entries = event_table.get(event) + if entries: + for cb, kwargs in entries: + cb(event, ctx, **kwargs) + else: + # no CB explicitly for this event - call the default handler. + h.handle_event(event, ctx) + else: + # no explicitly defined CBs - call the default handler. + h.handle_event(event, ctx) except Exception as e: h.log_exception( ctx, f'Exception when handling event "{event}": {secure_format_exception(e)}', fire_event=False diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 3543fb16cf..48170ac824 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -401,6 +401,21 @@ def abort_app(self, job_id: str) -> str: return "Abort signal has been sent to the client App." + def send_to_job(self, job_id, channel: str, topic: str, msg: CellMessage, timeout: float) -> CellMessage: + """Send a message to CJ + + Args: + job_id: id of the job + channel: message channel + topic: message topic + msg: the message to be sent + timeout: how long to wait for reply + + Returns: reply from CJ + + """ + return self.client_executor.send_to_job(job_id, channel, topic, msg, timeout) + def abort_task(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.NOT_STARTED: diff --git a/nvflare/private/fed/client/client_engine_executor_spec.py b/nvflare/private/fed/client/client_engine_executor_spec.py index 062955d505..3d3ce10327 100644 --- a/nvflare/private/fed/client/client_engine_executor_spec.py +++ b/nvflare/private/fed/client/client_engine_executor_spec.py @@ -180,3 +180,12 @@ def abort_app(self, job_id: str, fl_ctx: FLContext): """ pass + + @abstractmethod + def get_cell(self): + """Get communication cell + + Returns: + + """ + pass diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index c1c178d20b..cd8d0c104d 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -25,6 +25,7 @@ from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message @@ -210,7 +211,11 @@ def start_app( fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) - self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") + self.logger.info(f"Launched job {job_id} with job launcher: {type(job_launcher)} ") + + fl_ctx.set_prop(FLContextKey.JOB_META, job_meta, private=True, sticky=False) + engine = fl_ctx.get_engine() + engine.fire_event(EventType.AFTER_JOB_LAUNCH, fl_ctx) client.multi_gpu = False @@ -419,6 +424,29 @@ def abort_app(self, job_id): self.logger.info("Client worker process is terminated.") + def send_to_job(self, job_id, channel: str, topic: str, msg: CellMessage, timeout: float) -> CellMessage: + """Send a message to CJ + + Args: + job_id: id of the job + channel: message channel + topic: message topic + msg: the message to be sent + timeout: how long to wait for reply + + Returns: reply from CJ + + """ + # send any serializable data to the job cell + return self.client.cell.send_request( + target=self._job_fqcn(job_id), + channel=channel, + topic=topic, + request=msg, + optional=False, + timeout=timeout, + ) + def _terminate_job(self, job_handle, job_id): max_wait = 10.0 done = False @@ -500,6 +528,7 @@ def _wait_child_process_finish( fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CLIENT_NAME, client.client_name, private=True, sticky=False) engine.fire_event(EventType.JOB_COMPLETED, fl_ctx) + self.logger.info(f"Fired event JOB_COMPLETED {EventType.JOB_COMPLETED}") def get_status(self, job_id): process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED) diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index dbe4c82b12..9ee233744b 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -26,7 +26,12 @@ from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.reliable_message import ReliableMessage from nvflare.apis.utils.task_utils import apply_filters +from nvflare.edge.constants import EventType as EdgeEventType +from nvflare.edge.constants import Status as EdgeStatus +from nvflare.fuel.f3.cellnet.defs import CellChannel from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.cellnet.utils import make_reply as make_cell_reply +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment from nvflare.private.fed.tbi import TBI @@ -157,6 +162,35 @@ def __init__( self.submit_task_result_timeout = self.get_positive_float_var(ConfigVarName.SUBMIT_TASK_RESULT_TIMEOUT, None) self._register_aux_message_handlers(engine) + def set_cell(self, cell): + cell.register_request_cb( + channel=CellChannel.EDGE_REQUEST, + topic="*", + cb=self._receive_edge_request, + ) + + def _receive_edge_request(self, request: CellMessage): + with self.engine.new_context() as fl_ctx: + assert isinstance(fl_ctx, FLContext) + try: + # place the cell message into fl_ctx in case it's needed by process_edge_request. + fl_ctx.set_prop(FLContextKey.CELL_MESSAGE, request, private=True, sticky=False) + self.engine.fire_event(EdgeEventType.EDGE_REQUEST_RECEIVED, fl_ctx) + exception = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + if exception: + return make_cell_reply(EdgeStatus.PROCESS_EXCEPTION) + + reply = fl_ctx.get_prop(FLContextKey.TASK_RESULT) + if not reply: + self.logger.debug("no result for edge request") + return make_cell_reply(EdgeStatus.NO_TASK) + else: + self.logger.info(f"sending back edge result: {reply}") + return make_cell_reply(EdgeStatus.OK, body=reply) + except Exception as ex: + self.log_error(fl_ctx, f"exception from receive_edge_request: {secure_format_exception(ex)}") + return make_cell_reply(EdgeStatus.PROCESS_EXCEPTION) + def find_executor(self, task_name): return self.task_router.route(task_name) diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 08691e81f6..1c9e07cc3e 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -262,6 +262,7 @@ def _create_cell(self, location, scheme): time.sleep(self.cell_check_frequency) self.logger.info(f"Got client_runner after {time.time() - start} seconds") self.client_runner.engine.cell = self.cell + self.client_runner.set_cell(self.cell) else: start = time.time() self.logger.info("Wait for engine to be created.")