Skip to content

Commit

Permalink
add survey test case
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Feb 20, 2025
1 parent 99c91c8 commit 4c28400
Show file tree
Hide file tree
Showing 21 changed files with 817 additions and 70 deletions.
31 changes: 31 additions & 0 deletions nvflare/apis/edge_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode


class Status(CellReturnCode):
NO_TASK = "no_task"
NO_JOB = "no_job"


class EdgeProtoKey:
STATUS = "status"
DATA = "data"


class EdgeContextKey:
JOB_ID = "__edge_job_id__"
EDGE_CAPABILITIES = "__edge_capabilities__"
REQUEST_FROM_EDGE = "__request_from_edge__"
REPLY_TO_EDGE = "__reply_to_edge__"
3 changes: 3 additions & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class EventType(object):
AUTHORIZE_COMMAND_CHECK = "_authorize_command_check"
BEFORE_BUILD_COMPONENT = "_before_build_component"
BEFORE_JOB_LAUNCH = "_before_job_launch"
AFTER_JOB_LAUNCH = "_after_job_launch"

TASK_RESULT_RECEIVED = "_task_result_received"
TASK_ASSIGNMENT_SENT = "_task_assignment_sent"
EDGE_REQUEST_RECEIVED = "_edge_request_received"
EDGE_JOB_REQUEST_RECEIVED = "_edge_job_request_received"
30 changes: 30 additions & 0 deletions nvflare/apis/fl_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union

