Skip to content

Commit

Permalink
Support aborting messages (#3053)
Browse files Browse the repository at this point in the history
* support abort of messages

* removed unused imports

* support aborting messages

---------

Co-authored-by: Zhihong Zhang <[email protected]>
  • Loading branch information
yanchengnv and nvidianz authored Oct 31, 2024
1 parent d91b111 commit 80f9e82
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 28 deletions.
59 changes: 43 additions & 16 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import uuid
from typing import Dict, List, Union

from nvflare.apis.signal import Signal
from nvflare.fuel.f3.cellnet.core_cell import CoreCell, TargetMessage
from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey, MessageType, ReturnCode
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stream_cell import StreamCell
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait
from nvflare.security.logging import secure_format_exception

CHANNELS_TO_EXCLUDE = (
Expand Down Expand Up @@ -147,6 +149,7 @@ def _broadcast_request(
timeout=None,
secure=False,
optional=False,
abort_signal: Signal = None,
) -> Dict[str, Message]:
"""
Send a message over a channel to specified destination cell(s), and wait for reply
Expand All @@ -159,6 +162,7 @@ def _broadcast_request(
timeout: how long to wait for replies
secure: End-end encryption
optional: whether the message is optional
abort_signal: signal to abort the message
Returns: a dict of: cell_id => reply message
Expand All @@ -181,6 +185,7 @@ def _broadcast_request(
req = Message(copy.deepcopy(request.headers), request.payload)
target_argument["request"] = TargetMessage(t, channel, topic, req).message
target_argument["target"] = t
target_argument["abort_signal"] = abort_signal
target_argument.update(fixed_dict)
f = executor.submit(self._send_one_request, **target_argument)
future_to_target[f] = t
Expand Down Expand Up @@ -232,26 +237,43 @@ def _get_result(self, req_id):
waiter = self.requests_dict.pop(req_id)
return waiter.result

def _future_wait(self, future, timeout):
def _check_error(self, future):
if future.error:
# must return a negative number
return -1
else:
return WaiterRC.OK

def _future_wait(self, future, timeout, abort_signal: Signal):
# future could have an error!
last_progress = 0
while not future.waiter.wait(timeout):
if future.error:
return False
current_progress = future.get_progress()
if last_progress == current_progress:
return False
while True:
rc = conditional_wait(future.waiter, timeout, abort_signal, condition_cb=self._check_error, future=future)
if rc == WaiterRC.IS_SET:
# waiter has been set!
break
elif rc == WaiterRC.TIMEOUT:
# timed out: check whether any progress has been made during this time
current_progress = future.get_progress()
if last_progress == current_progress:
# no progress in timeout secs: consider this to be a failure
return False
else:
# good progress
self.logger.debug(f"{current_progress=}")
last_progress = current_progress
else:
self.logger.debug(f"{current_progress=}")
last_progress = current_progress
# error condition: aborted or future error
return False

if future.error:
return False
else:
return True

def _encode_message(self, msg: Message):
def _encode_message(self, msg: Message) -> int:
try:
encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING)
return encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING)
except BaseException as exc:
self.logger.error(f"Can't encode {msg=} {exc=}")
raise exc
Expand All @@ -265,6 +287,7 @@ def _send_request(
timeout=10.0,
secure=False,
optional=False,
abort_signal: Signal = None,
):
"""Stream one request to the target
Expand All @@ -276,12 +299,13 @@ def _send_request(
timeout: how long to wait
secure: is P2P security to be applied
optional: is the message optional
abort_signal: signal to abort the message
Returns: reply data
"""
self._encode_message(request)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional, abort_signal)

def _send_one_request(
self,
Expand All @@ -292,6 +316,7 @@ def _send_one_request(
timeout=10.0,
secure=False,
optional=False,
abort_signal=None,
):
req_id = str(uuid.uuid4())
request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id})
Expand All @@ -312,7 +337,7 @@ def _send_one_request(
# Three stages, sending, waiting for receiving first byte, receiving
# sending with progress timeout
self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
sending_complete = self._future_wait(future, timeout)
sending_complete = self._future_wait(future, timeout, abort_signal)
if not sending_complete:
self.logger.debug(f"{req_id=}: sending timeout {timeout=}")
return self._get_result(req_id)
Expand All @@ -321,15 +346,17 @@ def _send_one_request(

# waiting for receiving first byte
self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
if not waiter.in_receiving.wait(timeout):
self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}")

waiter_rc = conditional_wait(waiter.in_receiving, timeout, abort_signal)
if waiter_rc != WaiterRC.IS_SET:
self.logger.debug(f"{req_id=}: remote processing timeout {timeout=} {waiter_rc=}")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: in receiving")

# receiving with progress timeout
r_future = waiter.receiving_future
self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
receiving_complete = self._future_wait(r_future, timeout)
receiving_complete = self._future_wait(r_future, timeout, abort_signal)
if not receiving_complete:
self.logger.info(f"{req_id=}: receiving timeout {timeout=}")
return self._get_result(req_id)
Expand Down
18 changes: 17 additions & 1 deletion nvflare/fuel/f3/cellnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,22 @@ def format_log_message(fqcn: str, message: Message, log: str) -> str:
return " ".join(context) + f"] {log}"


def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING) -> int:
"""Encode the payload of the specified message.
Args:
message: the message to be encoded
encoding_key: the key name of the encoding property in the message header. If the encoding property is not
set in the message header, then it means that the message payload has not been encoded. If the property is
already set, then the message payload is already encoded, and no processing is done.
If encoding is needed, we will determine the encoding scheme based on the data type of the payload:
- If the payload is None, encoding scheme is NONE
- If the payload data type is like bytes, encoding scheme is BYTES
- Otherwise, encoding scheme is FOBS, and the payload is serialized with FOBS.
Returns: the encoded payload size.
"""
encoding = message.get_header(encoding_key)
if not encoding:
if message.payload is None:
Expand All @@ -97,6 +112,7 @@ def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCOD

size = buffer_len(message.payload)
message.set_header(MessageHeaderKey.PAYLOAD_LEN, size)
return size


def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
Expand Down
12 changes: 9 additions & 3 deletions nvflare/fuel/f3/streaming/byte_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
StreamDataType,
StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture, StreamTaskSpec
from nvflare.fuel.f3.streaming.stream_utils import (
ONE_MB,
gen_stream_id,
Expand All @@ -49,7 +49,7 @@
log = logging.getLogger(__name__)


class TxTask:
class TxTask(StreamTaskSpec):
def __init__(
self,
cell: CoreCell,
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.optional = optional
self.stopped = False

self.stream_future = StreamFuture(self.sid)
self.stream_future = StreamFuture(self.sid, task_handle=self)
self.stream_future.set_size(stream.get_size())

self.window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE)
Expand Down Expand Up @@ -184,6 +184,9 @@ def stop(self, error: Optional[StreamError] = None, notify=True):

self.stopped = True

if self.task_future:
self.task_future.cancel()

if not error:
# Result is the number of bytes streamed
if self.stream_future:
Expand Down Expand Up @@ -232,6 +235,9 @@ def handle_ack(self, message: Message):
def start_task_thread(self, task_handler: Callable):
self.task_future = stream_thread_pool.submit(task_handler, self)

def cancel(self):
self.stop(error=StreamError("cancelled"))


class ByteStreamer:

Expand Down
16 changes: 14 additions & 2 deletions nvflare/fuel/f3/streaming/stream_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,23 @@ def set_index(self, index: int):
self.index = index


class StreamTaskSpec(ABC):
def cancel(self):
"""Cancel the task
Returns:
"""
pass


class StreamFuture:
"""Future class for all stream calls.
Fashioned after concurrent.futures.Future
"""

def __init__(self, stream_id: int, headers: Optional[dict] = None):
def __init__(self, stream_id: int, headers: Optional[dict] = None, task_handle: StreamTaskSpec = None):
self.stream_id = stream_id
self.headers = headers
self.waiter = threading.Event()
Expand All @@ -126,6 +136,7 @@ def __init__(self, stream_id: int, headers: Optional[dict] = None):
self.size = 0
self.progress = 0
self.done_callbacks = []
self.task_handle = task_handle

def get_stream_id(self) -> int:
return self.stream_id
Expand Down Expand Up @@ -157,7 +168,8 @@ def cancel(self):
return False

self.error = StreamCancelled(f"Stream {self.stream_id} is cancelled")

if self.task_handle:
self.task_handle.cancel()
return True

def cancelled(self):
Expand Down
71 changes: 71 additions & 0 deletions nvflare/fuel/utils/waiter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

from nvflare.apis.signal import Signal

_SMALL_WAIT = 0.01


class WaiterRC:
OK = 0
IS_SET = 1
TIMEOUT = 2
ABORTED = 3
ERROR = 4


def conditional_wait(waiter: threading.Event, timeout: float, abort_signal: Signal, condition_cb=None, **cb_kwargs):
"""Wait for an event until timeout, aborted, or some condition is met.
Args:
waiter: the event to wait
timeout: the max time to wait
abort_signal: signal to abort the wait
condition_cb: condition to check during waiting
**cb_kwargs: kwargs for the condition_cb
Returns: return code to indicate how the waiting is stopped:
IS_SET: the event is set
TIMEOUT: the event timed out
ABORTED: abort signal is triggered during the wait
ERROR: the condition_cb encountered unhandled exception
OK: only used by the condition_cb to say "all is normal"
other integers: returned by condition_cb for other conditions met
"""
wait_time = min(_SMALL_WAIT, timeout)
start = time.time()
while True:
if waiter.wait(wait_time):
# the event just happened!
return WaiterRC.IS_SET

if time.time() - start >= timeout:
return WaiterRC.TIMEOUT

# check conditions
if abort_signal and abort_signal.triggered:
return WaiterRC.ABORTED

if condition_cb:
try:
rc = condition_cb(**cb_kwargs)
if rc != WaiterRC.OK:
# a bad condition is detected by the condition_cb
# we return the rc from the condition_cb
return rc
except:
return WaiterRC.ERROR
5 changes: 4 additions & 1 deletion nvflare/private/aux_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def _send_multi_requests(
)

if timeout > 0:
cell_replies = cell.broadcast_multi_requests(target_messages, timeout, optional=optional, secure=secure)
cell_replies = cell.broadcast_multi_requests(
target_messages, timeout, optional=optional, secure=secure, abort_signal=fl_ctx.get_run_abort_signal()
)
return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name)
else:
cell.fire_multi_requests_and_forget(
Expand Down Expand Up @@ -395,6 +397,7 @@ def _send_to_cell(
timeout=timeout,
optional=optional,
secure=secure,
abort_signal=fl_ctx.get_run_abort_signal(),
)
return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name)
else:
Expand Down
3 changes: 2 additions & 1 deletion nvflare/private/fed/client/client_engine_executor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def fire_and_forget_aux_request(
"""Send an async request to Server via the aux channel.
Args:
topic: topic of the request
topic: topic of the request.
request: request to be sent
fl_ctx: FL context
optional: whether the request is optional
secure: whether to send the message in P2P secure mode
Returns:
Expand Down
8 changes: 7 additions & 1 deletion nvflare/private/fed/client/client_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,13 @@ def send_aux_request(

if msg_targets:
return self.aux_runner.send_aux_request(
msg_targets, topic, request, timeout, fl_ctx, optional=optional, secure=secure
msg_targets,
topic,
request,
timeout,
fl_ctx,
optional=optional,
secure=secure,
)
else:
return {}
Expand Down
Loading

0 comments on commit 80f9e82

Please sign in to comment.