Skip to content

Commit

Permalink
Add --timeout, --checksum, and --total + refactor pbar into progress
Browse files Browse the repository at this point in the history
  • Loading branch information
zwimer committed Oct 29, 2024
1 parent 6b40165 commit 012b47d
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 142 deletions.
2 changes: 1 addition & 1 deletion rpipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__: str = "9.4.0" # Must be "<major>.<minor>.<patch>", all numbers
__version__: str = "9.5.0" # Must be "<major>.<minor>.<patch>", all numbers
9 changes: 9 additions & 0 deletions rpipe/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def cli() -> None:
help="Pipe TTL in seconds; use server default if not passed",
)
write_g.add_argument( # Do not use default= for better error checking w.r.t. plaintext mode
"-Z",
"--zstd",
metavar="[1-22]",
choices=range(1, 23),
type=int,
help="Compression level to use; invalid in plaintext mode",
)
write_g.add_argument(
"-j",
"--threads",
metavar=f"[1-{cpu}]" if cpu > 1 else "1",
choices=range(1, cpu + 1),
Expand All @@ -92,10 +94,17 @@ def cli() -> None:
read_write_g.add_argument(
"-P", "--progress", metavar="SIZE", type=_si, default=False, const=True, nargs="?", help=msg
)
read_write_g.add_argument(
"-Y", "--total", action="store_true", help="Print the total number of bytes sent/received"
)
read_write_g.add_argument(
"-K", "--checksum", action="store_true", help="Checksum the data being sent/received"
)
# Config options
config = parser.add_argument_group("Configuration")
config.add_argument("-u", "--url", help="The pipe url to use")
config.add_argument("-c", "--channel", help="The channel to use")
config.add_argument("-T", "--timeout", type=float, help="The timeout for the HTTP requests")
config.add_argument(
"-k",
"--key-file",
Expand Down
28 changes: 14 additions & 14 deletions rpipe/client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
from logging import getLogger
from json import dumps

from zstandard import ZstdCompressor

from ...shared import TRACE, QueryEC, Version, version
from .util import REQUEST_TIMEOUT, request
from .errors import UsageError, VersionError
from .data import Config, Mode
from .delete import delete
from .util import request
from .recv import recv
from .send import send

if TYPE_CHECKING:
from pathlib import Path
from .data import Result, Config, Mode


_LOG: str = "client"
_DEFAULT_LVL: int = 3


def _print_config(conf: Config, config_file: Path) -> None:
Expand All @@ -36,7 +33,7 @@ def _print_config(conf: Config, config_file: Path) -> None:
def _check_outdated(conf: Config) -> None:
log = getLogger(_LOG)
log.info("Mode: Outdated")
r = request("GET", f"{conf.url}/supported")
r = request("GET", f"{conf.url}/supported", timeout=conf.timeout)
if not r.ok:
raise RuntimeError(f"Failed to get server minimum version: {r}")
info = r.json()
Expand All @@ -51,7 +48,7 @@ def _query(conf: Config) -> None:
if not conf.channel:
raise UsageError("Channel unknown; try again with --channel")
log.info("Querying channel %s ...", conf.channel)
r = request("GET", f"{conf.url}/q/{conf.channel}")
r = request("GET", f"{conf.url}/q/{conf.channel}", timeout=conf.timeout)
log.debug("Got response %s", r)
log.log(TRACE, "Data: %s", r.content)
match r.status_code:
Expand Down Expand Up @@ -82,7 +79,7 @@ def _priority_actions(conf: Config, mode: Mode, config_file: Path) -> None:
_check_outdated(conf)
if mode.server_version:
log.info("Mode: Server Version")
r = request("GET", f"{conf.url}/version")
r = request("GET", f"{conf.url}/version", timeout=conf.timeout)
if not r.ok:
raise RuntimeError(f"Failed to get version: {r}")
print(f"rpipe_server {r.text}")
Expand All @@ -107,13 +104,16 @@ def rpipe(conf: Config, mode: Mode, config_file: Path) -> None:
if mode.zstd is not None:
raise UsageError("Cannot compress data in plaintext mode")
# Invoke mode
log.info("HTTP timeout set to %d seconds", REQUEST_TIMEOUT)
log.info("HTTP timeout set to: %s", "DEFAULT" if conf.timeout is None else f"{conf.timeout} seconds")
if mode.read:
recv(conf, mode.block, mode.peek, mode.force, mode.progress)
rv: Result = recv(conf, mode)
elif mode.write:
lvl = _DEFAULT_LVL if mode.zstd is None else mode.zstd
log.debug("Using compression level %d and %d threads", lvl, mode.threads)
compress = ZstdCompressor(write_checksum=True, level=lvl, threads=mode.threads).compress
send(conf, mode.ttl, mode.progress, compress)
rv = send(conf, mode)
else:
delete(conf)
return
# Print results
if rv.checksum is not None:
print(f"Blake2s: {rv.checksum.hexdigest()}")
if rv.total is not None:
print(f"Total bytes {'sent' if mode.write else 'received'}: {rv.total}")
22 changes: 20 additions & 2 deletions rpipe/client/client/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from urllib.parse import quote
from json import loads, dumps
from logging import getLogger
from hashlib import blake2s
from pathlib import Path

from human_readable import listing
Expand All @@ -17,7 +18,21 @@
_CONFIG_LOG: str = "Config"


@dataclass(kw_only=True, frozen=True)
@dataclass(init=False, slots=True)
class Result:
"""
Result of a successful send/receive
"""

checksum: blake2s | None
total: int | None

def __init__(self, total: bool, checksum: bool) -> None:
self.checksum = blake2s(usedforsecurity=False) if checksum else None
self.total = 0 if total else None


@dataclass(kw_only=True, frozen=True, slots=True)
class Config:
"""
Information about where the remote pipe is
Expand All @@ -27,6 +42,7 @@ class Config:
url: str = ""
channel: str = ""
password: str = ""
timeout: int | None = None
key_file: Path | None = None

def __post_init__(self):
Expand Down Expand Up @@ -91,7 +107,7 @@ def __repr__(self) -> str:


# pylint: disable=too-many-instance-attributes
@dataclass(kw_only=True, frozen=True)
@dataclass(kw_only=True, frozen=True, slots=True)
class Mode:
"""
Arguments used to decide how rpipe should operate
Expand Down Expand Up @@ -119,6 +135,8 @@ class Mode:
# Read / Write options
encrypt: bool
progress: bool | int
total: bool
checksum: bool

@classmethod
def keys(cls) -> tuple[str, ...]:
Expand Down
17 changes: 1 addition & 16 deletions rpipe/client/client/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,8 @@ def delete(conf: Config) -> None:
Delete the channel
"""
getLogger("delete").info("Deleting channel %s", conf.channel)
r = request("DELETE", conf.channel_url())
r = request("DELETE", conf.channel_url(), timeout=conf.timeout)
if r.status_code == DeleteEC.locked:
raise ChannelLocked(r.text)
if not r.ok:
raise RuntimeError(r)


class DeleteOnFail:
def __init__(self, config: Config):
self.catch = KeyboardInterrupt | Exception
self.config = config
self.armed = False

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.armed and isinstance(exc_val, self.catch):
getLogger("DeleteOnFail").warning("Caught %s; deleting channel", type(exc_val))
delete(self.config)
60 changes: 0 additions & 60 deletions rpipe/client/client/pbar.py

This file was deleted.

63 changes: 63 additions & 0 deletions rpipe/client/client/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Self
from logging import getLogger
import sys

from tqdm.contrib.logging import logging_redirect_tqdm
import tqdm

from .delete import delete
from .data import Result

if TYPE_CHECKING:
from .data import Config, Mode


class Progress:
"""
A small class to handle progress bars and progression data of send/recv
"""

DOF_EXC = KeyboardInterrupt | Exception

__slots__ = ("result", "dof", "_config", "_redir", "_pbar")

def __init__(self, config: Config, mode: Mode):
self.result = Result(total=mode.total, checksum=mode.checksum)
self.dof: bool = False
self._config = config
self._redir = logging_redirect_tqdm()
self._pbar = tqdm.tqdm(
disable=mode.progress is False,
total=None if isinstance(mode.progress, bool) else mode.progress,
dynamic_ncols=True,
leave=False,
unit_divisor=1000,
unit_scale=True,
unit="B",
)

def __enter__(self) -> Self:
self._redir.__enter__()
self._pbar.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.dof and isinstance(exc_val, self.DOF_EXC):
log = getLogger("Progress")
log.warning("Caught %s. Deleting channel: %s", type(exc_val), self._config.channel)
delete(self._config)
r1 = bool(self._pbar.__exit__(exc_type, exc_val, exc_tb))
r2 = bool(self._redir.__exit__(exc_type, exc_val, exc_tb))
return r1 or r2

def update(self, data: bytes, *, stdout: bool = False) -> None:
self.dof = True
self._pbar.update(len(data))
if self.result.total is not None:
self.result.total += len(data)
if self.result.checksum is not None:
self.result.checksum.update(data)
if stdout:
sys.stdout.buffer.write(data)
sys.stdout.flush()
Loading

0 comments on commit 012b47d

Please sign in to comment.