Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support aborting messages #3053

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import uuid
from typing import Dict, List, Union

from nvflare.apis.signal import Signal
from nvflare.fuel.f3.cellnet.core_cell import CoreCell, TargetMessage
from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey, MessageType, ReturnCode
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stream_cell import StreamCell
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.fuel.utils.waiter_utils import WaiterRC, conditional_wait
from nvflare.security.logging import secure_format_exception

CHANNELS_TO_EXCLUDE = (
Expand Down Expand Up @@ -147,6 +149,7 @@ def _broadcast_request(
timeout=None,
secure=False,
optional=False,
abort_signal: Signal = None,
) -> Dict[str, Message]:
"""
Send a message over a channel to specified destination cell(s), and wait for reply
Expand All @@ -159,6 +162,7 @@ def _broadcast_request(
timeout: how long to wait for replies
secure: End-end encryption
optional: whether the message is optional
abort_signal: signal to abort the message

Returns: a dict of: cell_id => reply message

Expand All @@ -181,6 +185,7 @@ def _broadcast_request(
req = Message(copy.deepcopy(request.headers), request.payload)
target_argument["request"] = TargetMessage(t, channel, topic, req).message
target_argument["target"] = t
target_argument["abort_signal"] = abort_signal
target_argument.update(fixed_dict)
f = executor.submit(self._send_one_request, **target_argument)
future_to_target[f] = t
Expand Down Expand Up @@ -232,26 +237,43 @@ def _get_result(self, req_id):
waiter = self.requests_dict.pop(req_id)
return waiter.result

def _future_wait(self, future, timeout):
def _check_error(self, future):
if future.error:
# must return a negative number
return -1
else:
return WaiterRC.OK

def _future_wait(self, future, timeout, abort_signal: Signal):
# future could have an error!
last_progress = 0
while not future.waiter.wait(timeout):
if future.error:
return False
current_progress = future.get_progress()
if last_progress == current_progress:
return False
while True:
rc = conditional_wait(future.waiter, timeout, abort_signal, condition_cb=self._check_error, future=future)
if rc == WaiterRC.IS_SET:
# waiter has been set!
break
elif rc == WaiterRC.TIMEOUT:
# timed out: check whether any progress has been made during this time
current_progress = future.get_progress()
if last_progress == current_progress:
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
# no progress in timeout secs: consider this to be a failure
return False
else:
# good progress
self.logger.debug(f"{current_progress=}")
last_progress = current_progress
else:
self.logger.debug(f"{current_progress=}")
last_progress = current_progress
# error condition: aborted or future error
return False
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved

if future.error:
return False
else:
return True

def _encode_message(self, msg: Message):
def _encode_message(self, msg: Message) -> int:
try:
encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING)
return encode_payload(msg, StreamHeaderKey.PAYLOAD_ENCODING)
except BaseException as exc:
self.logger.error(f"Can't encode {msg=} {exc=}")
raise exc
Expand All @@ -265,6 +287,7 @@ def _send_request(
timeout=10.0,
secure=False,
optional=False,
abort_signal: Signal = None,
):
"""Stream one request to the target

Expand All @@ -276,12 +299,13 @@ def _send_request(
timeout: how long to wait
secure: is P2P security to be applied
optional: is the message optional
abort_signal: signal to abort the message

Returns: reply data

"""
self._encode_message(request)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional)
return self._send_one_request(channel, target, topic, request, timeout, secure, optional, abort_signal)

def _send_one_request(
self,
Expand All @@ -292,6 +316,7 @@ def _send_one_request(
timeout=10.0,
secure=False,
optional=False,
abort_signal=None,
):
req_id = str(uuid.uuid4())
request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id})
Expand All @@ -312,7 +337,7 @@ def _send_one_request(
# Three stages, sending, waiting for receiving first byte, receiving
# sending with progress timeout
self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
sending_complete = self._future_wait(future, timeout)
sending_complete = self._future_wait(future, timeout, abort_signal)
if not sending_complete:
self.logger.debug(f"{req_id=}: sending timeout {timeout=}")
return self._get_result(req_id)
Expand All @@ -321,15 +346,17 @@ def _send_one_request(

# waiting for receiving first byte
self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
if not waiter.in_receiving.wait(timeout):
self.logger.debug(f"{req_id=}: remote processing timeout {timeout=}")

waiter_rc = conditional_wait(waiter.in_receiving, timeout, abort_signal)
if waiter_rc != WaiterRC.IS_SET:
self.logger.debug(f"{req_id=}: remote processing timeout {timeout=} {waiter_rc=}")
yanchengnv marked this conversation as resolved.
Show resolved Hide resolved
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: in receiving")

# receiving with progress timeout
r_future = waiter.receiving_future
self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
receiving_complete = self._future_wait(r_future, timeout)
receiving_complete = self._future_wait(r_future, timeout, abort_signal)
if not receiving_complete:
self.logger.info(f"{req_id=}: receiving timeout {timeout=}")
return self._get_result(req_id)
Expand Down
18 changes: 17 additions & 1 deletion nvflare/fuel/f3/cellnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,22 @@ def format_log_message(fqcn: str, message: Message, log: str) -> str:
return " ".join(context) + f"] {log}"


def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING) -> int:
"""Encode the payload of the specified message.

Args:
message: the message to be encoded
encoding_key: the key name of the encoding property in the message header. If the encoding property is not
set in the message header, then it means that the message payload has not been encoded. If the property is
already set, then the message payload is already encoded, and no processing is done.
If encoding is needed, we will determine the encoding scheme based on the data type of the payload:
- If the payload is None, encoding scheme is NONE
- If the payload data type is like bytes, encoding scheme is BYTES
- Otherwise, encoding scheme is FOBS, and the payload is serialized with FOBS.

Returns: the encoded payload size.

"""
encoding = message.get_header(encoding_key)
if not encoding:
if message.payload is None:
Expand All @@ -97,6 +112,7 @@ def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCOD

size = buffer_len(message.payload)
message.set_header(MessageHeaderKey.PAYLOAD_LEN, size)
return size


def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING):
Expand Down
12 changes: 9 additions & 3 deletions nvflare/fuel/f3/streaming/byte_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
StreamDataType,
StreamHeaderKey,
)
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture, StreamTaskSpec
from nvflare.fuel.f3.streaming.stream_utils import (
ONE_MB,
gen_stream_id,
Expand All @@ -49,7 +49,7 @@
log = logging.getLogger(__name__)


class TxTask:
class TxTask(StreamTaskSpec):
def __init__(
self,
cell: CoreCell,
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
self.optional = optional
self.stopped = False

self.stream_future = StreamFuture(self.sid)
self.stream_future = StreamFuture(self.sid, task_handle=self)
self.stream_future.set_size(stream.get_size())

self.window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE)
Expand Down Expand Up @@ -184,6 +184,9 @@ def stop(self, error: Optional[StreamError] = None, notify=True):

self.stopped = True

if self.task_future:
self.task_future.cancel()

if not error:
# Result is the number of bytes streamed
if self.stream_future:
Expand Down Expand Up @@ -232,6 +235,9 @@ def handle_ack(self, message: Message):
def start_task_thread(self, task_handler: Callable):
self.task_future = stream_thread_pool.submit(task_handler, self)

def cancel(self):
self.stop(error=StreamError("cancelled"))


class ByteStreamer:

Expand Down
16 changes: 14 additions & 2 deletions nvflare/fuel/f3/streaming/stream_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,23 @@ def set_index(self, index: int):
self.index = index


class StreamTaskSpec(ABC):
def cancel(self):
"""Cancel the task

Returns:

"""
pass


class StreamFuture:
"""Future class for all stream calls.

Fashioned after concurrent.futures.Future
"""

def __init__(self, stream_id: int, headers: Optional[dict] = None):
def __init__(self, stream_id: int, headers: Optional[dict] = None, task_handle: StreamTaskSpec = None):
self.stream_id = stream_id
self.headers = headers
self.waiter = threading.Event()
Expand All @@ -126,6 +136,7 @@ def __init__(self, stream_id: int, headers: Optional[dict] = None):
self.size = 0
self.progress = 0
self.done_callbacks = []
self.task_handle = task_handle

def get_stream_id(self) -> int:
return self.stream_id
Expand Down Expand Up @@ -157,7 +168,8 @@ def cancel(self):
return False

self.error = StreamCancelled(f"Stream {self.stream_id} is cancelled")

if self.task_handle:
self.task_handle.cancel()
return True

def cancelled(self):
Expand Down
71 changes: 71 additions & 0 deletions nvflare/fuel/utils/waiter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

from nvflare.apis.signal import Signal

_SMALL_WAIT = 0.01


class WaiterRC:
OK = 0
IS_SET = 1
TIMEOUT = 2
ABORTED = 3
ERROR = 4


def conditional_wait(waiter: threading.Event, timeout: float, abort_signal: Signal, condition_cb=None, **cb_kwargs):
"""Wait for an event until timeout, aborted, or some condition is met.

Args:
waiter: the event to wait
timeout: the max time to wait
abort_signal: signal to abort the wait
condition_cb: condition to check during waiting
**cb_kwargs: kwargs for the condition_cb

Returns: return code to indicate how the waiting is stopped:
IS_SET: the event is set
TIMEOUT: the event timed out
ABORTED: abort signal is triggered during the wait
ERROR: the condition_cb encountered unhandled exception
OK: only used by the condition_cb to say "all is normal"
other integers: returned by condition_cb for other conditions met

"""
wait_time = min(_SMALL_WAIT, timeout)
start = time.time()
while True:
if waiter.wait(wait_time):
# the event just happened!
return WaiterRC.IS_SET

if time.time() - start >= timeout:
return WaiterRC.TIMEOUT

# check conditions
if abort_signal and abort_signal.triggered:
return WaiterRC.ABORTED

if condition_cb:
try:
rc = condition_cb(**cb_kwargs)
if rc != WaiterRC.OK:
# a bad condition is detected by the condition_cb
# we return the rc from the condition_cb
return rc
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
except:
return WaiterRC.ERROR
5 changes: 4 additions & 1 deletion nvflare/private/aux_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def _send_multi_requests(
)

if timeout > 0:
cell_replies = cell.broadcast_multi_requests(target_messages, timeout, optional=optional, secure=secure)
cell_replies = cell.broadcast_multi_requests(
target_messages, timeout, optional=optional, secure=secure, abort_signal=fl_ctx.get_run_abort_signal()
)
return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name)
else:
cell.fire_multi_requests_and_forget(
Expand Down Expand Up @@ -395,6 +397,7 @@ def _send_to_cell(
timeout=timeout,
optional=optional,
secure=secure,
abort_signal=fl_ctx.get_run_abort_signal(),
)
return self._process_cell_replies(cell_replies, topic, channel, fqcn_to_name)
else:
Expand Down
3 changes: 2 additions & 1 deletion nvflare/private/fed/client/client_engine_executor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def fire_and_forget_aux_request(
"""Send an async request to Server via the aux channel.

Args:
topic: topic of the request
topic: topic of the request.
request: request to be sent
fl_ctx: FL context
optional: whether the request is optional
secure: whether to send the message in P2P secure mode

Returns:

Expand Down
8 changes: 7 additions & 1 deletion nvflare/private/fed/client/client_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,13 @@ def send_aux_request(

if msg_targets:
return self.aux_runner.send_aux_request(
msg_targets, topic, request, timeout, fl_ctx, optional=optional, secure=secure
msg_targets,
topic,
request,
timeout,
fl_ctx,
optional=optional,
secure=secure,
)
else:
return {}
Expand Down
Loading
Loading