From 96bcdc61b66b6ad6b8c52641dc965b2ef4576b1f Mon Sep 17 00:00:00 2001 From: zwimer Date: Sat, 24 Feb 2024 10:16:33 -0500 Subject: [PATCH] Improve typing --- rpipe/client/crypt.py | 37 +++++++++++++++++++++++-------------- rpipe/version.py | 2 +- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/rpipe/client/crypt.py b/rpipe/client/crypt.py index 1b55e92..938d887 100644 --- a/rpipe/client/crypt.py +++ b/rpipe/client/crypt.py @@ -1,3 +1,4 @@ +from typing import NamedTuple, Self import hashlib import zlib @@ -8,18 +9,26 @@ _ZLIB_LEVEL: int = 6 -def _merge(*args: bytes) -> bytes: - line1 = b" ".join(str(len(i)).encode() for i in args) + b"\n" - return b"".join([line1, *args]) +class _EncryptedData(NamedTuple): + text: bytes + salt: bytes + nonce: bytes + tag: bytes + def encode(self) -> bytes: + line1 = b" ".join(str(len(i)).encode() for i in self) + b"\n" # pylint: disable=not-an-iterable + return line1 + b"".join(self) -def _split(raw: bytes) -> list[bytes]: - ret = [] - start = raw.index(b"\n") + 1 - for i in (int(k.decode()) for k in raw[: start - 1].split(b" ")): - ret.append(raw[start : start + i]) - start += i - return ret + @classmethod + def decode(cls, raw: bytes) -> Self: + parts = [] + start = raw.index(b"\n") + 1 + for i in (int(k.decode()) for k in raw[: start - 1].split(b" ")): + parts.append(raw[start : start + i]) + start += i + if len(parts) != len(cls._fields): + raise RuntimeError("Bad encrypted data") + return cls(*parts) def _opts(password: str) -> dict: @@ -32,12 +41,12 @@ def encrypt(data: bytes, password: str | None) -> bytes: salt = get_random_bytes(AES.block_size) conf = AES.new(hashlib.scrypt(salt=salt, **_opts(password)), AES.MODE_GCM) # type: ignore text, tag = conf.encrypt_and_digest(zlib.compress(data, level=_ZLIB_LEVEL)) - return _merge(text, salt, conf.nonce, tag) + return _EncryptedData(text, salt, conf.nonce, tag).encode() def decrypt(data: bytes, password: str | None) -> bytes: if not password or not data: return data - text, salt, nonce, tag = _split(data) - aes = AES.new(hashlib.scrypt(salt=salt, **_opts(password)), AES.MODE_GCM, nonce=nonce) # type: ignore - return zlib.decompress(aes.decrypt_and_verify(text, tag)) + e = _EncryptedData.decode(data) + aes = AES.new(hashlib.scrypt(salt=e.salt, **_opts(password)), AES.MODE_GCM, nonce=e.nonce) # type: ignore + return zlib.decompress(aes.decrypt_and_verify(e.text, e.tag)) diff --git a/rpipe/version.py b/rpipe/version.py index 36616e0..5681fb2 100644 --- a/rpipe/version.py +++ b/rpipe/version.py @@ -1,7 +1,7 @@ from __future__ import annotations -__version__: str = "5.5.3" # Must be "..", all numbers +__version__: str = "5.5.4" # Must be "..", all numbers class Version: