From bfd7a2d0a736c97389b4c2433878b345fb807b70 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Wed, 22 Jan 2025 00:13:12 -0500 Subject: [PATCH 01/11] Added dict_streaming example --- .../app/config/config_fed_client.json | 24 ++ .../app/config/config_fed_server.json | 20 ++ .../dict_streaming/app/custom/__init__.py | 13 + .../app/custom/app_cmd_controller.py | 202 ++++++++++++++ .../app/custom/app_cmd_executor.py | 132 +++++++++ .../jobs/dict_streaming/app/custom/defs.py | 7 + .../streaming/jobs/dict_streaming/meta.json | 10 + .../app/config/config_fed_client.json | 23 ++ .../app/config/config_fed_server.json | 19 ++ .../file_streaming/app/custom/__init__.py | 13 + .../file_streaming/app/custom/controller.py | 47 ++++ .../app/custom/file_streaming.py | 110 ++++++++ .../jobs/file_streaming/app/custom/trainer.py | 45 ++++ .../streaming/jobs/file_streaming/meta.json | 10 + .../statistics/json_stats_file_persistor.py | 4 +- .../streamers/container_retriever.py | 144 ++++++++++ .../streamers/container_streamer.py | 255 ++++++++++++++++++ nvflare/fuel/utils/class_utils.py | 25 +- nvflare/fuel/utils/fobs/__init__.py | 2 + nvflare/fuel/utils/fobs/fobs.py | 48 ++-- nvflare/fuel/utils/wfconf.py | 4 +- nvflare/private/json_configer.py | 4 +- 22 files changed, 1120 insertions(+), 41 deletions(-) create mode 100755 examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json create mode 100755 examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/meta.json create mode 100755 examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json create mode 100755 examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json create mode 100644 examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py create mode 100644 examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py create mode 100644 examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py create mode 100644 examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py create mode 100644 examples/advanced/streaming/jobs/file_streaming/meta.json create mode 100644 nvflare/app_common/streamers/container_retriever.py create mode 100644 nvflare/app_common/streamers/container_streamer.py diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json new file mode 100755 index 0000000000..9f0c4bb890 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json @@ -0,0 +1,24 @@ +{ + "format_version": 2, + "cell_wait_timeout": 5.0, + "executors": [ + { + "tasks": ["*"], + "executor": { + "path": "app_cmd_executor.AppCmdExecutor", + "args": { + "file_retriever_id": "file_retriever" + } + } + } + ], + "components": [ + { + "id": "file_retriever", + "path": "nvflare.app_common.streamers.file_retriever.FileRetriever", + "args": { + "source_dir": "/tmp" + } + } + ] +} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json new file mode 100755 index 0000000000..7815cb3938 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json @@ -0,0 +1,20 @@ +{ + "format_version": 2, + "components": [ + { + "id": "file_retriever", + "path": "nvflare.app_common.streamers.file_retriever.FileRetriever", + "args": { + "source_dir": "/tmp" + } + } + ], + "workflows": [ + { + "id": "controller", + "path": "app_cmd_controller.AppCommandController", + "args": { + } + } + ] +} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py new file mode 100644 index 0000000000..d5163a4abb --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py @@ -0,0 +1,202 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from nvflare.apis.controller_spec import Client, ClientTask, Task +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.file_streamer import FileStreamer, StreamContext + +from .defs import STREAM_CHANNEL, TOPIC_INITIAL_FILE + + +class AppCommandController(Controller): + def __init__(self, cmd_timeout=2, task_check_period: float = 0.5): + Controller.__init__(self, task_check_period=task_check_period) + self.cmd_timeout = cmd_timeout + self.app_done = False + self.abort_signal = None + + def start_controller(self, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + engine.register_app_command( + topic="hello", + cmd_func=self.handle_hello, + ) + engine.register_app_command( + topic="avg", + cmd_func=self.handle_avg, + ) + engine.register_app_command( + topic="bye", + cmd_func=self.handle_bye, + ) + engine.register_app_command( + topic="echo", + cmd_func=self.handle_echo, + ) + engine.register_app_command( + topic="stream_file", + cmd_func=self.handle_stream_file_cmd, + ) + engine.register_app_command( + topic="rtr_file", + cmd_func=self.handle_rtr_file_cmd, + ) + + FileStreamer.register_stream_processing( + fl_ctx, STREAM_CHANNEL, "*", stream_status_cb=self._file_received, file_type="echo" + ) + + def _file_received( + self, + stream_ctx: StreamContext, + fl_ctx: FLContext, + file_type: str, + ): + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + peer_name = peer_ctx.get_identity_name() + self.log_info(fl_ctx, f"stream file received from {peer_name}: {stream_ctx=} {file_type=}") + + def stop_controller(self, fl_ctx: FLContext): + self.app_done = True + + def handle_stream_file_cmd(self, topic: str, data, fl_ctx: FLContext) -> dict: + full_file_name = data + result = FileStreamer.stream_file( + channel=STREAM_CHANNEL, + topic=TOPIC_INITIAL_FILE, + targets=[], + file_name=full_file_name, + fl_ctx=fl_ctx, + stream_ctx={"cmd_topic": topic}, + ) + return {"result": result} + + def handle_rtr_file_cmd(self, topic: str, data, fl_ctx: FLContext) -> dict: + self.log_info(fl_ctx, f"handle command: {topic=}") + s = Shareable() + s["file_name"] = data + task = Task(name="rtr_file", data=s, timeout=self.cmd_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=2, + abort_signal=self.abort_signal, + ) + client_resps = {} + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = "no answer" + else: + assert isinstance(resp, Shareable) + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get_return_code() + client_resps[ct.client.name] = resp + return {"status": "OK", "data": client_resps} + + def handle_echo(self, topic: str, data, fl_ctx: FLContext) -> dict: + engine = fl_ctx.get_engine() + clients = engine.get_clients() + reqs = {} + for c in clients: + r = Shareable() + r["data"] = c.name + reqs[c.name] = r + replies = engine.multicast_aux_requests( + topic="echo", + target_requests=reqs, + timeout=self.cmd_timeout, + fl_ctx=fl_ctx, + ) + result = {} + if replies: + for k, s in replies.items(): + assert isinstance(s, Shareable) + result[k] = s.get("data", "no data") + return result + + def handle_bye(self, topic: str, data, fl_ctx: FLContext) -> dict: + self.app_done = True + return {"status": "OK"} + + def handle_hello(self, topic: str, data, fl_ctx: FLContext) -> dict: + self.log_info(fl_ctx, f"handle command: {topic=}") + s = Shareable() + s["data"] = data + task = Task(name="hello", data=s, timeout=self.cmd_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=2, + abort_signal=self.abort_signal, + ) + client_resps = {} + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = "no answer" + else: + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get("data") + if not resp: + resp = "greetings!" + client_resps[ct.client.name] = resp + return {"status": "OK", "data": client_resps} + + def handle_avg(self, topic: str, data, fl_ctx: FLContext) -> dict: + s = Shareable() + s["data"] = data + task = Task(name="avg", data=s, timeout=self.cmd_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=2, + abort_signal=self.abort_signal, + ) + client_resps = {} + total = 0.0 + count = 0 + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = 0.0 + else: + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get("data") + if not resp: + resp = 0.0 + else: + total += resp + count += 1 + client_resps[ct.client.name] = resp + client_resps["avg"] = 0.0 if count == 0 else total / count + return {"status": "OK", "data": client_resps} + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + self.abort_signal = abort_signal + while not abort_signal.triggered and not self.app_done: + time.sleep(1.0) + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py new file mode 100644 index 0000000000..8b0ae97024 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py @@ -0,0 +1,132 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.file_retriever import FileRetriever +from nvflare.app_common.streamers.file_streamer import FileStreamer, StreamContext + +from .defs import STREAM_CHANNEL, TOPIC_ECHO_FILE, TOPIC_INITIAL_FILE + + +class AppCmdExecutor(Executor): + def __init__(self, file_retriever_id=None): + Executor.__init__(self) + self.file_retriever_id = file_retriever_id + self.file_retriever = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + engine.register_aux_message_handler( + topic="echo", + message_handle_func=self._handle_echo, + ) + FileStreamer.register_stream_processing( + fl_ctx, STREAM_CHANNEL, TOPIC_INITIAL_FILE, stream_status_cb=self._file_received, file_type="initial" + ) + FileStreamer.register_stream_processing( + fl_ctx, STREAM_CHANNEL, TOPIC_ECHO_FILE, stream_status_cb=self._file_received, file_type="echo" + ) + + if self.file_retriever_id: + c = engine.get_component(self.file_retriever_id) + if not isinstance(c, FileRetriever): + self.system_panic( + f"invalid file_retriever {self.file_retriever_id}: expect FileRetriever but got {type(c)}", + fl_ctx, + ) + return + self.file_retriever = c + + def _file_received( + self, + stream_ctx: StreamContext, + fl_ctx: FLContext, + file_type: str, + ): + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + peer_name = peer_ctx.get_identity_name() + channel = FileStreamer.get_channel(stream_ctx) + topic = FileStreamer.get_topic(stream_ctx) + rc = FileStreamer.get_rc(stream_ctx) + self.log_info(fl_ctx, f"file received from {peer_name}: {stream_ctx=} {file_type=} {channel=} {topic=} {rc=}") + file_location = FileStreamer.get_file_location(stream_ctx) + if file_type == "initial": + # send the file back to everyone + self.log_info(fl_ctx, f"echo file to all: {file_location}") + streamed = FileStreamer.stream_file( + channel=STREAM_CHANNEL, + topic=TOPIC_ECHO_FILE, + targets="@ALL", + file_name=file_location, + fl_ctx=fl_ctx, + stream_ctx={"file_type": file_type}, + ) + self.log_info(fl_ctx, f"streamed echo file to all sites: {streamed}") + + def _handle_echo(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + data = request.get("data") + s = Shareable() + s["data"] = data + return s + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"got task {task_name}: {shareable}") + if task_name == "hello": + data = shareable.get("data") + s = Shareable() + s["data"] = data + return s + elif task_name == "avg": + data = shareable.get("data") + + self.log_info(fl_ctx, f"got avg request: {shareable}") + + start = data.get("start", 0) + end = data.get("end", 0) + v = random.randint(start, end) + result = Shareable() + result["data"] = v + return result + elif task_name == "rtr_file": + file_name = shareable.get("file_name") + if not file_name: + self.log_error(fl_ctx, "missing file name in request") + return make_reply(ReturnCode.BAD_TASK_DATA) + if not self.file_retriever: + self.log_error(fl_ctx, "no file retriever") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + assert isinstance(self.file_retriever, FileRetriever) + rc, location = self.file_retriever.retrieve_file( + from_site="server", + fl_ctx=fl_ctx, + timeout=10.0, + file_name=file_name, + ) + if rc != ReturnCode.OK: + self.log_error(fl_ctx, f"failed to retrieve file {file_name}: {rc}") + return make_reply(rc) + self.log_info(fl_ctx, f"received file {location}") + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"got unknown task {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py new file mode 100644 index 0000000000..ca3d3541fd --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py @@ -0,0 +1,7 @@ +STREAM_CHANNEL = "file_stream" +TOPIC_INITIAL_FILE = "initial_file" +TOPIC_ECHO_FILE = "echo_file" + +FILE_RTR_REQUEST_TOPIC = "rtr_file" +FILE_RTR_STREAM_CHANNEL = "rtr_file_stream" +FILE_RTR_STREAM_TOPIC = "rtr_file_stream" diff --git a/examples/advanced/streaming/jobs/dict_streaming/meta.json b/examples/advanced/streaming/jobs/dict_streaming/meta.json new file mode 100644 index 0000000000..0fcb99272c --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/meta.json @@ -0,0 +1,10 @@ +{ + "name": "file_streaming", + "resource_spec": {}, + "min_clients" : 1, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json new file mode 100755 index 0000000000..5ac09cbb4f --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json @@ -0,0 +1,23 @@ +{ + "format_version": 2, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "trainer.TestTrainer", + "args": {} + } + } + ], + "task_result_filters": [], + "task_data_filters": [], + "components": [ + { + "id": "sender", + "path": "file_streaming.FileSender", + "args": {} + } + ] +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json new file mode 100755 index 0000000000..1c0be95c54 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json @@ -0,0 +1,19 @@ +{ + "format_version": 2, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "receiver", + "path": "file_streaming.FileReceiver", + "args": {} + } + ], + "workflows": [ + { + "id": "controller", + "path": "controller.SimpleController", + "args": {} + } + ] +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py new file mode 100644 index 0000000000..7b13676bb0 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from nvflare.apis.client import Client +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + +logger = logging.getLogger(__name__) + + +class SimpleController(Controller): + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + logger.info(f"Entering control loop of {self.__class__.__name__}") + engine = fl_ctx.get_engine() + receiver = engine.get_component("receiver") + while not receiver.is_done(): + time.sleep(0.2) + + logger.info("Control flow ends") + + def start_controller(self, fl_ctx: FLContext): + logger.info("Start controller") + + def stop_controller(self, fl_ctx: FLContext): + logger.info("Stop controller") + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + raise RuntimeError(f"Unknown task: {task_name} from client {client.name}.") diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py new file mode 100644 index 0000000000..b23bf6e15f --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import time +from threading import Thread + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.streamers.file_streamer import FileStreamer + +CHANNEL = "_test_channel" +TOPIC = "_test_topic" +SIZE = 100*1024*1024 # 100 MB + + +class FileSender(FLComponent): + + def __init__(self): + super().__init__() + self.seq = 0 + self.aborted = False + self.file_name = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.log_info(fl_ctx, "FileSender is started") + Thread(target=self._sending_file, args=(fl_ctx,), daemon=True).start() + elif event_type == EventType.ABORT_TASK: + self.log_info(fl_ctx, "Sender is aborted") + self.aborted = True + + def _sending_file(self, fl_ctx): + + # Create a temp file to send + tmp = tempfile.NamedTemporaryFile(delete=False) + try: + buf = bytearray(SIZE) + for i in range(len(buf)): + buf[i] = i % 256 + + tmp.write(buf) + finally: + tmp.close() + + self.file_name = tmp.name + + rc, result = FileStreamer.stream_file( + targets=["server"], + stream_ctx=None, + channel=CHANNEL, + topic=TOPIC, + file_name=self.file_name, + fl_ctx=fl_ctx, + optional=False, + secure=False, + ) + + self.log_info(fl_ctx, f"Sending finished with RC: {rc}") + os.remove(self.file_name) + + +class FileReceiver(FLComponent): + + def __init__(self): + super().__init__() + self.done = False + + def is_done(self): + return self.done + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._receive_file(fl_ctx) + self.log_info(fl_ctx, "FileReceiver is started") + + def _receive_file(self, fl_ctx): + FileStreamer.register_stream_processing( + fl_ctx=fl_ctx, + channel=CHANNEL, + topic=TOPIC, + stream_done_cb=self._done_cb, + ) + + def _done_cb(self, stream_ctx: dict, fl_ctx: FLContext): + self.log_info(fl_ctx, "File streaming is done") + self.done = True + + file_name = FileStreamer.get_file_location(stream_ctx) + file_size = FileStreamer.get_file_size(stream_ctx) + size = os.path.getsize(file_name) + + if size == file_size: + self.log_info(fl_ctx, f"File {file_name} has correct size {size} bytes") + else: + self.log_error(fl_ctx, f"File {file_name} sizes mismatch {size} <> {file_size} bytes") + + os.remove(file_name) diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py new file mode 100644 index 0000000000..11d1d18d3b --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + + +class TestTrainer(Executor): + def __init__(self): + super().__init__() + self.aborted = False + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.ABORT_TASK: + self.log_info(fl_ctx, "Trainer is aborted") + self.aborted = True + + def execute( + self, + task_name: str, + shareable: Shareable, + fl_ctx: FLContext, + abort_signal: Signal, + ) -> Shareable: + # This is a dummy executor which does nothing + self.log_info(fl_ctx, f"Executor is called with task {task_name}") + dxo = DXO(data_kind=DataKind.WEIGHTS, data={}) + return dxo.to_shareable() diff --git a/examples/advanced/streaming/jobs/file_streaming/meta.json b/examples/advanced/streaming/jobs/file_streaming/meta.json new file mode 100644 index 0000000000..0fcb99272c --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/meta.json @@ -0,0 +1,10 @@ +{ + "name": "file_streaming", + "resource_spec": {}, + "min_clients" : 1, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/nvflare/app_common/statistics/json_stats_file_persistor.py b/nvflare/app_common/statistics/json_stats_file_persistor.py index bf56fe3a6a..af731087df 100644 --- a/nvflare/app_common/statistics/json_stats_file_persistor.py +++ b/nvflare/app_common/statistics/json_stats_file_persistor.py @@ -19,7 +19,7 @@ from nvflare.apis.storage import StorageException from nvflare.app_common.abstract.statistics_writer import StatisticsWriter from nvflare.app_common.utils.json_utils import ObjectEncoder -from nvflare.fuel.utils.class_utils import get_class +from nvflare.fuel.utils.fobs import load_class class JsonStatsFileWriter(StatisticsWriter): @@ -34,7 +34,7 @@ def __init__(self, output_path: str, json_encoder_path: str = ""): self.json_encoder_class = ObjectEncoder else: self.json_encoder_path = json_encoder_path - self.json_encoder_class = get_class(json_encoder_path) + self.json_encoder_class = load_class(json_encoder_path) def save( self, diff --git a/nvflare/app_common/streamers/container_retriever.py b/nvflare/app_common/streamers/container_retriever.py new file mode 100644 index 0000000000..30ad03d69d --- /dev/null +++ b/nvflare/app_common/streamers/container_retriever.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable +from nvflare.apis.streaming import StreamContext +from .container_streamer import ContainerStreamer + +from .file_streamer import FileStreamer +from .object_retriever import ObjectRetriever + + +class ContainerRetriever(ObjectRetriever): + def __init__( + self, + topic: str = None, + stream_msg_optional=False, + stream_msg_secure=False, + entry_timeout=None, + ): + ObjectRetriever.__init__(self, topic) + self.stream_msg_optional = stream_msg_optional + self.stream_msg_secure = stream_msg_secure + self.entry_timeout = entry_timeout + self.containers = {} + + def add_container(self, name: str, container: Any): + """Add a container to the retriever. This must be called on the sending side + + Args: + name: name for the container. + container: The container to be streamed + """ + self.containers[name] = container + + def register_stream_processing( + self, + channel: str, + topic: str, + fl_ctx: FLContext, + stream_done_cb, + **cb_kwargs, + ): + """Called on the stream sending side. + + Args: + channel: + topic: + fl_ctx: + stream_done_cb: + **cb_kwargs: + + Returns: + + """ + ContainerStreamer.register_stream_processing( + channel=channel, + topic=topic, + fl_ctx=fl_ctx, + stream_done_cb=stream_done_cb, + **cb_kwargs, + ) + + def validate_request(self, request: Shareable, fl_ctx: FLContext) -> (str, Any): + name = request.get("name") + if not name: + self.log_error(fl_ctx, "bad request: missing container name") + return ReturnCode.BAD_REQUEST_DATA, None + + container = self.containers.get(name, None) + if not container: + self.log_error(fl_ctx, f"bad request: requested container {name} doesn't exist") + return ReturnCode.BAD_REQUEST_DATA, None + + return ReturnCode.OK, container + + def retrieve_container(self, from_site: str, fl_ctx: FLContext, timeout: float, name: str) -> (str, Any): + """Retrieve a container from the specified site. + This method is to be called by the app. + + Args: + from_site: the site that has the container to be retrieved + fl_ctx: FLContext object + timeout: how long to wait for the file + name: name of the container + + Returns: a tuple of (ReturnCode, container) + + """ + return self.retrieve(from_site=from_site, fl_ctx=fl_ctx, timeout=timeout, name=name) + + def do_stream( + self, target: str, request: Shareable, fl_ctx: FLContext, stream_ctx: StreamContext, validated_data: Any + ): + """Stream the container to the peer. + Called on the stream sending side. + + Args: + target: the receiving site + request: data to be sent + fl_ctx: FLContext object + stream_ctx: the stream context + validated_data: the file full path returned from the validate_request method + + Returns: + + """ + ContainerStreamer.stream_container( + targets=[target], + stream_ctx=stream_ctx, + channel=self.stream_channel, + topic=self.topic, + container=validated_data, + fl_ctx=fl_ctx, + optional=self.stream_msg_optional, + secure=self.stream_msg_secure, + ) + + def get_result(self, stream_ctx: StreamContext) -> (str, Any): + """Called on the stream receiving side. + Get the final result of the streaming. + The result is the location of the received file. + + Args: + stream_ctx: the StreamContext + + Returns: + + """ + return ContainerStreamer.get_rc(stream_ctx), ContainerStreamer.get_result(stream_ctx) diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py new file mode 100644 index 0000000000..74588443d2 --- /dev/null +++ b/nvflare/app_common/streamers/container_streamer.py @@ -0,0 +1,255 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.fuel.utils.validation_utils import check_positive_number + +from nvflare.app_common.streamers.streamer_base import StreamerBase +from nvflare.fuel.utils.fobs import get_class_name, load_class + +_PREFIX = "ContainerStreamer." + +# Keys for StreamCtx +_CTX_TYPE = _PREFIX + "type" +_CTX_SIZE = _PREFIX + "size" +_CTX_RESULT = _PREFIX + "result" + +# Keys for Shareable +_KEY_ENTRY = _PREFIX + "entry" +_KEY_LAST = _PREFIX + "last" + + +class _EntryConsumer(ObjectConsumer): + def __init__(self, stream_ctx: StreamContext): + self.logger = get_obj_logger(self) + container_type = stream_ctx.get(_CTX_TYPE) + container_class = load_class(container_type) + self.container = container_class() + self.size = stream_ctx.get(_CTX_SIZE) + + def consume( + self, + shareable: Shareable, + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Tuple[bool, Shareable]: + + entry = shareable.get(_KEY_ENTRY) + try: + if isinstance(self.container, dict): + key, value = entry + self.container[key] = value + else: + self.container.append(entry) + except Exception as ex: + error = f"Unable to add entry ({type(entry)} to container ({type(self.container)}" + self.logger.error(error) + raise ValueError(error) + + last = shareable.get(_KEY_LAST) + if last: + # Check if all entries are added + if self.size != len(self.container): + err = f"Container size {len(self.container)} does not match expected size {self.size}" + self.logger.error(err) + raise ValueError(err) + else: + stream_ctx[_CTX_RESULT] = self.container + return False, make_reply(ReturnCode.OK) + else: + # continue streaming + return True, make_reply(ReturnCode.OK) + + def finalize(self, stream_ctx: StreamContext, fl_ctx: FLContext): + self.logger.debug(f"Container streaming is done for container type {type(self.container)}") + + +class _EntryConsumerFactory(ConsumerFactory): + + def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectConsumer: + return _EntryConsumer(stream_ctx) + + +class _EntryProducer(ObjectProducer): + def __init__(self, container, entry_timeout): + self.logger = get_obj_logger(self) + if not self.container: + error = "Can't stream empty container" + self.logger.error(error) + raise ValueError(error) + self.container = container + if isinstance(container, dict): + self.iterator = iter(container.items()) + else: + self.iterator = iter(container) + self.entry_timeout = entry_timeout + self.last = False + self.next = None + + def produce( + self, + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Tuple[Shareable, float]: + + # To check if this is the last entry, need to get one entry ahead + if self.next: + entry = self.next + else: + entry = next(self.iterator) + + try: + self.next = next(self.iterator) + last = False + except StopIteration: + last = True + + result = Shareable() + result[_KEY_ENTRY] = entry + result[_KEY_LAST] = last + return result, self.entry_timeout + + def process_replies( + self, + replies: Dict[str, Shareable], + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Any: + has_error = False + for target, reply in replies.items(): + rc = reply.get_return_code(ReturnCode.OK) + if rc != ReturnCode.OK: + self.logger.error(f"error from target {target}: {rc}") + has_error = True + + if has_error: + # done - failed + return False + elif self.eof: + # done - succeeded + return True + else: + # not done yet - continue streaming + return None + + +class ContainerStreamer(StreamerBase): + @staticmethod + def register_stream_processing( + fl_ctx: FLContext, + channel: str, + topic: str, + stream_done_cb=None, + **cb_kwargs, + ): + """Register for stream processing on the receiving side. + + Args: + fl_ctx: the FLContext object + channel: the app channel + topic: the app topic + stream_done_cb: if specified, the callback to be called when the file is completely received + **cb_kwargs: the kwargs for the stream_done_cb + + Returns: None + + Notes: the stream_done_cb must follow stream_done_cb_signature as defined in apis.streaming. + + """ + + engine = fl_ctx.get_engine() + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + + engine.register_stream_processing( + channel=channel, + topic=topic, + factory=_EntryConsumerFactory(), + stream_done_cb=stream_done_cb, + **cb_kwargs, + ) + + @staticmethod + def stream_container( + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[str], + container: Any, + fl_ctx: FLContext, + entry_timeout=None, + optional=False, + secure=False, + ) -> bool: + """Stream a file to one or more targets. + + Args: + channel: the app channel + topic: the app topic + stream_ctx: context data of the stream + targets: targets that the file will be sent to + container: container to be streamed + fl_ctx: a FLContext object + entry_timeout: timeout for each entry sent to targets. + optional: whether the file is optional + secure: whether P2P security is required + + Returns: whether the streaming completed successfully + + Notes: this is a blocking call - only returns after the streaming is done. + """ + if not entry_timeout: + entry_timeout = 60.0 + check_positive_number("chunk_timeout", entry_timeout) + + producer = _EntryProducer(container, entry_timeout) + engine = fl_ctx.get_engine() + + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + + if not stream_ctx: + stream_ctx = {} + + stream_ctx[_CTX_TYPE] = get_class_name(container) + stream_ctx[_CTX_SIZE] = len(container) + + return engine.stream_objects( + channel=channel, + topic=topic, + stream_ctx=stream_ctx, + targets=targets, + producer=producer, + fl_ctx=fl_ctx, + optional=optional, + secure=secure, + ) + + @staticmethod + def get_container(stream_ctx: StreamContext) -> Any: + """Get the received container + This method is intended to be used by the stream_done_cb() function of the receiving side. + + Args: + stream_ctx: the stream context + + Returns: The received container + + """ + return stream_ctx.get(_CTX_RESULT) diff --git a/nvflare/fuel/utils/class_utils.py b/nvflare/fuel/utils/class_utils.py index e82162ada2..36f52593ea 100644 --- a/nvflare/fuel/utils/class_utils.py +++ b/nvflare/fuel/utils/class_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import importlib import inspect import pkgutil @@ -20,28 +19,13 @@ from nvflare.apis.fl_component import FLComponent from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.utils.components_utils import create_classes_table_static +from nvflare.fuel.utils.fobs import load_class from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception DEPRECATED_PACKAGES = ["nvflare.app_common.pt", "nvflare.app_common.homomorphic_encryption"] -def get_class(class_path): - module_name, class_name = class_path.rsplit(".", 1) - - try: - module_ = importlib.import_module(module_name) - - try: - class_ = getattr(module_, class_name) - except AttributeError: - raise ValueError("Class {} does not exist".format(class_path)) - except AttributeError: - raise ValueError("Module {} does not exist".format(class_path)) - - return class_ - - def instantiate_class(class_path, init_params): """Method for creating an instance for the class. @@ -51,7 +35,7 @@ def instantiate_class(class_path, init_params): arguments. The transform name will be appended to `medical.common.transforms` to make a full name of the transform to be built. """ - c = get_class(class_path) + c = load_class(class_path) try: if init_params: instance = c(**init_params) @@ -80,7 +64,7 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T self._class_table = create_classes_table_static() def create_classes_table(self): - class_table: Dict[str, str] = {} + class_table: Dict[str, list[str]] = {} for base in self.base_pkgs: package = importlib.import_module(base) @@ -123,7 +107,8 @@ def get_module_name(self, class_name) -> Optional[str]: """ if class_name not in self._class_table: raise ConfigError( - f"Cannot find class '{class_name}'. Please check its spelling. If the spelling is correct, specify the class using its full path." + f"Cannot find class '{class_name}'. Please check its spelling. If the spelling is correct, " + "specify the class using its full path." ) modules = self._class_table.get(class_name, None) diff --git a/nvflare/fuel/utils/fobs/__init__.py b/nvflare/fuel/utils/fobs/__init__.py index b9a54ef75b..8d238d32ea 100644 --- a/nvflare/fuel/utils/fobs/__init__.py +++ b/nvflare/fuel/utils/fobs/__init__.py @@ -16,6 +16,8 @@ auto_register_enum_types, deserialize, deserialize_stream, + get_class_name, + load_class, num_decomposers, register, register_data_classes, diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index 7a69ea17e7..fb2293b3f4 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -27,6 +27,8 @@ from nvflare.fuel.utils.fobs.decomposer import DataClassDecomposer, Decomposer, EnumTypeDecomposer __all__ = [ + "get_class_name", + "load_class", "register", "register_data_classes", "register_enum_types", @@ -58,25 +60,41 @@ _data_auto_registration = True -def _get_type_name(cls: Type) -> str: +def get_class_name(cls: Type) -> str: + """Get canonical class path or fully qualified name. The builtins module is removed + so common builtin class can be referenced with its normal name + + Args: + cls: The class type + Returns: + The canonical name + """ module = cls.__module__ if module == "builtins": return cls.__qualname__ return module + "." + cls.__qualname__ -def _load_class(type_name: str): +def load_class(class_path): + """Load class from fully qualified class name + + Args: + class_path: fully qualified class name + Returns: + The class type + """ + try: - if "." in type_name: - module_name, class_name = type_name.rsplit(".", 1) + if "." in class_path: + module_name, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, class_name) else: - return getattr(builtins, type_name) + return getattr(builtins, class_path) except Exception as ex: - raise TypeError(f"Can't load class {type_name}: {ex}") - - + raise TypeError(f"Can't load class {class_path}: {ex}") + + def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: """Register a decomposer. It does nothing if decomposer is already registered for the type @@ -91,7 +109,7 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: else: instance = decomposer - name = _get_type_name(instance.supported_type()) + name = get_class_name(instance.supported_type()) if name in _decomposers: return @@ -105,15 +123,15 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: class Packer: def __init__(self, manager: DatumManager): self.manager = manager - self.enum_decomposer_name = _get_type_name(EnumTypeDecomposer) - self.data_decomposer_name = _get_type_name(DataClassDecomposer) + self.enum_decomposer_name = get_class_name(EnumTypeDecomposer) + self.data_decomposer_name = get_class_name(DataClassDecomposer) def pack(self, obj: Any) -> dict: if type(obj) in MSGPACK_TYPES: return obj - type_name = _get_type_name(obj.__class__) + type_name = get_class_name(obj.__class__) if type_name not in _decomposers: registered = False if isinstance(obj, Enum): @@ -136,7 +154,7 @@ def pack(self, obj: Any) -> dict: if self.manager: decomposed = self.manager.externalize(decomposed) - return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: _get_type_name(type(decomposer))} + return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: get_class_name(type(decomposer))} def unpack(self, obj: Any) -> Any: @@ -147,7 +165,7 @@ def unpack(self, obj: Any) -> Any: if type_name not in _decomposers: registered = False decomposer_name = obj.get(FOBS_DECOMPOSER) - cls = _load_class(type_name) + cls = load_class(type_name) if not decomposer_name: # Maintaining backward compatibility with auto enum registration if _enum_auto_registration: @@ -155,7 +173,7 @@ def unpack(self, obj: Any) -> Any: register_enum_types(cls) registered = True else: - decomposer_class = _load_class(decomposer_name) + decomposer_class = load_class(decomposer_name) if decomposer_name == self.enum_decomposer_name or decomposer_name == self.data_decomposer_name: # Generic decomposer's __init__ takes the target class as argument decomposer = decomposer_class(cls) diff --git a/nvflare/fuel/utils/wfconf.py b/nvflare/fuel/utils/wfconf.py index f2356c5469..43e698b5e3 100644 --- a/nvflare/fuel/utils/wfconf.py +++ b/nvflare/fuel/utils/wfconf.py @@ -22,7 +22,7 @@ from nvflare.security.logging import secure_format_exception from .argument_utils import parse_vars -from .class_utils import ModuleScanner, get_class, instantiate_class +from .class_utils import ModuleScanner, load_class, instantiate_class from .dict_utils import extract_first_level_primitive, merge_dict from .json_scanner import JsonObjectProcessor, JsonScanner, Node @@ -362,7 +362,7 @@ def get_class_path(self, config_dict): return class_path def is_configured_subclass(self, config_dict, base_class): - return issubclass(get_class(self.get_class_path(config_dict)), base_class) + return issubclass(load_class(self.get_class_path(config_dict)), base_class) def start_config(self, config_ctx: ConfigContext): pass diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index 3a987fde69..e47de9b270 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -15,7 +15,7 @@ from typing import List, Union from nvflare.fuel.common.excepts import ComponentNotAuthorized, ConfigError -from nvflare.fuel.utils.class_utils import ModuleScanner, get_class +from nvflare.fuel.utils.class_utils import ModuleScanner, load_class from nvflare.fuel.utils.component_builder import ComponentBuilder from nvflare.fuel.utils.config_factory import ConfigFactory from nvflare.fuel.utils.config_service import ConfigService @@ -150,7 +150,7 @@ def process_element(self, node: Node): self.process_config_element(self.config_ctx, node) def is_configured_subclass(self, config_dict, base_class): - return issubclass(get_class(self.get_class_path(config_dict)), base_class) + return issubclass(load_class(self.get_class_path(config_dict)), base_class) def start_config(self, config_ctx: ConfigContext): pass From 749214b2cd5346214c29098d283d4a2891165f7b Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 11:42:46 -0500 Subject: [PATCH 02/11] Simplified the example --- .../app/config/config_fed_client.json | 9 +- .../app/config/config_fed_server.json | 10 +- .../app/custom/app_cmd_controller.py | 202 ------------------ .../app/custom/app_cmd_executor.py | 132 ------------ .../jobs/dict_streaming/app/custom/defs.py | 7 - .../app/custom/streaming_controller.py | 89 ++++++++ .../app/custom/streaming_executor.py | 71 ++++++ .../statistics/json_stats_file_persistor.py | 2 +- .../streamers/container_streamer.py | 13 +- .../app_common/streamers/object_retriever.py | 2 +- nvflare/fuel/f3/streaming/byte_streamer.py | 3 + nvflare/fuel/utils/class_loader.py | 53 +++++ nvflare/fuel/utils/class_utils.py | 2 +- nvflare/fuel/utils/fobs/__init__.py | 2 - nvflare/fuel/utils/fobs/fobs.py | 38 +--- 15 files changed, 236 insertions(+), 399 deletions(-) delete mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py delete mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py delete mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py create mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py create mode 100644 nvflare/fuel/utils/class_loader.py diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json index 9f0c4bb890..c2d85b8b48 100755 --- a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json @@ -5,19 +5,18 @@ { "tasks": ["*"], "executor": { - "path": "app_cmd_executor.AppCmdExecutor", + "path": "streaming_executor.StreamingExecutor", "args": { - "file_retriever_id": "file_retriever" + "dict_retriever_id": "dict_retriever" } } } ], "components": [ { - "id": "file_retriever", - "path": "nvflare.app_common.streamers.file_retriever.FileRetriever", + "id": "dict_retriever", + "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", "args": { - "source_dir": "/tmp" } } ] diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json index 7815cb3938..fd847e0175 100755 --- a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json @@ -2,19 +2,19 @@ "format_version": 2, "components": [ { - "id": "file_retriever", - "path": "nvflare.app_common.streamers.file_retriever.FileRetriever", + "id": "dict_retriever", + "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", "args": { - "source_dir": "/tmp" } } ], "workflows": [ { "id": "controller", - "path": "app_cmd_controller.AppCommandController", + "path": "streaming_controller.StreamingController", "args": { - } + "dict_retriever_id": "dict_retriever" + } } ] } \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py deleted file mode 100644 index d5163a4abb..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_controller.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time - -from nvflare.apis.controller_spec import Client, ClientTask, Task -from nvflare.apis.fl_context import FLContext -from nvflare.apis.impl.controller import Controller -from nvflare.apis.shareable import Shareable -from nvflare.apis.signal import Signal -from nvflare.app_common.streamers.file_streamer import FileStreamer, StreamContext - -from .defs import STREAM_CHANNEL, TOPIC_INITIAL_FILE - - -class AppCommandController(Controller): - def __init__(self, cmd_timeout=2, task_check_period: float = 0.5): - Controller.__init__(self, task_check_period=task_check_period) - self.cmd_timeout = cmd_timeout - self.app_done = False - self.abort_signal = None - - def start_controller(self, fl_ctx: FLContext): - engine = fl_ctx.get_engine() - engine.register_app_command( - topic="hello", - cmd_func=self.handle_hello, - ) - engine.register_app_command( - topic="avg", - cmd_func=self.handle_avg, - ) - engine.register_app_command( - topic="bye", - cmd_func=self.handle_bye, - ) - engine.register_app_command( - topic="echo", - cmd_func=self.handle_echo, - ) - engine.register_app_command( - topic="stream_file", - cmd_func=self.handle_stream_file_cmd, - ) - engine.register_app_command( - topic="rtr_file", - cmd_func=self.handle_rtr_file_cmd, - ) - - FileStreamer.register_stream_processing( - fl_ctx, STREAM_CHANNEL, "*", stream_status_cb=self._file_received, file_type="echo" - ) - - def _file_received( - self, - stream_ctx: StreamContext, - fl_ctx: FLContext, - file_type: str, - ): - peer_ctx = fl_ctx.get_peer_context() - assert isinstance(peer_ctx, FLContext) - peer_name = peer_ctx.get_identity_name() - self.log_info(fl_ctx, f"stream file received from {peer_name}: {stream_ctx=} {file_type=}") - - def stop_controller(self, fl_ctx: FLContext): - self.app_done = True - - def handle_stream_file_cmd(self, topic: str, data, fl_ctx: FLContext) -> dict: - full_file_name = data - result = FileStreamer.stream_file( - channel=STREAM_CHANNEL, - topic=TOPIC_INITIAL_FILE, - targets=[], - file_name=full_file_name, - fl_ctx=fl_ctx, - stream_ctx={"cmd_topic": topic}, - ) - return {"result": result} - - def handle_rtr_file_cmd(self, topic: str, data, fl_ctx: FLContext) -> dict: - self.log_info(fl_ctx, f"handle command: {topic=}") - s = Shareable() - s["file_name"] = data - task = Task(name="rtr_file", data=s, timeout=self.cmd_timeout) - self.broadcast_and_wait( - task=task, - fl_ctx=fl_ctx, - min_responses=2, - abort_signal=self.abort_signal, - ) - client_resps = {} - for ct in task.client_tasks: - assert isinstance(ct, ClientTask) - resp = ct.result - if resp is None: - resp = "no answer" - else: - assert isinstance(resp, Shareable) - self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") - resp = resp.get_return_code() - client_resps[ct.client.name] = resp - return {"status": "OK", "data": client_resps} - - def handle_echo(self, topic: str, data, fl_ctx: FLContext) -> dict: - engine = fl_ctx.get_engine() - clients = engine.get_clients() - reqs = {} - for c in clients: - r = Shareable() - r["data"] = c.name - reqs[c.name] = r - replies = engine.multicast_aux_requests( - topic="echo", - target_requests=reqs, - timeout=self.cmd_timeout, - fl_ctx=fl_ctx, - ) - result = {} - if replies: - for k, s in replies.items(): - assert isinstance(s, Shareable) - result[k] = s.get("data", "no data") - return result - - def handle_bye(self, topic: str, data, fl_ctx: FLContext) -> dict: - self.app_done = True - return {"status": "OK"} - - def handle_hello(self, topic: str, data, fl_ctx: FLContext) -> dict: - self.log_info(fl_ctx, f"handle command: {topic=}") - s = Shareable() - s["data"] = data - task = Task(name="hello", data=s, timeout=self.cmd_timeout) - self.broadcast_and_wait( - task=task, - fl_ctx=fl_ctx, - min_responses=2, - abort_signal=self.abort_signal, - ) - client_resps = {} - for ct in task.client_tasks: - assert isinstance(ct, ClientTask) - resp = ct.result - if resp is None: - resp = "no answer" - else: - self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") - resp = resp.get("data") - if not resp: - resp = "greetings!" - client_resps[ct.client.name] = resp - return {"status": "OK", "data": client_resps} - - def handle_avg(self, topic: str, data, fl_ctx: FLContext) -> dict: - s = Shareable() - s["data"] = data - task = Task(name="avg", data=s, timeout=self.cmd_timeout) - self.broadcast_and_wait( - task=task, - fl_ctx=fl_ctx, - min_responses=2, - abort_signal=self.abort_signal, - ) - client_resps = {} - total = 0.0 - count = 0 - for ct in task.client_tasks: - assert isinstance(ct, ClientTask) - resp = ct.result - if resp is None: - resp = 0.0 - else: - self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") - resp = resp.get("data") - if not resp: - resp = 0.0 - else: - total += resp - count += 1 - client_resps[ct.client.name] = resp - client_resps["avg"] = 0.0 if count == 0 else total / count - return {"status": "OK", "data": client_resps} - - def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): - self.abort_signal = abort_signal - while not abort_signal.triggered and not self.app_done: - time.sleep(1.0) - - def process_result_of_unknown_task( - self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext - ): - pass diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py deleted file mode 100644 index 8b0ae97024..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/app_cmd_executor.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random - -from nvflare.apis.event_type import EventType -from nvflare.apis.executor import Executor -from nvflare.apis.fl_constant import ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable, make_reply -from nvflare.apis.signal import Signal -from nvflare.app_common.streamers.file_retriever import FileRetriever -from nvflare.app_common.streamers.file_streamer import FileStreamer, StreamContext - -from .defs import STREAM_CHANNEL, TOPIC_ECHO_FILE, TOPIC_INITIAL_FILE - - -class AppCmdExecutor(Executor): - def __init__(self, file_retriever_id=None): - Executor.__init__(self) - self.file_retriever_id = file_retriever_id - self.file_retriever = None - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - engine = fl_ctx.get_engine() - engine.register_aux_message_handler( - topic="echo", - message_handle_func=self._handle_echo, - ) - FileStreamer.register_stream_processing( - fl_ctx, STREAM_CHANNEL, TOPIC_INITIAL_FILE, stream_status_cb=self._file_received, file_type="initial" - ) - FileStreamer.register_stream_processing( - fl_ctx, STREAM_CHANNEL, TOPIC_ECHO_FILE, stream_status_cb=self._file_received, file_type="echo" - ) - - if self.file_retriever_id: - c = engine.get_component(self.file_retriever_id) - if not isinstance(c, FileRetriever): - self.system_panic( - f"invalid file_retriever {self.file_retriever_id}: expect FileRetriever but got {type(c)}", - fl_ctx, - ) - return - self.file_retriever = c - - def _file_received( - self, - stream_ctx: StreamContext, - fl_ctx: FLContext, - file_type: str, - ): - peer_ctx = fl_ctx.get_peer_context() - assert isinstance(peer_ctx, FLContext) - peer_name = peer_ctx.get_identity_name() - channel = FileStreamer.get_channel(stream_ctx) - topic = FileStreamer.get_topic(stream_ctx) - rc = FileStreamer.get_rc(stream_ctx) - self.log_info(fl_ctx, f"file received from {peer_name}: {stream_ctx=} {file_type=} {channel=} {topic=} {rc=}") - file_location = FileStreamer.get_file_location(stream_ctx) - if file_type == "initial": - # send the file back to everyone - self.log_info(fl_ctx, f"echo file to all: {file_location}") - streamed = FileStreamer.stream_file( - channel=STREAM_CHANNEL, - topic=TOPIC_ECHO_FILE, - targets="@ALL", - file_name=file_location, - fl_ctx=fl_ctx, - stream_ctx={"file_type": file_type}, - ) - self.log_info(fl_ctx, f"streamed echo file to all sites: {streamed}") - - def _handle_echo(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: - data = request.get("data") - s = Shareable() - s["data"] = data - return s - - def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: - self.log_info(fl_ctx, f"got task {task_name}: {shareable}") - if task_name == "hello": - data = shareable.get("data") - s = Shareable() - s["data"] = data - return s - elif task_name == "avg": - data = shareable.get("data") - - self.log_info(fl_ctx, f"got avg request: {shareable}") - - start = data.get("start", 0) - end = data.get("end", 0) - v = random.randint(start, end) - result = Shareable() - result["data"] = v - return result - elif task_name == "rtr_file": - file_name = shareable.get("file_name") - if not file_name: - self.log_error(fl_ctx, "missing file name in request") - return make_reply(ReturnCode.BAD_TASK_DATA) - if not self.file_retriever: - self.log_error(fl_ctx, "no file retriever") - return make_reply(ReturnCode.SERVICE_UNAVAILABLE) - - assert isinstance(self.file_retriever, FileRetriever) - rc, location = self.file_retriever.retrieve_file( - from_site="server", - fl_ctx=fl_ctx, - timeout=10.0, - file_name=file_name, - ) - if rc != ReturnCode.OK: - self.log_error(fl_ctx, f"failed to retrieve file {file_name}: {rc}") - return make_reply(rc) - self.log_info(fl_ctx, f"received file {location}") - return make_reply(ReturnCode.OK) - else: - self.log_error(fl_ctx, f"got unknown task {task_name}") - return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py deleted file mode 100644 index ca3d3541fd..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/defs.py +++ /dev/null @@ -1,7 +0,0 @@ -STREAM_CHANNEL = "file_stream" -TOPIC_INITIAL_FILE = "initial_file" -TOPIC_ECHO_FILE = "echo_file" - -FILE_RTR_REQUEST_TOPIC = "rtr_file" -FILE_RTR_STREAM_CHANNEL = "rtr_file_stream" -FILE_RTR_STREAM_TOPIC = "rtr_file_stream" diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py new file mode 100644 index 0000000000..1f1700d1a8 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from random import randbytes + +from nvflare.apis.controller_spec import Client, ClientTask, Task +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever + +STREAM_TOPIC = "rtr_file_stream" + + +class StreamingController(Controller): + def __init__(self, dict_retriever_id=None, task_timeout=60, task_check_period: float = 0.5): + Controller.__init__(self, task_check_period=task_check_period) + self.dict_retriever_id = dict_retriever_id + self.dict_retriever = None + self.task_timeout = task_timeout + + def start_controller(self, fl_ctx: FLContext): + model = self._get_test_model() + self.dict_retriever.add_container("model", model) + + def stop_controller(self, fl_ctx: FLContext): + pass + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + s = Shareable() + s["name"] = "model" + task = Task(name="retrieve_dict", data=s, timeout=self.task_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=1, + abort_signal=abort_signal, + ) + client_resps = {} + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = "no answer" + else: + assert isinstance(resp, Shareable) + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get_return_code() + client_resps[ct.client.name] = resp + return {"status": "OK", "data": client_resps} + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.dict_retriever_id: + c = engine.get_component(self.dict_retriever_id) + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.dict_retriever = c + + @staticmethod + def _get_test_model() -> dict: + model = {} + for i in range(10): + key = f"layer-{i}" + model[key] = randbytes(1024) + + return model diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py new file mode 100644 index 0000000000..8756144b1e --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever +from nvflare.app_common.streamers.file_retriever import FileRetriever + + +class StreamingExecutor(Executor): + def __init__(self, dict_retriever_id=None): + Executor.__init__(self) + self.dict_retriever_id = dict_retriever_id + self.dict_retriever = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.dict_retriever_id: + c = engine.get_component(self.dict_retriever_id) + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.dict_retriever = c + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"got task {task_name}: {shareable}") + if task_name == "retrieve_dict": + name = shareable.get("name") + if not name: + self.log_error(fl_ctx, "missing name in request") + return make_reply(ReturnCode.BAD_TASK_DATA) + if not self.dict_retriever: + self.log_error(fl_ctx, "no container retriever") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + assert isinstance(self.dict_retriever, ContainerRetriever) + rc, result = self.dict_retriever.retrieve_container( + from_site="server", + fl_ctx=fl_ctx, + timeout=10.0, + name=name, + ) + if rc != ReturnCode.OK: + self.log_error(fl_ctx, f"failed to retrieve dict {name}: {rc}") + return make_reply(rc) + + self.log_info(fl_ctx, f"received container type: {type(result)} size: {len(result)}") + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"got unknown task {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/nvflare/app_common/statistics/json_stats_file_persistor.py b/nvflare/app_common/statistics/json_stats_file_persistor.py index af731087df..96c96f7b32 100644 --- a/nvflare/app_common/statistics/json_stats_file_persistor.py +++ b/nvflare/app_common/statistics/json_stats_file_persistor.py @@ -19,7 +19,7 @@ from nvflare.apis.storage import StorageException from nvflare.app_common.abstract.statistics_writer import StatisticsWriter from nvflare.app_common.utils.json_utils import ObjectEncoder -from nvflare.fuel.utils.fobs import load_class +from nvflare.fuel.utils.class_loader import load_class class JsonStatsFileWriter(StatisticsWriter): diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py index 74588443d2..aea3815d0d 100644 --- a/nvflare/app_common/streamers/container_streamer.py +++ b/nvflare/app_common/streamers/container_streamer.py @@ -17,11 +17,11 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext +from nvflare.fuel.utils.class_loader import load_class, get_class_name from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.validation_utils import check_positive_number from nvflare.app_common.streamers.streamer_base import StreamerBase -from nvflare.fuel.utils.fobs import get_class_name, load_class _PREFIX = "ContainerStreamer." @@ -89,10 +89,11 @@ def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectCo class _EntryProducer(ObjectProducer): def __init__(self, container, entry_timeout): self.logger = get_obj_logger(self) - if not self.container: + if not container: error = "Can't stream empty container" self.logger.error(error) raise ValueError(error) + self.container = container if isinstance(container, dict): self.iterator = iter(container.items()) @@ -141,7 +142,7 @@ def process_replies( if has_error: # done - failed return False - elif self.eof: + elif self.last: # done - succeeded return True else: @@ -216,7 +217,7 @@ def stream_container( """ if not entry_timeout: entry_timeout = 60.0 - check_positive_number("chunk_timeout", entry_timeout) + check_positive_number("entry_timeout", entry_timeout) producer = _EntryProducer(container, entry_timeout) engine = fl_ctx.get_engine() @@ -227,7 +228,7 @@ def stream_container( if not stream_ctx: stream_ctx = {} - stream_ctx[_CTX_TYPE] = get_class_name(container) + stream_ctx[_CTX_TYPE] = get_class_name(type(container)) stream_ctx[_CTX_SIZE] = len(container) return engine.stream_objects( @@ -242,7 +243,7 @@ def stream_container( ) @staticmethod - def get_container(stream_ctx: StreamContext) -> Any: + def get_result(stream_ctx: StreamContext) -> Any: """Get the received container This method is intended to be used by the stream_done_cb() function of the receiving side. diff --git a/nvflare/app_common/streamers/object_retriever.py b/nvflare/app_common/streamers/object_retriever.py index 01c8f989fb..9360206747 100644 --- a/nvflare/app_common/streamers/object_retriever.py +++ b/nvflare/app_common/streamers/object_retriever.py @@ -237,7 +237,7 @@ def _handle_stream_done(self, stream_ctx: StreamContext, fl_ctx: FLContext): waiter.result = result waiter.set() - self.log_info(fl_ctx, f"got result for RTR {tx_id}: {waiter.result}") + self.log_info(fl_ctx, f"got result for RTR {tx_id}: {type(waiter.result)}") def _handle_request(self, topic, request: Shareable, fl_ctx: FLContext) -> Shareable: # On request receiving side, which is also stream sending side. diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index faf160becd..3e678bd5a7 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -182,6 +182,9 @@ def send_pending_buffer(self, final=False): def stop(self, error: Optional[StreamError] = None, notify=True): + if self.stopped: + return + self.stopped = True if self.task_future: diff --git a/nvflare/fuel/utils/class_loader.py b/nvflare/fuel/utils/class_loader.py new file mode 100644 index 0000000000..2479f5c662 --- /dev/null +++ b/nvflare/fuel/utils/class_loader.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import builtins +import importlib +from typing import Type + + +# Those functions are extracted from class_utils module to share the code +# with FOBS and to avoid circular imports +def get_class_name(cls: Type) -> str: + """Get canonical class path or fully qualified name. The builtins module is removed + so common builtin class can be referenced with its normal name + + Args: + cls: The class type + Returns: + The canonical name + """ + module = cls.__module__ + if module == "builtins": + return cls.__qualname__ + return module + "." + cls.__qualname__ + + +def load_class(class_path): + """Load class from fully qualified class name + + Args: + class_path: fully qualified class name + Returns: + The class type + """ + + try: + if "." in class_path: + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + else: + return getattr(builtins, class_path) + except Exception as ex: + raise TypeError(f"Can't load class {class_path}: {ex}") diff --git a/nvflare/fuel/utils/class_utils.py b/nvflare/fuel/utils/class_utils.py index 36f52593ea..be69d7339f 100644 --- a/nvflare/fuel/utils/class_utils.py +++ b/nvflare/fuel/utils/class_utils.py @@ -18,8 +18,8 @@ from nvflare.apis.fl_component import FLComponent from nvflare.fuel.common.excepts import ConfigError +from nvflare.fuel.utils.class_loader import load_class from nvflare.fuel.utils.components_utils import create_classes_table_static -from nvflare.fuel.utils.fobs import load_class from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception diff --git a/nvflare/fuel/utils/fobs/__init__.py b/nvflare/fuel/utils/fobs/__init__.py index 8d238d32ea..b9a54ef75b 100644 --- a/nvflare/fuel/utils/fobs/__init__.py +++ b/nvflare/fuel/utils/fobs/__init__.py @@ -16,8 +16,6 @@ auto_register_enum_types, deserialize, deserialize_stream, - get_class_name, - load_class, num_decomposers, register, register_data_classes, diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index fb2293b3f4..64ba2de0e7 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -23,12 +23,11 @@ import msgpack +from nvflare.fuel.utils.class_loader import get_class_name, load_class from nvflare.fuel.utils.fobs.datum import DatumManager from nvflare.fuel.utils.fobs.decomposer import DataClassDecomposer, Decomposer, EnumTypeDecomposer __all__ = [ - "get_class_name", - "load_class", "register", "register_data_classes", "register_enum_types", @@ -60,41 +59,6 @@ _data_auto_registration = True -def get_class_name(cls: Type) -> str: - """Get canonical class path or fully qualified name. The builtins module is removed - so common builtin class can be referenced with its normal name - - Args: - cls: The class type - Returns: - The canonical name - """ - module = cls.__module__ - if module == "builtins": - return cls.__qualname__ - return module + "." + cls.__qualname__ - - -def load_class(class_path): - """Load class from fully qualified class name - - Args: - class_path: fully qualified class name - Returns: - The class type - """ - - try: - if "." in class_path: - module_name, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) - else: - return getattr(builtins, class_path) - except Exception as ex: - raise TypeError(f"Can't load class {class_path}: {ex}") - - def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: """Register a decomposer. It does nothing if decomposer is already registered for the type From 0f28d9e4b590e05bd02c0ad64e5c21a50969d6de Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 13:40:05 -0500 Subject: [PATCH 03/11] Added document --- examples/advanced/streaming/README.md | 68 +++++++++++++++++++ .../streamers/container_retriever.py | 2 +- .../streamers/container_streamer.py | 11 ++- nvflare/fuel/utils/wfconf.py | 2 +- 4 files changed, 75 insertions(+), 8 deletions(-) create mode 100644 examples/advanced/streaming/README.md diff --git a/examples/advanced/streaming/README.md b/examples/advanced/streaming/README.md new file mode 100644 index 0000000000..2fc4577772 --- /dev/null +++ b/examples/advanced/streaming/README.md @@ -0,0 +1,68 @@ +# Object Streaming Examples + +## Overview +The examples here demonstrate how to use object streamers to send large file/objects memory efficiently. + +The object streamer uses less memory because it sends files by chunks (default chunk size is 1MB) and +it sends containers entry by entry. + +For example, if you have a dict with 10 1GB entries, it will take 10GB extra space to send the dict without +streaming. It only requires extra 1GB to serialize the entry using streaming. + +## Concepts + +### Object Streamer + +ObjectStreamer is a base class to stream an object piece by piece. The `StreamableEngine` built in the NVFlare can +stream any implementations of ObjectSteamer + +Following implementations are included in NVFlare, + +* `FileStreamer`: It can be used to stream a file +* `ContainerStreamer`: This class can stream a container entry by entry. Currently, dict, list and set are supported + +The container streamer can only stream the top level entries. All the sub entries of a top entry are sent at once with +the top entry. + +### Object Retriever + +`ObjectRetriever` is designed to request an object to be streamed from a remote site. It automatically sets up the streaming +on both ends and handles the coordination. + +Currently, following implementations are available, + +* `FileRetriever`: It's used to retrieve a file from remote site using FileStreamer. +* `ContainerRetriever`: This class can be used to retrieve a container from remote site using ContainerStreamer. + +To use ContainerRetriever, the container must be given a name and added on the sending site, + +``` +ContainerRetriever.add("model", model_dict) +``` + +## Example Jobs + +### file_streaming job + +This job uses the FileStreamer object to send a large file from server to client. + +It demonstrates following mechanisms: +1. It uses components to handle the file transferring. No training workflow is used. + Since executor is required by NVFlare, a dummy executor is created. +2. It shows how to use the streamer directly without an object retriever. + +The job creates a temporary file to test. You can run the job in POC or using simulator as follows, + +``` +nvflare simulator -n 1 -t 1 jobs/file_streaming +``` +### dict_streaming job + +This job demonstrate how to send a dict from server to client using object retriever. + +It creates a task called "retrieve_dict" to tell client to get ready for the streaming. + +The example can be run in simulator like this, +``` +nvflare simulator -n 1 -t 1 jobs/dict_streaming +``` diff --git a/nvflare/app_common/streamers/container_retriever.py b/nvflare/app_common/streamers/container_retriever.py index 30ad03d69d..fdfe0ef161 100644 --- a/nvflare/app_common/streamers/container_retriever.py +++ b/nvflare/app_common/streamers/container_retriever.py @@ -18,8 +18,8 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.streaming import StreamContext -from .container_streamer import ContainerStreamer +from .container_streamer import ContainerStreamer from .file_streamer import FileStreamer from .object_retriever import ObjectRetriever diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py index aea3815d0d..f7928d0032 100644 --- a/nvflare/app_common/streamers/container_streamer.py +++ b/nvflare/app_common/streamers/container_streamer.py @@ -17,12 +17,11 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext -from nvflare.fuel.utils.class_loader import load_class, get_class_name +from nvflare.app_common.streamers.streamer_base import StreamerBase +from nvflare.fuel.utils.class_loader import get_class_name, load_class from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.validation_utils import check_positive_number -from nvflare.app_common.streamers.streamer_base import StreamerBase - _PREFIX = "ContainerStreamer." # Keys for StreamCtx @@ -117,13 +116,13 @@ def produce( try: self.next = next(self.iterator) - last = False + self.last = False except StopIteration: - last = True + self.last = True result = Shareable() result[_KEY_ENTRY] = entry - result[_KEY_LAST] = last + result[_KEY_LAST] = self.last return result, self.entry_timeout def process_replies( diff --git a/nvflare/fuel/utils/wfconf.py b/nvflare/fuel/utils/wfconf.py index 43e698b5e3..fc8b8f65af 100644 --- a/nvflare/fuel/utils/wfconf.py +++ b/nvflare/fuel/utils/wfconf.py @@ -22,7 +22,7 @@ from nvflare.security.logging import secure_format_exception from .argument_utils import parse_vars -from .class_utils import ModuleScanner, load_class, instantiate_class +from .class_utils import ModuleScanner, instantiate_class, load_class from .dict_utils import extract_first_level_primitive, merge_dict from .json_scanner import JsonObjectProcessor, JsonScanner, Node From 3e631186e048aa93e960e123aebb66611c621449 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 13:48:20 -0500 Subject: [PATCH 04/11] Moved load_class to class_loader --- nvflare/fuel/utils/wfconf.py | 3 ++- nvflare/private/json_configer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nvflare/fuel/utils/wfconf.py b/nvflare/fuel/utils/wfconf.py index fc8b8f65af..c283d18dd6 100644 --- a/nvflare/fuel/utils/wfconf.py +++ b/nvflare/fuel/utils/wfconf.py @@ -22,7 +22,8 @@ from nvflare.security.logging import secure_format_exception from .argument_utils import parse_vars -from .class_utils import ModuleScanner, instantiate_class, load_class +from .class_loader import load_class +from .class_utils import ModuleScanner, instantiate_class from .dict_utils import extract_first_level_primitive, merge_dict from .json_scanner import JsonObjectProcessor, JsonScanner, Node diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index e47de9b270..aca538658b 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -15,7 +15,8 @@ from typing import List, Union from nvflare.fuel.common.excepts import ComponentNotAuthorized, ConfigError -from nvflare.fuel.utils.class_utils import ModuleScanner, load_class +from nvflare.fuel.utils.class_loader import load_class +from nvflare.fuel.utils.class_utils import ModuleScanner from nvflare.fuel.utils.component_builder import ComponentBuilder from nvflare.fuel.utils.config_factory import ConfigFactory from nvflare.fuel.utils.config_service import ConfigService From 1c5eb1f4ec5128b69e343edeb81abb3a7c3b1985 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 13:52:45 -0500 Subject: [PATCH 05/11] Fixed example formatting --- .../streaming/jobs/file_streaming/app/custom/file_streaming.py | 2 +- .../streaming/jobs/file_streaming/app/custom/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py index b23bf6e15f..0495efb6c2 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py @@ -23,7 +23,7 @@ CHANNEL = "_test_channel" TOPIC = "_test_topic" -SIZE = 100*1024*1024 # 100 MB +SIZE = 100 * 1024 * 1024 # 100 MB class FileSender(FLComponent): diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py index 11d1d18d3b..5888991a1d 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py @@ -31,7 +31,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABORT_TASK: self.log_info(fl_ctx, "Trainer is aborted") self.aborted = True - + def execute( self, task_name: str, From 33e6c5151dd1248f6cd11b5d79052c4f9e2d0d5d Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 16:51:49 -0500 Subject: [PATCH 06/11] Simplified container producer --- .../app/custom/streaming_executor.py | 2 -- .../file_streaming/app/custom/controller.py | 2 ++ .../app/custom/file_streaming.py | 3 +-- .../jobs/file_streaming/app/custom/trainer.py | 3 --- .../streamers/container_retriever.py | 3 --- .../app_common/streamers/container_streamer.py | 18 ++++++++---------- nvflare/fuel/utils/fobs/fobs.py | 1 - 7 files changed, 11 insertions(+), 21 deletions(-) diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py index 8756144b1e..a238f82aff 100644 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor @@ -20,7 +19,6 @@ from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.streamers.container_retriever import ContainerRetriever -from nvflare.app_common.streamers.file_retriever import FileRetriever class StreamingExecutor(Executor): diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py index 7b13676bb0..206346e8f2 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py @@ -29,6 +29,8 @@ class SimpleController(Controller): def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): logger.info(f"Entering control loop of {self.__class__.__name__}") engine = fl_ctx.get_engine() + + # Wait till receiver is done. Otherwise, the job ends. receiver = engine.get_component("receiver") while not receiver.is_done(): time.sleep(0.2) diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py index 0495efb6c2..9b49b230e5 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py @@ -13,7 +13,6 @@ # limitations under the License. import os import tempfile -import time from threading import Thread from nvflare.apis.event_type import EventType @@ -59,7 +58,7 @@ def _sending_file(self, fl_ctx): rc, result = FileStreamer.stream_file( targets=["server"], - stream_ctx=None, + stream_ctx={}, channel=CHANNEL, topic=TOPIC, file_name=self.file_name, diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py index 5888991a1d..216d69f793 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import time - from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor diff --git a/nvflare/app_common/streamers/container_retriever.py b/nvflare/app_common/streamers/container_retriever.py index fdfe0ef161..a2781b536e 100644 --- a/nvflare/app_common/streamers/container_retriever.py +++ b/nvflare/app_common/streamers/container_retriever.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any -from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable from nvflare.apis.streaming import StreamContext from .container_streamer import ContainerStreamer -from .file_streamer import FileStreamer from .object_retriever import ObjectRetriever diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py index f7928d0032..263ea4fa06 100644 --- a/nvflare/app_common/streamers/container_streamer.py +++ b/nvflare/app_common/streamers/container_streamer.py @@ -98,9 +98,10 @@ def __init__(self, container, entry_timeout): self.iterator = iter(container.items()) else: self.iterator = iter(container) - self.entry_timeout = entry_timeout + self.size = len(container) + self.count = 0 self.last = False - self.next = None + self.entry_timeout = entry_timeout def produce( self, @@ -108,16 +109,13 @@ def produce( fl_ctx: FLContext, ) -> Tuple[Shareable, float]: - # To check if this is the last entry, need to get one entry ahead - if self.next: - entry = self.next - else: - entry = next(self.iterator) - try: - self.next = next(self.iterator) - self.last = False + entry = next(self.iterator) + self.count += 1 + self.last = self.count >= self.size except StopIteration: + self.logger.error(f"Producer called too many times {self.count}/{self.size}") + entry = None self.last = True result = Shareable() diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index 64ba2de0e7..aaca8c8a7c 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import builtins import importlib import inspect import logging From e63c82934b675070840e5c21cd62e5f8cd2d3659 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 23 Jan 2025 17:14:45 -0500 Subject: [PATCH 07/11] Fixed a doc error --- examples/advanced/streaming/README.md | 2 +- nvflare/app_common/streamers/container_streamer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/advanced/streaming/README.md b/examples/advanced/streaming/README.md index 2fc4577772..6ab160c7d5 100644 --- a/examples/advanced/streaming/README.md +++ b/examples/advanced/streaming/README.md @@ -37,7 +37,7 @@ Currently, following implementations are available, To use ContainerRetriever, the container must be given a name and added on the sending site, ``` -ContainerRetriever.add("model", model_dict) +ContainerRetriever.add_container("model", model_dict) ``` ## Example Jobs diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py index 263ea4fa06..13c1364ece 100644 --- a/nvflare/app_common/streamers/container_streamer.py +++ b/nvflare/app_common/streamers/container_streamer.py @@ -54,6 +54,8 @@ def consume( if isinstance(self.container, dict): key, value = entry self.container[key] = value + elif isinstance(self.container, set): + self.container.add(entry) else: self.container.append(entry) except Exception as ex: From f0361b7fc8759e3a1a2b0afb6e6253de7069024e Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Wed, 29 Jan 2025 14:22:12 -0500 Subject: [PATCH 08/11] End producer on StopIteration --- nvflare/app_common/streamers/container_streamer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py index 13c1364ece..569ea0b257 100644 --- a/nvflare/app_common/streamers/container_streamer.py +++ b/nvflare/app_common/streamers/container_streamer.py @@ -117,8 +117,8 @@ def produce( self.last = self.count >= self.size except StopIteration: self.logger.error(f"Producer called too many times {self.count}/{self.size}") - entry = None self.last = True + return None, 0.0 result = Shareable() result[_KEY_ENTRY] = entry From 00cdb0b7184610670e71c86d397232b2c41af058 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Wed, 29 Jan 2025 15:19:01 -0500 Subject: [PATCH 09/11] Pinned bitsandbytes version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index b772acd3b7..bd4e1ccd84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ app_opt = %(MONITORING)s pytorch_lightning xgboost - bitsandbytes + bitsandbytes==0.42.0 app_opt_mac = %(PT)s %(SKLEARN)s From 344d606149ba1eabecceb32b6e747be712d8287f Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Wed, 29 Jan 2025 15:31:05 -0500 Subject: [PATCH 10/11] Trying different version of bitsandbytes --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index bd4e1ccd84..c4b8afabdd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ app_opt = %(MONITORING)s pytorch_lightning xgboost - bitsandbytes==0.42.0 + bitsandbytes==0.44.0 app_opt_mac = %(PT)s %(SKLEARN)s From 6cd490c147f4bf463fdde7edff7501d30b4d0379 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 30 Jan 2025 05:25:43 -0500 Subject: [PATCH 11/11] Undo setup.cfg changes --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index c4b8afabdd..b772acd3b7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ app_opt = %(MONITORING)s pytorch_lightning xgboost - bitsandbytes==0.44.0 + bitsandbytes app_opt_mac = %(PT)s %(SKLEARN)s