Skip to content

Commit

Permalink
Support Edge Executor (NVIDIA#3245)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

This PR adds following features.

- A base class (EdgeTaskExecutor) for developing executors for
processing edge tasks.
- The EdgeTaskDispatcher (ETD) that is to be installed on CP and is
responsible for dispatching edge requests to the appropriate CJ. The ETD
determines the right CJ based on the job's "edge_method" meta property
against device's capabilities.
- Added some test components for determining active devices from
simulated device activities.
- Added a more elegant way for registering and handling events.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
yanchengnv authored Feb 21, 2025
1 parent 0a8d273 commit 4bb222c
Show file tree
Hide file tree
Showing 29 changed files with 939 additions and 86 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class EventType(object):
AUTHORIZE_COMMAND_CHECK = "_authorize_command_check"
BEFORE_BUILD_COMPONENT = "_before_build_component"
BEFORE_JOB_LAUNCH = "_before_job_launch"
AFTER_JOB_LAUNCH = "_after_job_launch"

TASK_RESULT_RECEIVED = "_task_result_received"
TASK_ASSIGNMENT_SENT = "_task_assignment_sent"
30 changes: 30 additions & 0 deletions nvflare/apis/fl_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union

from nvflare.apis.utils.fl_context_utils import generate_log_message
from nvflare.fuel.utils.log_utils import get_obj_logger
Expand All @@ -35,6 +36,7 @@ def __init__(self):
"""
self._name = self.__class__.__name__
self.logger = get_obj_logger(self)
self._event_handlers = {}

@property
def name(self):
Expand Down Expand Up @@ -227,3 +229,31 @@ def _fire_log_event(self, event_type: str, log_tag: str, log_msg: str, fl_ctx: F
dxo = event_data.to_dxo()
fl_ctx.set_prop(key=FLContextKey.EVENT_DATA, value=dxo.to_shareable(), private=True, sticky=False)
self.fire_event(event_type=event_type, fl_ctx=fl_ctx)

def register_event_handler(self, event_types: Union[str, List[str]], handler, **kwargs):
if isinstance(event_types, str):
event_types = [event_types]
elif not isinstance(event_types, list):
raise ValueError(f"event_types must be string or list of strings but got {type(event_types)}")

if not callable(handler):
raise ValueError(f"handler {handler.__name__} is not callable")

for e in event_types:
entries = self._event_handlers.get(e)
if not entries:
entries = []
self._event_handlers[e] = entries

already_registered = False
for h, _ in entries:
if handler == h:
# already registered: either by a super class or by the class itself.
already_registered = True
break

if not already_registered:
entries.append((handler, kwargs))

def get_event_handlers(self):
return self._event_handlers
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class FLContextKey(object):
PROCESS_TYPE = ReservedKey.PROCESS_TYPE
JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS
EVENT_PROCESSED = "__event_processed__"
CELL_MESSAGE = "__cell_message__"


class ProcessType:
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/job_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class JobMetaKey(str, Enum):
STATS_POOL_CONFIG = "stats_pool_config"
FROM_HUB_SITE = "from_hub_site"
CUSTOM_PROPS = "custom_props"
EDGE_METHOD = "edge_method"

def __repr__(self):
return self.value
Expand Down
132 changes: 71 additions & 61 deletions nvflare/app_common/executors/ham.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,67 +57,77 @@ def __init__(
self._aggr_lock = threading.Lock()
self._process_error = None

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
engine = fl_ctx.get_engine()

aggr = engine.get_component(self.aggregator_id)
if not isinstance(aggr, Aggregator):
self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}")
self.aggregator = aggr

learner = engine.get_component(self.learner_id)
if not isinstance(learner, Executor):
self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}")
self.learner = learner
elif event_type == EventType.TASK_ASSIGNMENT_SENT:
# the task was sent to a child client
if not self.pending_task_id:
# I don't have a pending task
return

child_client_ctx = fl_ctx.get_peer_context()
assert isinstance(child_client_ctx, FLContext)
child_client_name = child_client_ctx.get_identity_name()
self._update_client_status(child_client_name, None)
task_id = fl_ctx.get_prop(FLContextKey.TASK_ID)

# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}")
elif event_type == EventType.TASK_RESULT_RECEIVED:
# received results from a child client
if not self.pending_task_id:
# I don't have a pending task
return

# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)

result = fl_ctx.get_prop(FLContextKey.TASK_RESULT)
assert isinstance(result, Shareable)
task_id = result.get_header(ReservedKey.TASK_ID)
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
child_client_name = peer_ctx.get_identity_name()
self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}")

if task_id != self.pending_task_id:
self.log_warning(
fl_ctx,
f"dropped the received result from child {child_client_name} "
f"for task {task_id} while waiting for task {self.pending_task_id}",
)
return

rc = result.get_return_code(ReturnCode.OK)
if rc == ReturnCode.OK:
self.log_info(fl_ctx, f"accepting result from client {child_client_name}")
self._do_aggregation(result, fl_ctx)
else:
self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}")
self.log_info(fl_ctx, f"received result from child {child_client_name}")
self._update_client_status(child_client_name, time.time())
self.register_event_handler(EventType.START_RUN, self._handle_start_run)
self.register_event_handler(EventType.TASK_ASSIGNMENT_SENT, self._handle_task_sent)
self.register_event_handler(EventType.TASK_RESULT_RECEIVED, self._handle_result_received)

def _handle_start_run(self, event_type: str, fl_ctx: FLContext):
self.log_debug(fl_ctx, f"handling event {event_type}")
engine = fl_ctx.get_engine()

aggr = engine.get_component(self.aggregator_id)
if not isinstance(aggr, Aggregator):
self.log_error(fl_ctx, f"component '{self.aggregator_id}' must be Aggregator but got {type(aggr)}")
self.aggregator = aggr

learner = engine.get_component(self.learner_id)
if not isinstance(learner, Executor):
self.log_error(fl_ctx, f"component '{self.learner_id}' must be Executor but got {type(learner)}")
self.learner = learner

def _handle_task_sent(self, event_type: str, fl_ctx: FLContext):
# the task was sent to a child client
self.log_debug(fl_ctx, f"handling event {event_type}")

if not self.pending_task_id:
# I don't have a pending task
return

child_client_ctx = fl_ctx.get_peer_context()
assert isinstance(child_client_ctx, FLContext)
child_client_name = child_client_ctx.get_identity_name()
self._update_client_status(child_client_name, None)
task_id = fl_ctx.get_prop(FLContextKey.TASK_ID)

# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)
self.log_info(fl_ctx, f"sent task {task_id} to child {child_client_name}")

def _handle_result_received(self, event_type: str, fl_ctx: FLContext):
# received results from a child client
self.log_debug(fl_ctx, f"handling event {event_type}")

if not self.pending_task_id:
# I don't have a pending task
return

# indicate that this event has been processed by me
fl_ctx.set_prop(FLContextKey.EVENT_PROCESSED, True, private=True, sticky=False)

result = fl_ctx.get_prop(FLContextKey.TASK_RESULT)
assert isinstance(result, Shareable)
task_id = result.get_header(ReservedKey.TASK_ID)
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
child_client_name = peer_ctx.get_identity_name()
self.log_info(fl_ctx, f"received result for task {task_id} from child {child_client_name}")

if task_id != self.pending_task_id:
self.log_warning(
fl_ctx,
f"dropped the received result from child {child_client_name} "
f"for task {task_id} while waiting for task {self.pending_task_id}",
)
return

rc = result.get_return_code(ReturnCode.OK)
if rc == ReturnCode.OK:
self.log_info(fl_ctx, f"accepting result from client {child_client_name}")
self._do_aggregation(result, fl_ctx)
else:
self.log_error(fl_ctx, f"Received bad result from client {child_client_name}: {rc=}")
self.log_info(fl_ctx, f"received result from child {child_client_name}")
self._update_client_status(child_client_name, time.time())

def _do_aggregation(self, result: Shareable, fl_ctx: FLContext):
with self._aggr_lock:
Expand Down
13 changes: 13 additions & 0 deletions nvflare/edge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions nvflare/edge/aggregators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
36 changes: 36 additions & 0 deletions nvflare/edge/aggregators/edge_survey_aggr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator


class EdgeSurveyAggregator(Aggregator):
def __init__(self):
Aggregator.__init__(self)
self.num_devices = 0

def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
self.log_info(fl_ctx, f"accepting: {shareable}")
num_devices = shareable.get("num_devices")
if num_devices:
self.num_devices += num_devices
return True

def reset(self, fl_ctx: FLContext):
self.num_devices = 0

def aggregate(self, fl_ctx: FLContext) -> Shareable:
self.log_info(fl_ctx, f"aggregating final result: {self.num_devices}")
return Shareable({"num_devices": self.num_devices})
36 changes: 36 additions & 0 deletions nvflare/edge/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.fuel.f3.cellnet.defs import ReturnCode as CellReturnCode


class Status(CellReturnCode):
NO_TASK = "no_task"
NO_JOB = "no_job"


class EdgeProtoKey:
STATUS = "status"
DATA = "data"


class EdgeContextKey:
JOB_ID = "__edge_job_id__"
EDGE_CAPABILITIES = "__edge_capabilities__"
REQUEST_FROM_EDGE = "__request_from_edge__"
REPLY_TO_EDGE = "__reply_to_edge__"


class EventType:
EDGE_REQUEST_RECEIVED = "_edge_request_received"
EDGE_JOB_REQUEST_RECEIVED = "_edge_job_request_received"
13 changes: 13 additions & 0 deletions nvflare/edge/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
59 changes: 59 additions & 0 deletions nvflare/edge/controllers/edge_survey_ctl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal


class EdgeSurveyController(Controller):
def __init__(self, num_rounds: int, timeout: int):
Controller.__init__(self)
self.num_rounds = num_rounds
self.timeout = timeout

def start_controller(self, fl_ctx: FLContext):
pass

def stop_controller(self, fl_ctx: FLContext):
pass

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
for r in range(self.num_rounds):
task = Task(
name="survey",
data=Shareable(),
timeout=self.timeout,
)

self.broadcast_and_wait(
task=task,
min_responses=2,
wait_time_after_min_received=0,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)

total_devices = 0
for ct in task.client_tasks:
assert isinstance(ct, ClientTask)
result = ct.result
assert isinstance(result, Shareable)
self.log_info(fl_ctx, f"result from client {ct.client.name}: {result}")
count = result.get("num_devices")
if count:
total_devices += count

self.log_info(fl_ctx, f"total devices in round {r}: {total_devices}")
Loading

0 comments on commit 4bb222c

Please sign in to comment.