Skip to content

Commit

Permalink
Blocklist improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
zwimer committed Feb 3, 2025
1 parent 05b7fd2 commit aa53f51
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 46 deletions.
2 changes: 1 addition & 1 deletion rpipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__: str = "9.6.6" # Must be "<major>.<minor>.<patch>", all numbers
__version__: str = "9.7.0" # Must be "<major>.<minor>.<patch>", all numbers
2 changes: 2 additions & 0 deletions rpipe/client/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from json import loads, dumps
from base64 import b85encode
from pathlib import Path
import json
import zlib

from cryptography.hazmat.primitives.serialization import load_ssh_private_key # type: ignore[attr-defined]
Expand Down Expand Up @@ -155,6 +156,7 @@ def _block(self, name: str, block: str | None, unblock: str | None) -> None:
raise ValueError("block and unblock may not both be non-None")
if block is None and unblock is None:
blocked = self._request(f"/admin/{name}", f'{{"{name}": null}}').text
blocked = json.dumps(json.loads(blocked), indent=4)
print(f"Blocked {name}s: {blocked}")
return
ban = block is not None
Expand Down
33 changes: 16 additions & 17 deletions rpipe/server/admin/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def log_level(self, state: State, body: str) -> Response:
return Response(f"{old}\n{new}", status=200, mimetype="text/plain")

@staticmethod
def stats(state: State, _: str) -> Response:
def stats(state: State, _: str, blocked: dict[str, list[list[str]]]) -> Response:
with state as s:
stats = asdict(s.stats)
return json_response(stats)
return json_response({"server": stats, "blocked": blocked})

@staticmethod
def channels(state: State, _: str) -> Response:
Expand All @@ -86,19 +86,18 @@ def lock(self, state: State, body: str) -> Response:

def _block(self, name: str, body: str) -> Response:
js = loads(body.strip())
if (obj := js[name]) is None:
return json_response(getattr(self._blocked.data, f"{name}s"))
lst = getattr(self._blocked.data, f"{name}s")
if js["block"]:
if obj not in lst:
self._log.info("Blocking %s: %s", name, obj)
lst.append(obj)
self._blocked.commit()
elif obj in lst:
while obj in lst:
self._log.info("Unblocking %s: %s", name, obj)
lst.remove(obj)
self._blocked.commit()
with self._blocked as data:
if (obj := js[name]) is None:
return json_response(getattr(data, f"{name}s"))
lst = getattr(data, f"{name}s")
if js["block"]:
if obj not in lst:
self._log.info("Blocking %s: %s", name, obj)
lst.append(obj)
elif obj in lst:
while obj in lst:
self._log.info("Unblocking %s: %s", name, obj)
lst.remove(obj)
return Response(status=200)

def ip(self, _: State, body: str) -> Response:
Expand Down Expand Up @@ -129,10 +128,10 @@ def __getattr__(self, item: str) -> Any:
if item.startswith("_"):
raise AttributeError(f"{item} is a private member")

def wrapper(state: State) -> Response:
def wrapper(state: State, *args, **kwargs) -> Response:
assert self._verify is not None, "Admin not initialized"
if isinstance(rv := self._verify(item, state), str):
return getattr(self._methods, item)(state, rv)
return getattr(self._methods, item)(state, rv, *args, **kwargs)
return rv

return wrapper
Expand Down
7 changes: 4 additions & 3 deletions rpipe/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ServerConfig:
port: int
debug: bool
state_file: Path | None
block_file: Path | None
blocklist: Path | None
key_files: list[Path]


Expand All @@ -60,7 +60,7 @@ def start(self, conf: ServerConfig, log_file: Path, favicon: Path | None):
if favicon is not None and not favicon.is_file():
lg.error("Favicon file not found: %s", favicon)
favicon = None
blocked = Blocked(conf.block_file)
blocked = Blocked(conf.blocklist, conf.debug)
admin = Admin(log_file, conf.key_files, blocked)
lg.info("Starting server version: %s", __version__)
# pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -205,7 +205,8 @@ def _admin_channels(o: App.Objs) -> Response:

@app.route("/admin/stats", admin=True)
def _admin_stats(o: App.Objs) -> Response:
return o.admin.stats(o.server.state)
with o.blocked as data:
return o.admin.stats(o.server.state, data.stats)


@app.route("/admin/log", admin=True)
Expand Down
109 changes: 85 additions & 24 deletions rpipe/server/blocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,114 @@
from dataclasses import dataclass, asdict, field
from typing import TYPE_CHECKING
from logging import getLogger
from datetime import datetime
from fnmatch import fnmatch
from threading import RLock
import atexit
import json

