From 838f8eb8cf8cd41a44f0e25435359bcf91f24387 Mon Sep 17 00:00:00 2001 From: zwimer Date: Mon, 15 Jan 2024 15:18:33 -0500 Subject: [PATCH] Client: use custom io class instead of sys.stdin.buffer.read as that can hang when not needed --- rpipe/client/io.py | 81 ++++++++++++++++++++++++++++++++++++++++++++ rpipe/client/send.py | 4 ++- rpipe/version.py | 2 +- 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 rpipe/client/io.py diff --git a/rpipe/client/io.py b/rpipe/client/io.py new file mode 100644 index 0000000..f46809a --- /dev/null +++ b/rpipe/client/io.py @@ -0,0 +1,81 @@ +from threading import Thread, Condition +from collections import deque +from logging import getLogger +import os + + +_LOG = "IO" + + +class IO: + """ + A better version of stdin.read that doesn't hang as often + Only meant to ever be used from a single thread at a time + Will preload about 2n bytes for reading in chunks of <= n + May use about an extra n bytes for stitching data together + """ + + def __init__(self, fd: int, n: int) -> None: + self._log = getLogger(_LOG) + self._log.debug("Constructed IO on fd %d with n=%d", fd, n) + self._buffer: deque[bytes] = deque() + self._cond = Condition() + self._eof: bool = False + self._fd: int = fd + self._n: int = n + self._log.debug("Starting up IO thread.") + self._thread = Thread(target=self, daemon=True) + self._thread.start() + + # Main Thread: + + def eof(self) -> bool: + return self._eof and not self._buffer + + def read(self) -> bytes: + """ + :param delay: sleep delay ms to allow more IO to load + :return: Up to n bytes; returns b"" only upon final read + """ + with self._cond: + self._cond.wait_for(lambda: self._buffer or self._eof) + ret: bytes = self._read() + self._cond.notify() + self._log.debug("Read %d bytes of data", len(ret)) + return ret + + def _read(self) -> bytes: + """ + A helper to read that assumes it owns self._buffer + """ + if not self._buffer: + return b"" + # Calculate how many pieces to stitch together + count = 0 + total = 0 + for i in self._buffer: + total += len(i) + count += 1 + if total > self._n: + break + if len(self._buffer) > 1: + count -= 1 + assert count > 0, "Write thread wrote too much data" + # Stitch together pieces as efficiently as possible + if count == 1: + return self._buffer.popleft() + return b"".join(self._buffer.popleft() for _ in range(count)) + + # Worker thread + + def __call__(self) -> None: + until = lambda: sum(len(i) for i in self._buffer) < self._n + while data := os.read(self._fd, self._n): # Can read in small bursts + with self._cond: + self._cond.wait_for(until) + self._buffer.append(data) + self._cond.notify() + with self._cond: + self._eof = True + self._cond.notify() + self._log.debug("IO has terminated successfully") diff --git a/rpipe/client/send.py b/rpipe/client/send.py index aabbe94..3c9868e 100644 --- a/rpipe/client/send.py +++ b/rpipe/client/send.py @@ -9,6 +9,7 @@ from .errors import MultipleClients, ReportThis, VersionError from .util import WAIT_DELAY_SEC, request, channel_url from .crypt import encrypt +from .io import IO if TYPE_CHECKING: from requests import Response @@ -64,7 +65,8 @@ def send(config: ValidConfig) -> None: log.debug("Writing to channel %s with block size of %s", config.channel, block_size) # Send params.stream_id = headers.stream_id - while block := sys.stdin.buffer.read(block_size): + io = IO(sys.stdin.fileno(), block_size) + while block := io.read(): _send_block(encrypt(block, config.password), config, params) # Finalize params.final = True diff --git a/rpipe/version.py b/rpipe/version.py index f9dedc8..455c0b6 100644 --- a/rpipe/version.py +++ b/rpipe/version.py @@ -1,7 +1,7 @@ from __future__ import annotations -__version__: str = "5.1.0" # Must be "..", all numbers +__version__: str = "5.2.0" # Must be "..", all numbers class Version: