From 4b37b428cdb6ca1d0362fdaa4915b90644b0af40 Mon Sep 17 00:00:00 2001 From: Gary van der Merwe Date: Tue, 28 May 2024 10:54:27 +0200 Subject: [PATCH] Refactor unmarshaller: * This breaks up the unmarshaller into separate classes for better state encapsulation. * Abstract the align maths to a common method. * Simplify read_header_fields by reusing read_struct --- dbus_ezy/_private/unmarshaller.py | 565 +++++++++++++++--------------- dbus_ezy/aio/message_bus.py | 5 +- dbus_ezy/glib/message_bus.py | 2 +- dbus_ezy/message_bus.py | 4 +- test/test_fd_passing.py | 81 +++-- test/test_marshaller.py | 33 +- 6 files changed, 342 insertions(+), 348 deletions(-) diff --git a/dbus_ezy/_private/unmarshaller.py b/dbus_ezy/_private/unmarshaller.py index cb30060..ff0fa9f 100644 --- a/dbus_ezy/_private/unmarshaller.py +++ b/dbus_ezy/_private/unmarshaller.py @@ -1,9 +1,10 @@ -import array +import codecs import io import socket -import sys +from dataclasses import dataclass +from functools import partial from struct import Struct -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple from ..constants import MessageFlag, MessageType from ..errors import InvalidMessageError @@ -16,323 +17,327 @@ HeaderField, ) -MAX_UNIX_FDS = 16 +# This is unfortunately slower than the version from dbus-next. The reason why version +# from dbus-next is faster is it inlines the align, read_byte, and read_range code, +# hence reducing function calls. For now, I prefer the code reuse, as it makes the +# code more readable. +# +# Potential solutions to get back the performance: +# * Make the BodyReader methods, and the read_xxx methods cython methods. +# * Compile signature into inline python code -UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"} -UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct(" None | Message: + """Unmarshall the message. -READER_TYPE = Dict[ - str, - Tuple[ - Optional[Callable[["Unmarshaller", Signature], Any]], - Optional[str], - Optional[int], - Optional[Struct], - ], -] + The underlying read function will raise BlockingIOError + if there are not enough bytes in the buffer. This allows unmarshall + to be resumed when more data comes in over the wire. + """ + try: + self.message = None + + if self.header is None: + header_buffer, unix_fds = self.read(HEADER_SIGNATURE_SIZE) + self.unix_fds.extend(unix_fds) + if header_buffer is None: + return None + self.header = read_header(header_buffer) + + body_buffer, unix_fds = self.read(self.header.msg_len) + self.unix_fds.extend(unix_fds) + if body_buffer is None: + return None + self.message = read_body(body_buffer, self.header, self.unix_fds) + except BlockingIOError: + # print("BlockingIOError") + return None + else: + self.header = None + self.unix_fds = [] + return self.message -class MarshallerStreamEndError(Exception): - """This exception is raised when the end of the stream is reached. +def read_stream(stream: io.RawIOBase, size: int) -> Tuple[bytes, Iterable[int]]: + data = stream.read(size) + if data == b"": + raise EOFError() + return data, () - This means more data is expected on the wire that has not yet been - received. The caller should call unmarshall later when more data is - available. - """ - pass +MAX_UNIX_FDS = 16 +FD_STRUCT = Struct("I") # What about endian +FD_CMSG_LEN = socket.CMSG_LEN(MAX_UNIX_FDS * FD_STRUCT.size) -# -# Alignment padding is handled with the following formula below -# -# For any align value, the correct padding formula is: -# -# (align - (offset % align)) % align -# -# However, if align is a power of 2 (always the case here), the slow MOD -# operator can be replaced by a bitwise AND: -# -# (align - (offset & (align - 1))) & (align - 1) -# -# Which can be simplified to: -# -# (-offset) & (align - 1) -# -# -class Unmarshaller: - buf: bytearray - view: memoryview - message: Message - unpack: Dict[str, Struct] - readers: READER_TYPE +class SocketReader: + # This basically does what socket.SocketIO + BufferedReader does, but: + # 1. It won't return unless the full requested size is received + # 2. It handles receiving unix fd's - def __init__(self, stream: io.BufferedRWPair, sock=None): - self.unix_fds: List[int] = [] - self.can_cast = False - self.buf = bytearray() # Actual buffer - self.view = None # Memory view of the buffer - self.offset = 0 - self.stream = stream + def __init__(self, sock: socket.socket): self.sock = sock - self.message = None - self.readers = None - self.body_len: int | None = None - self.serial: int | None = None - self.header_len: int | None = None - self.message_type: MessageType | None = None - self.flag: MessageFlag | None = None - - def read_sock(self, length: int) -> bytes: - """reads from the socket, storing any fds sent and handling errors - from the read itself""" - unix_fd_list = array.array("i") + self.buffer: bytearray = None + self.view: memoryview = None + self.unix_fds: List[int] = None + self.recv_size: int = 0 + socket.recv_fds + + def __call__(self, size: int): + if self.buffer is None: + self.buffer = bytearray(size) + self.view = memoryview(self.buffer) + self.unix_fds = list() + self.recv_size = 0 + else: + # Due to the way this get's called by unmarshall, the buffer len should always equal the size + assert len(self.buffer) == size - try: - msg, ancdata, *_ = self.sock.recvmsg( - length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize) + while self.recv_size < size: + recv_size, ancdata, *_ = self.sock.recvmsg_into( + (self.view[self.recv_size :],), FD_CMSG_LEN ) - except BlockingIOError: - raise MarshallerStreamEndError() + if recv_size == 0: + raise EOFError() + for level, type_, data in ancdata: + if level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS: + for fd_item in FD_STRUCT.iter_unpack(data): + self.unix_fds.append(fd_item[0]) + self.recv_size += recv_size + else: + ret = self.buffer, self.unix_fds + self.buffer = None + self.view = None + self.unix_fds = None + return ret - for level, type_, data in ancdata: - if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): - continue - unix_fd_list.frombytes(data[: len(data) - (len(data) % unix_fd_list.itemsize)]) - self.unix_fds.extend(list(unix_fd_list)) - return msg +HEADER_SIGNATURE_SIZE = 16 +HEADER_FIELD_SIGNATURE = parse_single_type("(yv)") +HEADER_UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct(" None: - """ - Read from underlying socket into buffer. +UTF_8 = codecs.lookup("utf-8") +ASCII = codecs.lookup("ascii") - Raises MarshallerStreamEndError if there is not enough data to be read. +UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"} +STRUCT_BY_ENDIAN_DBUS_TYPE: Dict[Tuple[int, str], Struct] = { + (endian, dbus_type): Struct(f"{UNPACK_SYMBOL[endian]}{ctype}") + for endian in (BIG_ENDIAN, LITTLE_ENDIAN) + for dbus_type, ctype in ( + ("n", "h"), # int16 + ("q", "H"), # uint16 + ("i", "i"), # int32 + ("u", "I"), # uint32 + ("x", "q"), # int64 + ("t", "Q"), # uint64 + ("d", "d"), # double + ("h", "I"), # uint32 + ) +} - :arg offset: - The offset to read to. If not enough bytes are available in the - buffer, read more from it. - :returns: - None - """ - start_len = len(self.buf) - missing_bytes = offset - (start_len - self.offset) - if self.sock is None: - data = self.stream.read(missing_bytes) - else: - data = self.read_sock(missing_bytes) - if data == b"": - raise EOFError() - if data is None: - raise MarshallerStreamEndError() - self.buf.extend(data) - if len(data) + start_len != offset: - raise MarshallerStreamEndError() - - def read_boolean(self, _=None): - return bool(self.read_argument(UINT32_SIGNATURE)) - - def read_string(self, _=None): - str_length = self.read_argument(UINT32_SIGNATURE) - str_start = self.offset - # read terminating '\0' byte as well (str_length + 1) - self.offset += str_length + 1 - return self.buf[str_start : str_start + str_length].decode() +@dataclass(slots=True, init=False) +class Header: + endian: int + message_type: MessageType + flag: MessageFlag + protocol_version: int + body_len: int + serial: int + header_len: int + msg_len: int + + +def read_header(buffer: bytes): + """Read the header of the message.""" + + # Signature is of the header is + # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) + header = Header() + header.endian = buffer[0] + header.message_type = MessageType(buffer[1]) + header.flag = MessageFlag(buffer[2]) + header.protocol_version = buffer[3] + + if header.endian != LITTLE_ENDIAN and header.endian != BIG_ENDIAN: + raise InvalidMessageError( + f"Expecting endianness as the first byte, got {header.endian} from {buffer}" + ) + if header.protocol_version != PROTOCOL_VERSION: + raise InvalidMessageError(f"got unknown protocol version: {header.protocol_version}") + + header.body_len, header.serial, header.header_len = HEADER_UNPACK_LENGTHS[ + header.endian + ].unpack_from(buffer, 4) + header.msg_len = header.header_len + (-header.header_len & 7) + header.body_len # align 8 + return header + + +def read_body(buffer: bytes, header: Header, unix_fds: Sequence[int]): + """Read the body of the message.""" + body_reader = BodyReader(buffer, header.endian) + header_fields = dict(body_reader.read_header_fields(header.header_len)) + signature = parse_signature(header_fields.get(HeaderField.SIGNATURE, "")) + body_reader.align(8) + body = [body_reader.read_item(t) for t in signature.children] if header.body_len else [] + + return Message( + destination=header_fields.get(HeaderField.DESTINATION), + path=header_fields.get(HeaderField.PATH), + interface=header_fields.get(HeaderField.INTERFACE), + member=header_fields.get(HeaderField.MEMBER), + message_type=header.message_type, + flags=header.flag, + error_name=header_fields.get(HeaderField.ERROR_NAME), + reply_serial=header_fields.get(HeaderField.REPLY_SERIAL), + sender=header_fields.get(HeaderField.SENDER), + unix_fds=unix_fds, + signature=signature, + body=body, + serial=header.serial, + ) + + +class BodyReader: + slots = ["buffer", "offset", "endian"] + + def __init__(self, buffer: bytes, endian: int): + self.buffer = memoryview(buffer) + self.endian = endian + self.offset = 0 + + def align(self, align: int): + # Alignment padding is handled with the following formula below + # + # For any align value, the correct padding formula is: + # + # (align - (offset % align)) % align + # + # However, if align is a power of 2 (always the case here), the slow MOD + # operator can be replaced by a bitwise AND: + # + # (align - (offset & (align - 1))) & (align - 1) + # + # Which can be simplified to: + # + # -offset & (align - 1) + self.offset += -self.offset & (align - 1) + + def read_range(self, size: int): + start = self.offset + self.offset = self.offset + size + ret = self.buffer[start : self.offset] + # print(f"read_range {start=:02x} {self.offset=:02x} {ret=}") + return ret + + def read_byte(self, _=None): + ret = self.buffer[self.offset] + # print(f"read_range {self.offset=:02x} {ret=}") + self.offset += 1 + return ret + + def read_header_fields(self, header_len: int): + # Header fields are always a(yv) + while self.offset < header_len - 1: + field_id, field_value = self.read_struct(HEADER_FIELD_SIGNATURE) + yield HeaderField(field_id), field_value.value + + def read_uint32(self) -> int: + struct = STRUCT_BY_ENDIAN_DBUS_TYPE[(self.endian, "u")] + self.align(struct.size) + buffer = self.read_range(struct.size) + return struct.unpack_from(buffer)[0] + + def read_item(self, signature: Signature) -> Any: + """Dispatch to an argument reader or cast/unpack a C type.""" + type_code = signature.type_code + + if ctype_struct := STRUCT_BY_ENDIAN_DBUS_TYPE.get((self.endian, type_code)): + self.align(ctype_struct.size) + buffer = self.read_range(ctype_struct.size) + return ctype_struct.unpack_from(buffer)[0] + + if complex_reader := self.COMPLEX_PARSERS.get(type_code): + return complex_reader(self, signature) + + def read_boolean(self, _): + return bool(self.read_uint32()) + + def read_string(self, _): + string_length = self.read_uint32() + bytes_ = self.read_range(string_length) + # Check for the terminating '\0' + assert self.read_byte() == 0 + return UTF_8.decode(bytes_)[0] def read_signature(self, _=None): - signature_len = self.view[self.offset] # byte - o = self.offset + 1 - # read terminating '\0' byte as well (str_length + 1) - self.offset = o + signature_len + 1 - return self.buf[o : o + signature_len].decode() + signature_len = self.read_byte() + bytes_ = self.read_range(signature_len) + # Check for the terminating '\0' + assert self.read_byte() == 0 + return ASCII.decode(bytes_)[0] def read_variant(self, _=None): signature = parse_single_type(self.read_signature()) # verify in Variant is only useful on construction not unmarshalling - return Variant(signature, self.read_argument(signature), verify=False) + return Variant(signature, self.read_item(signature), verify=False) def read_struct(self, signature: Signature): - self.offset += -self.offset & 7 # align 8 - return [self.read_argument(child_type) for child_type in signature.children] - - def read_dict_entry(self, signature: Signature): - self.offset += -self.offset & 7 # align 8 - return self.read_argument(signature.children[0]), self.read_argument(signature.children[1]) + self.align(8) + return [self.read_item(child_type) for child_type in signature.children] def read_array(self, signature: Signature): - self.offset += -self.offset & 3 # align 4 for the array - array_length = self.read_argument(UINT32_SIGNATURE) + array_length = self.read_uint32() - child_type = signature.children[0] - if child_type.type_code in "xtd{(": - # the first alignment is not included in the array size - self.offset += -self.offset & 7 # align 8 + child_signature = signature.children[0] - if child_type.type_code == "y": - self.offset += array_length - return self.buf[self.offset - array_length : self.offset] + if child_signature.type_code == "y": + return self.read_range(array_length).tobytes() - beginning_offset = self.offset + if child_signature.type_code in "xtd{(": + # the first alignment is not included in the array size, so align before + # calculating stop_offset + self.align(8) - if child_type.type_code == "{": + stop_offset = self.offset + array_length - 1 + + if child_signature.type_code == "{": result_dict = {} - while self.offset - beginning_offset < array_length: - key, value = self.read_dict_entry(child_type) + while self.offset < stop_offset: + self.align(8) + key = self.read_item(child_signature.children[0]) + value = self.read_item(child_signature.children[1]) result_dict[key] = value return result_dict result_list = [] - while self.offset - beginning_offset < array_length: - result_list.append(self.read_argument(child_type)) + while self.offset < stop_offset: + result_list.append(self.read_item(child_signature)) return result_list - def read_argument(self, signature: Signature) -> Any: - """Dispatch to an argument reader or cast/unpack a C type.""" - type_code = signature.type_code - reader, ctype, size, struct = self.readers[type_code] - if reader: # complex type - return reader(self, signature) - self.offset += size + (-self.offset & (size - 1)) # align - if self.can_cast: - return self.view[self.offset - size : self.offset].cast(ctype)[0] - return struct.unpack_from(self.view, self.offset - size)[0] - - def header_fields(self, header_length): - """Header fields are always a(yv).""" - beginning_offset = self.offset - headers = {} - while self.offset - beginning_offset < header_length: - # Now read the y (byte) of struct (yv) - self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte - field_0 = self.view[self.offset - 1] - - # Now read the v (variant) of struct (yv) - signature_len = self.view[self.offset] # byte - o = self.offset + 1 - self.offset += signature_len + 2 # one for the byte, one for the '\0' - signature = parse_single_type(self.buf[o : o + signature_len].decode()) - headers[HeaderField(field_0).name] = self.read_argument(signature) - return headers - - def _read_header(self): - """Read the header of the message.""" - # Signature is of the header is - # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) - self.read_to_offset(HEADER_SIGNATURE_SIZE) - buffer = self.buf - endian = buffer[0] - self.message_type = MessageType(buffer[1]) - self.flag = MessageFlag(buffer[2]) - protocol_version = buffer[3] - - if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN: - raise InvalidMessageError( - f"Expecting endianness as the first byte, got {endian} from {buffer}" - ) - if protocol_version != PROTOCOL_VERSION: - raise InvalidMessageError(f"got unknown protocol version: {protocol_version}") - - self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[endian].unpack_from(buffer, 4) - self.msg_len = self.header_len + (-self.header_len & 7) + self.body_len # align 8 - if IS_BIG_ENDIAN and endian == BIG_ENDIAN: - self.can_cast = True - elif IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN: - self.can_cast = True - self.readers = self._readers_by_type[endian] - - def _read_body(self): - """Read the body of the message.""" - self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len) - self.view = memoryview(self.buf) - self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION - header_fields = self.header_fields(self.header_len) - self.offset += -self.offset & 7 # align 8 - signature = parse_signature(header_fields.get(HeaderField.SIGNATURE.name, "")) - self.message = Message( - destination=header_fields.get(HEADER_DESTINATION), - path=header_fields.get(HEADER_PATH), - interface=header_fields.get(HEADER_INTERFACE), - member=header_fields.get(HEADER_MEMBER), - message_type=self.message_type, - flags=self.flag, - error_name=header_fields.get(HEADER_ERROR_NAME), - reply_serial=header_fields.get(HEADER_REPLY_SERIAL), - sender=header_fields.get(HEADER_SENDER), - unix_fds=self.unix_fds, - signature=signature, - body=[self.read_argument(t) for t in signature.children] if self.body_len else [], - serial=self.serial, - ) - - def unmarshall(self): - """Unmarshall the message. - - The underlying read function will raise MarshallerStreamEndError - if there are not enough bytes in the buffer. This allows unmarshall - to be resumed when more data comes in over the wire. - """ - try: - if not self.message_type: - self._read_header() - self._read_body() - except MarshallerStreamEndError: - return None - return self.message - - _complex_parsers: Dict[ - str, Tuple[Callable[["Unmarshaller", Signature], Any], None, None, None] - ] = { - "b": (read_boolean, None, None, None), - "o": (read_string, None, None, None), - "s": (read_string, None, None, None), - "g": (read_signature, None, None, None), - "a": (read_array, None, None, None), - "(": (read_struct, None, None, None), - "{": (read_dict_entry, None, None, None), - "v": (read_variant, None, None, None), - } - - _ctype_by_endian: Dict[int, Dict[str, Tuple[None, str, int, Struct]]] = { - endian: { - dbus_type: ( - None, - *ctype_size, - Struct(f"{UNPACK_SYMBOL[endian]}{ctype_size[0]}"), - ) - for dbus_type, ctype_size in DBUS_TO_CTYPE.items() - } - for endian in (BIG_ENDIAN, LITTLE_ENDIAN) - } - - _readers_by_type: Dict[int, READER_TYPE] = { - BIG_ENDIAN: {**_ctype_by_endian[BIG_ENDIAN], **_complex_parsers}, - LITTLE_ENDIAN: {**_ctype_by_endian[LITTLE_ENDIAN], **_complex_parsers}, + COMPLEX_PARSERS: Dict[str, Callable[["BodyReader", Signature], Any]] = { + "y": read_byte, + "b": read_boolean, + "o": read_string, + "s": read_string, + "g": read_signature, + "a": read_array, + "(": read_struct, + "v": read_variant, } diff --git a/dbus_ezy/aio/message_bus.py b/dbus_ezy/aio/message_bus.py index 131941e..557c925 100644 --- a/dbus_ezy/aio/message_bus.py +++ b/dbus_ezy/aio/message_bus.py @@ -428,10 +428,7 @@ async def _authenticate(self): break def _create_unmarshaller(self): - sock = None - if self._negotiate_unix_fd: - sock = self._sock - return Unmarshaller(self._stream, sock) + return Unmarshaller(self._sock) def _finalize(self, err=None): try: diff --git a/dbus_ezy/glib/message_bus.py b/dbus_ezy/glib/message_bus.py index 4a5333d..8f5edd3 100644 --- a/dbus_ezy/glib/message_bus.py +++ b/dbus_ezy/glib/message_bus.py @@ -45,7 +45,7 @@ def dispatch(self, callback, user_data): try: while self.bus._stream.readable(): if not self.unmarshaller: - self.unmarshaller = Unmarshaller(self.bus._stream) + self.unmarshaller = Unmarshaller(self.bus._sock) if self.unmarshaller.unmarshall(): callback(self.unmarshaller.message) diff --git a/dbus_ezy/message_bus.py b/dbus_ezy/message_bus.py index 86da9f7..3f73dbc 100644 --- a/dbus_ezy/message_bus.py +++ b/dbus_ezy/message_bus.py @@ -633,7 +633,7 @@ def _setup_socket(self): try: self._sock.connect(filename) - self._sock.setblocking(False) + self._sock.settimeout(0) break except Exception as e: err = e @@ -650,7 +650,7 @@ def _setup_socket(self): try: self._sock.connect((ip_addr, ip_port)) - self._sock.setblocking(False) + self._sock.settimeout(0) break except Exception as e: err = e diff --git a/test/test_fd_passing.py b/test/test_fd_passing.py index d22e913..df56c63 100644 --- a/test/test_fd_passing.py +++ b/test/test_fd_passing.py @@ -1,7 +1,8 @@ """This tests the ability to send and receive file descriptors in dbus messages""" import os -from asyncio import get_event_loop +from asyncio import Future, get_event_loop +from contextlib import AsyncExitStack, contextmanager import pytest @@ -70,50 +71,48 @@ def assert_fds_equal(fd1, fd2): assert stat1.st_rdev == stat2.st_rdev +@contextmanager +def fd_closer(fd): + try: + yield fd + finally: + os.close(fd) + + @pytest.mark.asyncio async def test_sending_file_descriptor_low_level(): - bus1 = await MessageBus(negotiate_unix_fd=True).connect() - bus2 = await MessageBus(negotiate_unix_fd=True).connect() + async with AsyncExitStack() as stack: + sender_bus = await stack.enter_async_context(MessageBus(negotiate_unix_fd=True)) + reviver_bus = await stack.enter_async_context(MessageBus(negotiate_unix_fd=True)) + + sender_fd = stack.enter_context(fd_closer(open_file())) + sender_msg = Message( + destination=sender_bus.unique_name, + path="/org/test/path", + interface="org.test.iface", + member="SomeMember", + body=[0], + signature="h", + unix_fds=[sender_fd], + ) - fd_before = open_file() - fd_after = None - - msg = Message( - destination=bus1.unique_name, - path="/org/test/path", - interface="org.test.iface", - member="SomeMember", - body=[0], - signature="h", - unix_fds=[fd_before], - ) + receiver_msg_fut = Future() - def message_handler(sent): - nonlocal fd_after - if sent.sender == bus2.unique_name and sent.serial == msg.serial: - assert sent.path == msg.path - assert sent.serial == msg.serial - assert sent.interface == msg.interface - assert sent.member == msg.member - assert sent.body == [0] - assert len(sent.unix_fds) == 1 - fd_after = sent.unix_fds[0] - bus1.send(Message.new_method_return(sent, "s", ["got it"])) - bus1.remove_message_handler(message_handler) - return True - - bus1.add_message_handler(message_handler) - - reply = await bus2.call(msg) - assert reply.body == ["got it"] - assert fd_after is not None - - assert_fds_equal(fd_before, fd_after) - - for fd in [fd_before, fd_after]: - os.close(fd) - for bus in [bus1, bus2]: - bus.disconnect() + def message_handler(msg): + nonlocal receiver_msg_fut + if msg.sender == reviver_bus.unique_name and msg.serial == sender_msg.serial: + receiver_msg_fut.set_result(msg) + return True + + sender_bus.add_message_handler(message_handler) + + await reviver_bus.send(sender_msg) + receiver_msg = await receiver_msg_fut + + assert len(receiver_msg.unix_fds) == 1 + receiver_fd = stack.enter_context(fd_closer(receiver_msg.unix_fds[0])) + + assert_fds_equal(sender_fd, receiver_fd) @pytest.mark.asyncio diff --git a/test/test_marshaller.py b/test/test_marshaller.py index ca3089a..0265166 100644 --- a/test/test_marshaller.py +++ b/test/test_marshaller.py @@ -3,6 +3,7 @@ import os from dataclasses import dataclass from pprint import pprint +from socket import socketpair import pytest @@ -114,7 +115,7 @@ def test_unmarshall(item: MessageExample): def test_unmarshall_can_resume(): """Verify resume works.""" - bluez_rssi_message = ( + bluez_rssi_message = bytes.fromhex( "6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465" "765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b" "746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e" @@ -122,27 +123,19 @@ def test_unmarshall_can_resume(): "110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001" "6e00a7ff000000000000" ) - message_bytes = bytes.fromhex(bluez_rssi_message) + chunks = [bluez_rssi_message[i : i + 4] for i in range(0, len(bluez_rssi_message), 4)] - class SlowStream(io.IOBase): - """A fake stream that will only give us one byte at a time.""" + send_sock, recv_sock = socketpair() + recv_sock.settimeout(0) + unmarshaller = Unmarshaller(recv_sock) + for chunk in chunks[:-1]: + send_sock.send(chunk) + message = unmarshaller.unmarshall() + assert message is None - def __init__(self): - self.data = message_bytes - self.pos = 0 - - def read(self, n) -> bytes: - data = self.data[self.pos : self.pos + 1] - self.pos += 1 - return data - - stream = SlowStream() - unmarshaller = Unmarshaller(stream) - - for _ in range(len(bluez_rssi_message)): - if unmarshaller.unmarshall(): - break - assert unmarshaller.message is not None + send_sock.send(chunks[-1]) + message = unmarshaller.unmarshall() + assert message is not None def test_ay_buffer():