from werkzeug.serving import is_running_from_reloader
from flask import request

from ..shared import Version, version, __version__
from ..shared import Version, version, __version__, remote_addr

if TYPE_CHECKING:
from pathlib import Path


@dataclass(kw_only=True)
@dataclass(kw_only=True, slots=True)
class Data:
"""
Contains blocklist data such as ips, routes, whitelists, etc
"""

version: Version = field(default_factory=lambda: Version("0.0.1"))
ips: list[str] = field(default_factory=list)
routes: list[str] = field(default_factory=list)
whitelist: list[str] = field(default_factory=list)
stats: dict[str, list[list[str]]] = field(default_factory=dict)


class Blocked: # Move into server? Move stats into Stats?
"""
Used to determine if requests should be blocked or not
"""

class Blocked:
_INIT = {"version": __version__}
MIN_VERSION = Version("9.6.6")

def __init__(self, file: Path | None) -> None:
js = {"version": __version__} if file is None else json.loads(file.read_text())
def __init__(self, file: Path | None, debug: bool) -> None:
self._log = getLogger("Blocked")
js = self._INIT if file is None or not file.is_file() else json.loads(file.read_text())
if (old := Version(js.pop("version", ""))) < self.MIN_VERSION:
raise ValueError(f"Blocklist version too old: {old} <= {self.MIN_VERSION}")
self.data = Data(version=version, **js) # Use new version
self.file: Path | None = file
self._lg = getLogger("Blocked")
self._data = Data(version=version, **js) # Use new version
self._file: Path | None = file
self._lock = RLock()
# Initialize file as needed
if file is None:
self._log.warning("No blocklist is set, blocklist changes will not persist across restarts")
return
if not file.exists():
self._log.warning("Blocklist %s not found. Using defaults", file)
# Setup saving on exit
if debug and not is_running_from_reloader(): # Flask will reload the program, skip atexit
self._log.info("Skipping initialization until reload")
return
self._log.info("Installing atexit shutdown handler for saving blocklist")
atexit.register(self._save)

def __enter__(self) -> Data:
"""
Returns the Data object of Blocked
"""
self._lock.acquire()
return self._data

def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release()

def commit(self) -> None:
if self.file is None:
raise ValueError("Cannot save a block file when block-file not set")
self.file.write_text(json.dumps(asdict(self.data), default=str, indent=4))
def _save(self) -> None:
"""
Save data to blocklist
This function assumes self._file is not None
"""
if self._file is None:
self._log.critical("_save called when blocklist file is not set; changes will not persist")
return
try:
self._log.info("Saving blocklist: %s", self._file)
with self as data:
self._file.write_text(json.dumps(asdict(data), default=str, indent=4))
except OSError:
self._log.exception("Failed to save blocklist %s", self._file)

def _notate(self) -> None:
"""
Log the blocked route (should be called by __call__)
"""
ip = remote_addr()
pth = request.path
self._log.info("Blocking IP %s based on route: %s", ip, pth)
with self as data:
if ip not in data.stats:
data.stats[ip] = []
data.stats[ip].append([str(datetime.now()), pth])

def __call__(self) -> bool:
if self.file is None:
return False
"""
:return: True if the given request should be blocked
"""
ip = request.headers.get("X-Forwarded-For", request.remote_addr)
if ip in self.data.whitelist:
return False
if ip in self.data.ips:
return True
pth = request.path
if any(fnmatch(pth, i) for i in self.data.routes):
self._lg.info("Blocking IP %s based on route: %s", ip, pth)
self.data.ips.append(ip) # type: ignore
self.commit()
return True
with self as data:
if ip in self._data.whitelist:
return False
if ip in self._data.ips:
self._notate()
return True
pth = request.path
if any(fnmatch(pth, i) for i in data.routes):
data.ips.append(ip) # type: ignore
self._notate()
return True
return False
2 changes: 1 addition & 1 deletion rpipe/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cli() -> None:
)
parser.add_argument("port", type=int, help="The port waitress will listen on")
parser.add_argument("--host", default="0.0.0.0", help="The host waitress will bind to for listening")
parser.add_argument("--block-file", type=Path, help="A json of IP addresses and routes to ban")
parser.add_argument("-b", "--blocklist", type=Path, help="The blocklist configuration file")
parser.add_argument("-s", "--state-file", type=Path, help="The save state file, if desired")
parser.add_argument(
"-k",
Expand Down

0 comments on commit aa53f51

Please sign in to comment.