forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99c91c8
commit 4c28400
Showing
21 changed files
with
817 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode | ||
|
||
|
||
class Status(CellReturnCode): | ||
NO_TASK = "no_task" | ||
NO_JOB = "no_job" | ||
|
||
|
||
class EdgeProtoKey: | ||
STATUS = "status" | ||
DATA = "data" | ||
|
||
|
||
class EdgeContextKey: | ||
JOB_ID = "__edge_job_id__" | ||
EDGE_CAPABILITIES = "__edge_capabilities__" | ||
REQUEST_FROM_EDGE = "__request_from_edge__" | ||
REPLY_TO_EDGE = "__reply_to_edge__" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.app_common.abstract.aggregator import Aggregator | ||
|
||
|
||
class EdgeSurveyAggregator(Aggregator): | ||
def __init__(self): | ||
Aggregator.__init__(self) | ||
self.num_devices = 0 | ||
|
||
def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: | ||
self.log_info(fl_ctx, f"accepting: {shareable}") | ||
num_devices = shareable.get("num_devices") | ||
if num_devices: | ||
self.num_devices += num_devices | ||
return True | ||
|
||
def reset(self, fl_ctx: FLContext): | ||
self.num_devices = 0 | ||
|
||
def aggregate(self, fl_ctx: FLContext) -> Shareable: | ||
self.log_info(fl_ctx, f"aggregating final result: {self.num_devices}") | ||
return Shareable({"num_devices": self.num_devices}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import time | ||
from typing import Any | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply | ||
from nvflare.app_common.executors.ete import EdgeTaskExecutor | ||
|
||
|
||
class EdgeSurvey(EdgeTaskExecutor): | ||
def __init__(self, timeout=10.0): | ||
EdgeTaskExecutor.__init__(self) | ||
self.timeout = timeout | ||
self.num_devices = 0 | ||
self.start_time = None | ||
|
||
def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): | ||
self.num_devices = 0 | ||
self.start_time = time.time() | ||
|
||
def is_task_done(self, fl_ctx: FLContext) -> bool: | ||
return time.time() - self.start_time > self.timeout | ||
|
||
def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: | ||
assert isinstance(request, dict) | ||
self.log_info(fl_ctx, f"received edge request: {request}") | ||
self.num_devices += 1 | ||
return {"status": "tryAgain", "comment": f"received {request}"} | ||
|
||
def get_task_result(self, fl_ctx: FLContext) -> Shareable: | ||
result = make_reply(ReturnCode.OK) | ||
result["num_devices"] = self.num_devices | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import time | ||
from abc import abstractmethod | ||
from typing import Any | ||
|
||
from nvflare.apis.event_type import EventType | ||
from nvflare.apis.executor import Executor | ||
from nvflare.apis.fl_constant import FLContextKey | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply | ||
from nvflare.apis.signal import Signal | ||
from nvflare.fuel.f3.message import Message as CellMessage | ||
from nvflare.security.logging import secure_format_exception | ||
|
||
|
||
class EdgeTaskExecutor(Executor): | ||
"""This is the base class for executors to handling requests from edge devices. | ||
Subclasses must implement the required abstract methods defined here. | ||
""" | ||
|
||
def __init__(self): | ||
"""Constructor of EdgeTaskExecutor""" | ||
Executor.__init__(self) | ||
self.current_task = None | ||
|
||
self.register_event_handler(EventType.EDGE_REQUEST_RECEIVED, self._handle_edge_request) | ||
|
||
@abstractmethod | ||
def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any: | ||
"""This is called to process an edge request sent from the edge device. | ||
Args: | ||
request: the request from edge device | ||
fl_ctx: FLContext object | ||
Returns: reply to the edge device | ||
""" | ||
pass | ||
|
||
def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext): | ||
"""This method is called when a task assignment is received from the controller. | ||
Subclass can implement this method to prepare for task processing. | ||
Args: | ||
task_name: name of the task | ||
task_data: task data | ||
fl_ctx: FLContext object | ||
Returns: None | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def is_task_done(self, fl_ctx: FLContext) -> bool: | ||
"""This is called by the base class to determine whether the task processing is done. | ||
Subclass must implement this method. | ||
Args: | ||
fl_ctx: FLContext object | ||
Returns: whether task is done. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_task_result(self, fl_ctx: FLContext) -> Shareable: | ||
"""This is called by the base class to get the final result of the task. | ||
Base class will send the result to the controller. | ||
Args: | ||
fl_ctx: FLContext object | ||
Returns: a Shareable object that is the task result | ||
""" | ||
pass | ||
|
||
def _handle_edge_request(self, event_type: str, fl_ctx: FLContext): | ||
self.log_debug(fl_ctx, f"handling event {event_type}") | ||
if not self.current_task: | ||
self.logger.debug(f"received edge request but I don't have pending task") | ||
return | ||
|
||
try: | ||
msg = fl_ctx.get_prop(FLContextKey.CELL_MESSAGE) | ||
assert isinstance(msg, CellMessage) | ||
self.log_info(fl_ctx, f"received edge request: {msg.payload}") | ||
reply = self.process_edge_request(request=msg.payload, fl_ctx=fl_ctx) | ||
fl_ctx.set_prop(FLContextKey.TASK_RESULT, reply, private=True, sticky=False) | ||
except Exception as ex: | ||
self.logger.error(f"exception from self.process_edge_request: {secure_format_exception(ex)}") | ||
fl_ctx.set_prop(FLContextKey.EXCEPTIONS, ex, private=True, sticky=False) | ||
|
||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: | ||
self.current_task = shareable | ||
result = self._execute(task_name, shareable, fl_ctx, abort_signal) | ||
self.current_task = None | ||
return result | ||
|
||
def _execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: | ||
try: | ||
self.task_received(task_name, shareable, fl_ctx) | ||
except Exception as ex: | ||
self.log_error(fl_ctx, f"exception from self.task_received: {secure_format_exception(ex)}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
start_time = time.time() | ||
while True: | ||
if abort_signal.triggered: | ||
return make_reply(ReturnCode.TASK_ABORTED) | ||
|
||
try: | ||
task_done = self.is_task_done(fl_ctx) | ||
except Exception as ex: | ||
self.log_error(fl_ctx, f"exception from self.is_task_done: {secure_format_exception(ex)}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
if task_done: | ||
break | ||
|
||
time.sleep(0.2) | ||
|
||
self.log_debug(fl_ctx, f"task done after {time.time()-start_time} seconds") | ||
try: | ||
result = self.get_task_result(fl_ctx) | ||
|
||
if not isinstance(result, Shareable): | ||
self.log_error( | ||
fl_ctx, | ||
f"bad result from self.get_task_result: expect Shareable but got {type(result)}", | ||
) | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
except Exception as ex: | ||
self.log_error(fl_ctx, f"exception from self.get_task_result: {secure_format_exception(ex)}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
return result |
Oops, something went wrong.