diff --git a/rpipe/__init__.py b/rpipe/__init__.py index 8aa8825..dec0f10 100644 --- a/rpipe/__init__.py +++ b/rpipe/__init__.py @@ -1 +1 @@ -__version__: str = "9.6.5" # Must be "..", all numbers +__version__: str = "9.6.6" # Must be "..", all numbers diff --git a/rpipe/client/admin.py b/rpipe/client/admin.py index acef726..9a8d111 100644 --- a/rpipe/client/admin.py +++ b/rpipe/client/admin.py @@ -150,20 +150,29 @@ def unlock(self) -> None: """ self._lock(False) - def ip(self, block: str | None, unblock: str | None) -> None: - """ - Request the blocked ip addresses, or block / unblock an ip address - """ + def _block(self, name: str, block: str | None, unblock: str | None) -> None: if block is not None and unblock is not None: raise ValueError("block and unblock may not both be non-None") if block is None and unblock is None: - blocked = self._request("/admin/ip", '{"ip": null}').text - print(f"Blocked IP addresses: {blocked}") + blocked = self._request(f"/admin/{name}", f'{{"{name}": null}}').text + print(f"Blocked {name}s: {blocked}") return ban = block is not None - addr = block if ban else unblock - self._request("/admin/ip", dumps({"ip": addr, "block": ban})) - print(f"{"" if ban else "UN"}BLOCKED: {addr}") + obj = block if ban else unblock + self._request(f"/admin/{name}", dumps({name: obj, "block": ban})) + print(f"{"" if ban else "UN"}BLOCKED: {obj}") + + def ip(self, block: str | None, unblock: str | None) -> None: + """ + Request the blocked ip addresses, or block / unblock an ip address + """ + self._block("ip", block, unblock) + + def route(self, block: str | None, unblock: str | None) -> None: + """ + Request the blocked routes, or block / unblock a route + """ + self._block("route", block, unblock) class Admin: diff --git a/rpipe/client/cli.py b/rpipe/client/cli.py index ae98330..01b8b88 100644 --- a/rpipe/client/cli.py +++ b/rpipe/client/cli.py @@ -180,12 +180,13 @@ def cli() -> None: log_lvl_p.add_argument("level", default=None, nargs="?", help="The log level for the server to use") admin.add_parser("lock", help="Lock the channel") admin.add_parser("unlock", help="Unlock the channel") - ip_p = admin.add_parser("ip", help="Block / unblock ip addresses, or get a list of blocked addresses") - m_g = ip_p.add_argument_group( - "Block / Unblock a given IP", - "If none of these are passed, the command will return the list of banned IP addresses", - ).add_mutually_exclusive_group(required=False) - m_g.add_argument("--block", help="Block a given IP address") - m_g.add_argument("--unblock", help="Unblock a given IP address") + for name in ("ip", "route"): + p2 = admin.add_parser(name, help=f"Block / unblock {name}s, or get a list of blocked {name}s") + m_g = p2.add_argument_group( + f"Block / Unblock a given {name}", + f"If none of these are passed, the command will return the list of banned {name}s", + ).add_mutually_exclusive_group(required=False) + m_g.add_argument("--block", help=f"Block a given {name}") + m_g.add_argument("--unblock", help=f"Unblock a given {name}") argcomplete.autocomplete(parser) # Tab completion _cli(parser, parser.parse_args()) diff --git a/rpipe/server/admin/admin.py b/rpipe/server/admin/admin.py index 43b4e16..62d2ebb 100644 --- a/rpipe/server/admin/admin.py +++ b/rpipe/server/admin/admin.py @@ -84,21 +84,29 @@ def lock(self, state: State, body: str) -> Response: s.locked = lock return Response(f"Channel {channel} is now {lock_s}", status=200) - def ip(self, _: State, body: str) -> Response: + def _block(self, name: str, body: str) -> Response: js = loads(body.strip()) - if (addr := js["ip"]) is None: - return json_response(self._blocked.data["ips"]) - lst = self._blocked.data["ips"] + if (obj := js[name]) is None: + return json_response(getattr(self._blocked.data, f"{name}s")) + lst = getattr(self._blocked.data, f"{name}s") if js["block"]: - if addr not in lst: - lst.append(addr) + if obj not in lst: + self._log.info("Blocking %s: %s", name, obj) + lst.append(obj) self._blocked.commit() - elif addr in lst: - while addr in lst: - lst.remove(addr) + elif obj in lst: + while obj in lst: + self._log.info("Unblocking %s: %s", name, obj) + lst.remove(obj) self._blocked.commit() return Response(status=200) + def ip(self, _: State, body: str) -> Response: + return self._block("ip", body) + + def route(self, _: State, body: str) -> Response: + return self._block("route", body) + class Admin: """ diff --git a/rpipe/server/app.py b/rpipe/server/app.py index 41c6961..c1bd69e 100644 --- a/rpipe/server/app.py +++ b/rpipe/server/app.py @@ -4,11 +4,8 @@ from dataclasses import dataclass from tempfile import mkstemp from functools import wraps -from fnmatch import fnmatch from pathlib import Path import atexit -import typing -import json from flask import Response, Flask, send_file, request from zstdlib.log import CuteFormatter @@ -17,6 +14,7 @@ from ..shared import BLOCKED_EC, TRACE, restrict_umask, remote_addr, log, __version__ from .util import MAX_SIZE_HARD, MIN_VERSION, json_response, plaintext from .channel import handler, query +from .blocked import Blocked from .server import Server from .admin import Admin @@ -42,34 +40,6 @@ class ServerConfig: key_files: list[Path] -class Blocked: - _DEFAULT: dict[str, list[str]] = {"ips": [], "routes": []} - - def __init__(self, file: Path | None) -> None: - self.data = dict(self._DEFAULT) if file is None else json.loads(file.read_text()) - self.file: Path | None = file - self._lg = getLogger("Blocked") - - def commit(self) -> None: - if self.file is None: - raise ValueError("Cannot save a block file when block-file not set") - self.file.write_text(json.dumps(self.data, indent=4)) - - def __call__(self) -> bool: - if self.file is None: - return False - ip = request.headers.get("X-Forwarded-For", request.remote_addr) - if ip in self.data["ips"]: - return True - pth = request.path - if any(fnmatch(pth, i) for i in self.data["routes"]): - self._lg.info("Blocking IP %s based on route: %s", ip, pth) - self.data["ips"].append(typing.cast(str, ip)) - self.commit() - return True - return False - - class App(Flask): @dataclass(frozen=True, slots=True) @@ -258,6 +228,11 @@ def _admin_ip(o: App.Objs) -> Response: return o.admin.ip(o.server.state) +@app.route("/admin/route", admin=True) +def _admin_route(o: App.Objs) -> Response: + return o.admin.route(o.server.state) + + # Main functions diff --git a/rpipe/server/blocked.py b/rpipe/server/blocked.py new file mode 100644 index 0000000..2b76980 --- /dev/null +++ b/rpipe/server/blocked.py @@ -0,0 +1,54 @@ +from __future__ import annotations +from dataclasses import dataclass, asdict, field +from typing import TYPE_CHECKING +from logging import getLogger +from fnmatch import fnmatch +import json + +from flask import request + +from ..shared import Version, version, __version__ + +if TYPE_CHECKING: + from pathlib import Path + + +@dataclass(kw_only=True) +class Data: + version: Version = field(default_factory=lambda: Version("0.0.1")) + ips: list[str] = field(default_factory=list) + routes: list[str] = field(default_factory=list) + whitelist: list[str] = field(default_factory=list) + + +class Blocked: + MIN_VERSION = Version("9.6.6") + + def __init__(self, file: Path | None) -> None: + js = {"version": __version__} if file is None else json.loads(file.read_text()) + if (old := Version(js.pop("version", ""))) < self.MIN_VERSION: + raise ValueError(f"Blocklist version too old: {old} <= {self.MIN_VERSION}") + self.data = Data(version=version, **js) # Use new version + self.file: Path | None = file + self._lg = getLogger("Blocked") + + def commit(self) -> None: + if self.file is None: + raise ValueError("Cannot save a block file when block-file not set") + self.file.write_text(json.dumps(asdict(self.data), default=str, indent=4)) + + def __call__(self) -> bool: + if self.file is None: + return False + ip = request.headers.get("X-Forwarded-For", request.remote_addr) + if ip in self.data.whitelist: + return False + if ip in self.data.ips: + return True + pth = request.path + if any(fnmatch(pth, i) for i in self.data.routes): + self._lg.info("Blocking IP %s based on route: %s", ip, pth) + self.data.ips.append(ip) # type: ignore + self.commit() + return True + return False diff --git a/rpipe/server/server/server.py b/rpipe/server/server/server.py index 4cb85a4..2c83b8d 100644 --- a/rpipe/server/server/server.py +++ b/rpipe/server/server/server.py @@ -10,7 +10,6 @@ from .prune_thread import PruneThread from .state import State -from ...shared import version if TYPE_CHECKING: from pathlib import Path @@ -54,7 +53,7 @@ def shutdown(self): def __init__(self, debug: bool, state_file: Path | None) -> None: self._log = getLogger(_LOG) - self._log.info("Initializing server v%s", version) + self._log.info("Initializing server") self._state_file: Path | None = state_file self.state = State(debug) # Flask reloader will just relaunch this so we skip most configuration (such as persistent items)