from nvflare.apis.utils.fl_context_utils import generate_log_message
from nvflare.fuel.utils.log_utils import get_obj_logger
Expand All @@ -35,6 +36,7 @@ def __init__(self):
"""
self._name = self.__class__.__name__
self.logger = get_obj_logger(self)
self._event_handlers = {}

@property
def name(self):
Expand Down Expand Up @@ -227,3 +229,31 @@ def _fire_log_event(self, event_type: str, log_tag: str, log_msg: str, fl_ctx: F
dxo = event_data.to_dxo()
fl_ctx.set_prop(key=FLContextKey.EVENT_DATA, value=dxo.to_shareable(), private=True, sticky=False)
self.fire_event(event_type=event_type, fl_ctx=fl_ctx)

def register_event_handler(self, event_types: Union[str, List[str]], handler, **kwargs):
if isinstance(event_types, str):
event_types = [event_types]
elif not isinstance(event_types, list):
raise ValueError(f"event_types must be string or list of strings")

if not callable(handler):
raise ValueError(f"handler {handler.__name__} is not callable")

for e in event_types:
entries = self._event_handlers.get(e)
if not entries:
entries = []
self._event_handlers[e] = entries

already_registered = False
for h, _ in entries:
if handler == h:
# already registered: either by a super class or by the class itself.
already_registered = True
break

if not already_registered:
entries.append((handler, kwargs))

def get_event_handlers(self):
return self._event_handlers
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class FLContextKey(object):
PROCESS_TYPE = ReservedKey.PROCESS_TYPE
JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS
EVENT_PROCESSED = "__event_processed__"
CELL_MESSAGE = "__cell_message__"


class ProcessType:
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/job_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class JobMetaKey(str, Enum):
STATS_POOL_CONFIG = "stats_pool_config"
FROM_HUB_SITE = "from_hub_site"
CUSTOM_PROPS = "custom_props"
EDGE_METHOD = "edge_method"

def __repr__(self):
return self.value
Expand Down
36 changes: 36 additions & 0 deletions nvflare/app_common/aggregators/edge_survey_aggr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator


class EdgeSurveyAggregator(Aggregator):
def __init__(self):
Aggregator.__init__(self)
self.num_devices = 0

def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
self.log_info(fl_ctx, f"accepting: {shareable}")
num_devices = shareable.get("num_devices")
if num_devices:
self.num_devices += num_devices
return True

def reset(self, fl_ctx: FLContext):
self.num_devices = 0

def aggregate(self, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, f"aggregating final result: {self.num_devices}")
return Shareable({"num_devices": self.num_devices})
45 changes: 45 additions & 0 deletions nvflare/app_common/executors/edge_survey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.app_common.executors.ete import EdgeTaskExecutor


class EdgeSurvey(EdgeTaskExecutor):
def __init__(self, timeout=10.0):
EdgeTaskExecutor.__init__(self)
self.timeout = timeout
self.num_devices = 0
self.start_time = None

def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext):
self.num_devices = 0
self.start_time = time.time()

def is_task_done(self, fl_ctx: FLContext) -> bool:
return time.time() - self.start_time > self.timeout

def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any:
assert isinstance(request, dict)
self.log_info(fl_ctx, f"received edge request: {request}")
self.num_devices += 1
return {"status": "tryAgain", "comment": f"received {request}"}

def get_task_result(self, fl_ctx: FLContext) -> Shareable:
result = make_reply(ReturnCode.OK)
result["num_devices"] = self.num_devices
return result
153 changes: 153 additions & 0 deletions nvflare/app_common/executors/ete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from abc import abstractmethod
from typing import Any

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.security.logging import secure_format_exception


class EdgeTaskExecutor(Executor):
"""This is the base class for executors to handling requests from edge devices.
Subclasses must implement the required abstract methods defined here.
"""

def __init__(self):
"""Constructor of EdgeTaskExecutor"""
Executor.__init__(self)
self.current_task = None

self.register_event_handler(EventType.EDGE_REQUEST_RECEIVED, self._handle_edge_request)

@abstractmethod
def process_edge_request(self, request: Any, fl_ctx: FLContext) -> Any:
"""This is called to process an edge request sent from the edge device.
Args:
request: the request from edge device
fl_ctx: FLContext object
Returns: reply to the edge device
"""
pass

def task_received(self, task_name: str, task_data: Shareable, fl_ctx: FLContext):
"""This method is called when a task assignment is received from the controller.
Subclass can implement this method to prepare for task processing.
Args:
task_name: name of the task
task_data: task data
fl_ctx: FLContext object
Returns: None
"""
pass

@abstractmethod
def is_task_done(self, fl_ctx: FLContext) -> bool:
"""This is called by the base class to determine whether the task processing is done.
Subclass must implement this method.
Args:
fl_ctx: FLContext object
Returns: whether task is done.
"""
pass

@abstractmethod
def get_task_result(self, fl_ctx: FLContext) -> Shareable:
"""This is called by the base class to get the final result of the task.
Base class will send the result to the controller.
Args:
fl_ctx: FLContext object
Returns: a Shareable object that is the task result
"""
pass

def _handle_edge_request(self, event_type: str, fl_ctx: FLContext):
self.log_debug(fl_ctx, f"handling event {event_type}")
if not self.current_task:
self.logger.debug(f"received edge request but I don't have pending task")
return

try:
msg = fl_ctx.get_prop(FLContextKey.CELL_MESSAGE)
assert isinstance(msg, CellMessage)
self.log_info(fl_ctx, f"received edge request: {msg.payload}")
reply = self.process_edge_request(request=msg.payload, fl_ctx=fl_ctx)
fl_ctx.set_prop(FLContextKey.TASK_RESULT, reply, private=True, sticky=False)
except Exception as ex:
self.logger.error(f"exception from self.process_edge_request: {secure_format_exception(ex)}")
fl_ctx.set_prop(FLContextKey.EXCEPTIONS, ex, private=True, sticky=False)

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
self.current_task = shareable
result = self._execute(task_name, shareable, fl_ctx, abort_signal)
self.current_task = None
return result

def _execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
try:
self.task_received(task_name, shareable, fl_ctx)
except Exception as ex:
self.log_error(fl_ctx, f"exception from self.task_received: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

start_time = time.time()
while True:
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

try:
task_done = self.is_task_done(fl_ctx)
except Exception as ex:
self.log_error(fl_ctx, f"exception from self.is_task_done: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if task_done:
break

time.sleep(0.2)

self.log_debug(fl_ctx, f"task done after {time.time()-start_time} seconds")
try:
result = self.get_task_result(fl_ctx)

if not isinstance(result, Shareable):
self.log_error(
fl_ctx,
f"bad result from self.get_task_result: expect Shareable but got {type(result)}",
)
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

except Exception as ex:
self.log_error(fl_ctx, f"exception from self.get_task_result: {secure_format_exception(ex)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

return result
Loading

0 comments on commit 4c28400

Please sign in to comment.