Skip to content

Commit

Permalink
refactor: add class that handles request records
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Nov 25, 2024
1 parent 1d43713 commit 3b7a324
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 123 deletions.
34 changes: 24 additions & 10 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import collections
import itertools
import os
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar

import mocket.inject
from mocket.recording import MocketRecordStorage

# NOTE this is here for backwards-compat to keep old import-paths working
# from mocket.socket import MocketSocket as MocketSocket
Expand All @@ -20,21 +22,28 @@ class Mocket:
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
_requests: ClassVar[list] = []
_namespace: ClassVar[str] = str(id(_entries))
_truesocket_recording_dir: ClassVar[str | None] = None
_record_storage: ClassVar[MocketRecordStorage | None] = None

@classmethod
def enable(
cls,
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError
if namespace is None:
namespace = str(id(cls._entries))

cls._namespace = namespace
cls._truesocket_recording_dir = truesocket_recording_dir
if truesocket_recording_dir is not None:
recording_dir = Path(truesocket_recording_dir)

if not recording_dir.is_dir():
# JSON dumps will be saved here
raise AssertionError

cls._record_storage = MocketRecordStorage(
directory=recording_dir,
namespace=namespace,
)

mocket.inject.enable()

Expand Down Expand Up @@ -87,6 +96,7 @@ def reset(cls) -> None:
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []
cls._record_storage = None

@classmethod
def last_request(cls):
Expand All @@ -107,12 +117,16 @@ def has_requests(cls) -> bool:
return bool(cls.request_list())

@classmethod
def get_namespace(cls) -> str:
return cls._namespace
def get_namespace(cls) -> str | None:
if not cls._record_storage:
return None
return cls._record_storage.namespace

@classmethod
def get_truesocket_recording_dir(cls) -> str | None:
return cls._truesocket_recording_dir
if not cls._record_storage:
return None
return str(cls._record_storage.directory)

@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
Expand Down
172 changes: 172 additions & 0 deletions mocket/recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import binascii
import contextlib
import hashlib
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path

from devtools import debug

from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.types import Address

hash_function = hashlib.md5

with contextlib.suppress(ImportError):
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32

hash_function = xxhash_cffi_xxh32

with contextlib.suppress(ImportError):
from xxhash import xxh32 as xxhash_xxh32

hash_function = xxhash_xxh32


def _hash_prepare_request(data: bytes) -> bytes:
_data = decode_from_bytes(data)
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))


def _hash_request(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hash_function(_data).hexdigest()


def _hash_request_fallback(data: bytes) -> str:
_data = _hash_prepare_request(data)
return hashlib.md5(_data).hexdigest()


def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
True
"""
bs = decode_from_bytes(binascii.hexlify(binary_string).upper())
return " ".join(a + b for a, b in zip(bs[::2], bs[1::2]))


def hexload(string: str) -> bytes:
r"""
>>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
True
"""
string_no_spaces = "".join(string.split())
return encode_to_bytes(binascii.unhexlify(string_no_spaces))


@dataclass
class MocketRecord:
host: str
port: int
request: bytes
response: bytes


class MocketRecordStorage:
def __init__(self, directory: Path, namespace: str) -> None:
self._directory = directory
self._namespace = namespace
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
defaultdict(defaultdict)
)

self._load()

@property
def directory(self) -> Path:
return self._directory

@property
def namespace(self) -> str:
return self._namespace

@property
def file(self) -> Path:
return self._directory / f"{self._namespace}.json"

def _load(self) -> None:
if not self.file.exists():
return

json_data = self.file.read_text()
records = json.loads(json_data)
for host, port_signature_record in records.items():
for port, signature_record in port_signature_record.items():
for signature, record in signature_record.items():
# NOTE backward-compat
try:
request_data = hexload(record["request"])
except binascii.Error:
request_data = record["request"]

self._records[(host, int(port))][signature] = MocketRecord(
host=host,
port=port,
request=request_data,
response=hexload(record["response"]),
)

def _save(self) -> None:
# FIXME change name and type
d = defaultdict(lambda: defaultdict(defaultdict))
for address, signature_record in self._records.items():
host, port = address
for signature, record in signature_record.items():
d[host][str(port)][signature] = dict(
request=decode_from_bytes(record.request),
response=hexdump(record.response),
)

json_data = json.dumps(d, indent=4, sort_keys=True)
self.file.parent.mkdir(exist_ok=True)
self.file.write_text(json_data)

def get_records(self, address: Address) -> list[MocketRecord]:
return list(self._records[address].values())

def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
# FIXME encode should not be required
request = encode_to_bytes(request)

# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
return self._records[address].get(request_signature_fallback)

request_signature = _hash_request(request)
if request_signature in self._records[address]:
return self._records[address][request_signature]

return None

def put_record(
self,
address: Address,
request: bytes,
response: bytes,
) -> None:
# FIXME encode should not be required
request = encode_to_bytes(request)

host, port = address
record = MocketRecord(
host=host,
port=port,
request=request,
response=response,
)

# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
self._records[address][request_signature_fallback] = record
return

request_signature = _hash_request(request)
self._records[address][request_signature] = record
self._save()
Loading

0 comments on commit 3b7a324

Please sign in to comment.