diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index a13c8a5570..050bbd1a6e 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -19,6 +19,7 @@ import uuid from typing import Dict, List, Union +from nvflare.apis.signal import Signal from nvflare.fuel.f3.cellnet.core_cell import CoreCell, TargetMessage from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey, MessageType, ReturnCode from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply @@ -26,6 +27,7 @@ from nvflare.fuel.f3.stream_cell import StreamCell from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey from nvflare.fuel.f3.streaming.stream_types import StreamFuture +from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait from nvflare.security.logging import secure_format_exception CHANNELS_TO_EXCLUDE = ( @@ -147,6 +149,7 @@ def _broadcast_request( timeout=None, secure=False, optional=False, + abort_signal: Signal = None, ) -> Dict[str, Message]: """ Send a message over a channel to specified destination cell(s), and wait for reply @@ -159,6 +162,7 @@ def _broadcast_request( timeout: how long to wait for replies secure: End-end encryption optional: whether the message is optional + abort_signal: signal to abort the message Returns: a dict of: cell_id => reply message @@ -181,6 +185,7 @@ def _broadcast_request( req = Message(copy.deepcopy(request.headers), request.payload) target_argument["request"] = TargetMessage(t, channel, topic, req).message target_argument["target"] = t + target_argument["abort_signal"] = abort_signal target_argument.update(fixed_dict) f = executor.submit(self._send_one_request, **target_argument) future_to_target[f] = t @@ -232,26 +237,43 @@ def _get_result(self, req_id): waiter = self.requests_dict.pop(req_id) return waiter.result - def _future_wait(self, future, timeout): + def _check_error(self, future): + if future.error: + # must return a negative number + return -1 + else: + return WaiterRC.OK + + def _future_wait(self, future, timeout, abort_signal: Signal): # future could have an error! last_progress = 0 - while not future.waiter.wait(timeout): - if future.error: - return False - current_progress = future.get_progress() - if last_progress == current_progress: - return False + while True: + rc = conditional_wait(future.waiter, timeout, abort_signal, condition_cb=self._check_error, future=future) + if rc == WaiterRC.IS_SET: + # waiter has been set! + break + elif rc == WaiterRC.TIMEOUT: + # timed out: check whether any progress has been made during this time + current_progress = future.get_progress() + if last_progress == current_progress: + # no progress in timeout secs: consider this to be a failure + return False + else: + # good progress + self.logger.debug(f"{current_progress=}") + last_progress = current_progress else: - self.logger.debug(f"{current_progress=}") - last_progress = current_progress + # error condition: aborted or future error + return False + if future.error: return False else: return True - def _encode_message(self, msg: Message): + def _encode_message(self, msg: Message) -> int: try: - encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING) + return encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING) except BaseException as exc: self.logger.error(f"Can't encode {msg=} {exc=}") raise exc @@ -265,6 +287,7 @@ def _send_request( timeout=10.0, secure=False, optional=False, + abort_signal: Signal = None, ): """Stream one request to the target @@ -276,12 +299,13 @@ def _send_request( timeout: how long to wait secure: is P2P security to be applied optional: is the message optional + abort_signal: signal to abort the message Returns: reply data """ self._encode_message(request) - return self._send_one_request(channel, target, topic, request, timeout, secure, optional) + return self._send_one_request(channel, target, topic, request, timeout, secure, optional, abort_signal) def _send_one_request( self, @@ -292,6 +316,7 @@ def _send_one_request( timeout=10.0, secure=False, optional=False, + abort_signal=None, ): req_id = str(uuid.uuid4()) request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id}) @@ -312,7 +337,7 @@ def _send_one_request( # Three stages, sending, waiting for receiving first byte, receiving # sending with progress timeout self.logger.debug(f"{req_id=}: entering sending wait {timeout=}") - sending_complete = self._future_wait(future, timeout) + sending_complete = self._future_wait(future, timeout, abort_signal) if not sending_complete: self.logger.debug(f"{req_id=}: sending timeout {timeout=}") return self._get_result(req_id) @@ -321,15 +346,17 @@ def _send_one_request( # waiting for receiving first byte self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}") - if not waiter.in_receiving.wait(timeout): - self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}") + + waiter_rc = conditional_wait(waiter.in_receiving, timeout, abort_signal) + if waiter_rc != WaiterRC.IS_SET: + self.logger.debug(f"{req_id=}: remote processing timeout {timeout=} {waiter_rc=}") return self._get_result(req_id) self.logger.debug(f"{req_id=}: in receiving") # receiving with progress timeout r_future = waiter.receiving_future self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}") - receiving_complete = self._future_wait(r_future, timeout) + receiving_complete = self._future_wait(r_future, timeout, abort_signal) if not receiving_complete: self.logger.info(f"{req_id=}: receiving timeout {timeout=}") return self._get_result(req_id) diff --git a/nvflare/fuel/f3/cellnet/utils.py b/nvflare/fuel/f3/cellnet/utils.py index 415790be1a..43d2ec16fe 100644 --- a/nvflare/fuel/f3/cellnet/utils.py +++ b/nvflare/fuel/f3/cellnet/utils.py @@ -83,7 +83,22 @@ def format_log_message(fqcn: str, message: Message, log: str) -> str: return " ".join(context) + f"] {log}" -def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING): +def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING) -> int: + """Encode the payload of the specified message. + + Args: + message: the message to be encoded + encoding_key: the key name of the encoding property in the message header. If the encoding property is not + set in the message header, then it means that the message payload has not been encoded. If the property is + already set, then the message payload is already encoded, and no processing is done. + If encoding is needed, we will determine the encoding scheme based on the data type of the payload: + - If the payload is None, encoding scheme is NONE + - If the payload data type is like bytes, encoding scheme is BYTES + - Otherwise, encoding scheme is FOBS, and the payload is serialized with FOBS. + + Returns: the encoded payload size. + + """ encoding = message.get_header(encoding_key) if not encoding: if message.payload is None: @@ -97,6 +112,7 @@ def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCOD size = buffer_len(message.payload) message.set_header(MessageHeaderKey.PAYLOAD_LEN, size) + return size def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING): diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index de82d38bb6..faf160becd 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -27,7 +27,7 @@ StreamDataType, StreamHeaderKey, ) -from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture +from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture, StreamTaskSpec from nvflare.fuel.f3.streaming.stream_utils import ( ONE_MB, gen_stream_id, @@ -49,7 +49,7 @@ log = logging.getLogger(__name__) -class TxTask: +class TxTask(StreamTaskSpec): def __init__( self, cell: CoreCell, @@ -84,7 +84,7 @@ def __init__( self.optional = optional self.stopped = False - self.stream_future = StreamFuture(self.sid) + self.stream_future = StreamFuture(self.sid, task_handle=self) self.stream_future.set_size(stream.get_size()) self.window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE) @@ -184,6 +184,9 @@ def stop(self, error: Optional[StreamError] = None, notify=True): self.stopped = True + if self.task_future: + self.task_future.cancel() + if not error: # Result is the number of bytes streamed if self.stream_future: @@ -232,6 +235,9 @@ def handle_ack(self, message: Message): def start_task_thread(self, task_handler: Callable): self.task_future = stream_thread_pool.submit(task_handler, self) + def cancel(self): + self.stop(error=StreamError("cancelled")) + class ByteStreamer: diff --git a/nvflare/fuel/f3/streaming/stream_types.py b/nvflare/fuel/f3/streaming/stream_types.py index bff6c23954..e604c6ac2c 100644 --- a/nvflare/fuel/f3/streaming/stream_types.py +++ b/nvflare/fuel/f3/streaming/stream_types.py @@ -110,13 +110,23 @@ def set_index(self, index: int): self.index = index +class StreamTaskSpec(ABC): + def cancel(self): + """Cancel the task + + Returns: + + """ + pass + + class StreamFuture: """Future class for all stream calls. Fashioned after concurrent.futures.Future """ - def __init__(self, stream_id: int, headers: Optional[dict] = None): + def __init__(self, stream_id: int, headers: Optional[dict] = None, task_handle: StreamTaskSpec = None): self.stream_id = stream_id self.headers = headers self.waiter = threading.Event() @@ -126,6 +136,7 @@ def __init__(self, stream_id: int, headers: Optional[dict] = None): self.size = 0 self.progress = 0 self.done_callbacks = [] + self.task_handle = task_handle def get_stream_id(self) -> int: return self.stream_id @@ -157,7 +168,8 @@ def cancel(self): return False self.error = StreamCancelled(f"Stream {self.stream_id} is cancelled") - + if self.task_handle: + self.task_handle.cancel() return True def cancelled(self): diff --git a/nvflare/fuel/utils/waiter_utils.py b/nvflare/fuel/utils/waiter_utils.py new file mode 100644 index 0000000000..c0bb6e3b8a --- /dev/null +++ b/nvflare/fuel/utils/waiter_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time + +from nvflare.apis.signal import Signal + +_SMALL_WAIT = 0.01 + + +class WaiterRC: + OK = 0 + IS_SET = 1 + TIMEOUT = 2 + ABORTED = 3 + ERROR = 4 + + +def conditional_wait(waiter: threading.Event, timeout: float, abort_signal: Signal, condition_cb=None, **cb_kwargs): + """Wait for an event until timeout, aborted, or some condition is met. + + Args: + waiter: the event to wait + timeout: the max time to wait + abort_signal: signal to abort the wait + condition_cb: condition to check during waiting + **cb_kwargs: kwargs for the condition_cb + + Returns: return code to indicate how the waiting is stopped: + IS_SET: the event is set + TIMEOUT: the event timed out + ABORTED: abort signal is triggered during the wait + ERROR: the condition_cb encountered unhandled exception + OK: only used by the condition_cb to say "all is normal" + other integers: returned by condition_cb for other conditions met + + """ + wait_time = min(_SMALL_WAIT, timeout) + start = time.time() + while True: + if waiter.wait(wait_time): + # the event just happened! + return WaiterRC.IS_SET + + if time.time() - start >= timeout: + return WaiterRC.TIMEOUT + + # check conditions + if abort_signal and abort_signal.triggered: + return WaiterRC.ABORTED + + if condition_cb: + try: + rc = condition_cb(**cb_kwargs) + if rc != WaiterRC.OK: + # a bad condition is detected by the condition_cb + # we return the rc from the condition_cb + return rc + except: + return WaiterRC.ERROR diff --git a/nvflare/private/aux_runner.py b/nvflare/private/aux_runner.py index 1047872ca4..83edac9ae3 100644 --- a/nvflare/private/aux_runner.py +++ b/nvflare/private/aux_runner.py @@ -280,7 +280,9 @@ def _send_multi_requests( ) if timeout > 0: - cell_replies = cell.broadcast_multi_requests(target_messages, timeout, optional=optional, secure=secure) + cell_replies = cell.broadcast_multi_requests( + target_messages, timeout, optional=optional, secure=secure, abort_signal=fl_ctx.get_run_abort_signal() + ) return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name) else: cell.fire_multi_requests_and_forget( @@ -395,6 +397,7 @@ def _send_to_cell( timeout=timeout, optional=optional, secure=secure, + abort_signal=fl_ctx.get_run_abort_signal(), ) return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name) else: diff --git a/nvflare/private/fed/client/client_engine_executor_spec.py b/nvflare/private/fed/client/client_engine_executor_spec.py index 5551c463ee..062955d505 100644 --- a/nvflare/private/fed/client/client_engine_executor_spec.py +++ b/nvflare/private/fed/client/client_engine_executor_spec.py @@ -149,10 +149,11 @@ def fire_and_forget_aux_request( """Send an async request to Server via the aux channel. Args: - topic: topic of the request + topic: topic of the request. request: request to be sent fl_ctx: FL context optional: whether the request is optional + secure: whether to send the message in P2P secure mode Returns: diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index 5a90b0ef2a..8a4653c130 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -244,7 +244,13 @@ def send_aux_request( if msg_targets: return self.aux_runner.send_aux_request( - msg_targets, topic, request, timeout, fl_ctx, optional=optional, secure=secure + msg_targets, + topic, + request, + timeout, + fl_ctx, + optional=optional, + secure=secure, ) else: return {} diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 0012a91e78..aa8ff11683 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -28,8 +28,9 @@ from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.shareable import Shareable from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx -from nvflare.fuel.f3.cellnet.core_cell import FQCN, CoreCell +from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.utils import format_size from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec @@ -67,7 +68,7 @@ def __init__( secure_train=False, client_state_processors: Optional[List[Filter]] = None, compression=None, - cell: CoreCell = None, + cell: Cell = None, client_register_interval=2, timeout=5.0, maint_msg_timeout=5.0, @@ -312,6 +313,7 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): request=task_message, timeout=timeout, optional=True, + abort_signal=fl_ctx.get_run_abort_signal(), ) end_time = time.time() return_code = task.get_header(MessageHeaderKey.RETURN_CODE) @@ -388,6 +390,7 @@ def submit_update( request=task_message, timeout=timeout, optional=optional, + abort_signal=fl_ctx.get_run_abort_signal(), ) end_time = time.time() return_code = result.get_header(MessageHeaderKey.RETURN_CODE) diff --git a/nvflare/private/fed/server/server_commands.py b/nvflare/private/fed/server/server_commands.py index 20e0e9a85d..b57e0e16f4 100644 --- a/nvflare/private/fed/server/server_commands.py +++ b/nvflare/private/fed/server/server_commands.py @@ -144,7 +144,7 @@ def get_command_name(self) -> str: return ServerCommandNames.GET_TASK def process(self, data: Shareable, fl_ctx: FLContext): - """Called to process the abort command. + """Called to process the GetTask command. Args: data: process data