Skip to content

Commit

Permalink
Client: use custom io class instead of sys.stdin.buffer.read as that …
Browse files Browse the repository at this point in the history
…can hang when not needed
  • Loading branch information
zwimer committed Jan 15, 2024
1 parent 6a573a7 commit 838f8eb
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
81 changes: 81 additions & 0 deletions rpipe/client/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from threading import Thread, Condition
from collections import deque
from logging import getLogger
import os


_LOG = "IO"


class IO:
"""
A better version of stdin.read that doesn't hang as often
Only meant to ever be used from a single thread at a time
Will preload about 2n bytes for reading in chunks of <= n
May use about an extra n bytes for stitching data together
"""

def __init__(self, fd: int, n: int) -> None:
self._log = getLogger(_LOG)
self._log.debug("Constructed IO on fd %d with n=%d", fd, n)
self._buffer: deque[bytes] = deque()
self._cond = Condition()
self._eof: bool = False
self._fd: int = fd
self._n: int = n
self._log.debug("Starting up IO thread.")
self._thread = Thread(target=self, daemon=True)
self._thread.start()

# Main Thread:

def eof(self) -> bool:
return self._eof and not self._buffer

def read(self) -> bytes:
"""
:param delay: sleep delay ms to allow more IO to load
:return: Up to n bytes; returns b"" only upon final read
"""
with self._cond:
self._cond.wait_for(lambda: self._buffer or self._eof)
ret: bytes = self._read()
self._cond.notify()
self._log.debug("Read %d bytes of data", len(ret))
return ret

def _read(self) -> bytes:
"""
A helper to read that assumes it owns self._buffer
"""
if not self._buffer:
return b""
# Calculate how many pieces to stitch together
count = 0
total = 0
for i in self._buffer:
total += len(i)
count += 1
if total > self._n:
break
if len(self._buffer) > 1:
count -= 1
assert count > 0, "Write thread wrote too much data"
# Stitch together pieces as efficiently as possible
if count == 1:
return self._buffer.popleft()
return b"".join(self._buffer.popleft() for _ in range(count))

# Worker thread

def __call__(self) -> None:
until = lambda: sum(len(i) for i in self._buffer) < self._n
while data := os.read(self._fd, self._n): # Can read in small bursts
with self._cond:
self._cond.wait_for(until)
self._buffer.append(data)
self._cond.notify()
with self._cond:
self._eof = True
self._cond.notify()
self._log.debug("IO has terminated successfully")
4 changes: 3 additions & 1 deletion rpipe/client/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .errors import MultipleClients, ReportThis, VersionError
from .util import WAIT_DELAY_SEC, request, channel_url
from .crypt import encrypt
from .io import IO

if TYPE_CHECKING:
from requests import Response
Expand Down Expand Up @@ -64,7 +65,8 @@ def send(config: ValidConfig) -> None:
log.debug("Writing to channel %s with block size of %s", config.channel, block_size)
# Send
params.stream_id = headers.stream_id
while block := sys.stdin.buffer.read(block_size):
io = IO(sys.stdin.fileno(), block_size)
while block := io.read():
_send_block(encrypt(block, config.password), config, params)
# Finalize
params.final = True
Expand Down
2 changes: 1 addition & 1 deletion rpipe/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


__version__: str = "5.1.0" # Must be "<major>.<minor>.<patch>", all numbers
__version__: str = "5.2.0" # Must be "<major>.<minor>.<patch>", all numbers


class Version:
Expand Down

0 comments on commit 838f8eb

Please sign in to comment.