diff --git a/lib/hidapi/hidapi_impl.py b/lib/hidapi/hidapi_impl.py index 062e47e7d9..a6fe80616c 100644 --- a/lib/hidapi/hidapi_impl.py +++ b/lib/hidapi/hidapi_impl.py @@ -206,15 +206,18 @@ def __init__(self, device_callback, polling_delay=5.0): self.device_callback = device_callback self.polling_delay = polling_delay self.prev_devices = None + self.alive = False + self.abort_triggered = False # daemon threads are automatically killed when main thread exits super().__init__(daemon=True) def run(self): + self.alive = True # Populate initial set of devices so startup doesn't cause any callbacks self.prev_devices = {tuple(dev.items()): dev for dev in _enumerate_devices()} # Continously enumerate devices and raise callback for changes - while True: + while not self.abort_triggered: current_devices = {tuple(dev.items()): dev for dev in _enumerate_devices()} for key, device in self.prev_devices.items(): if key not in current_devices: @@ -225,6 +228,11 @@ def run(self): self.prev_devices = current_devices sleep(self.polling_delay) + self.alive = False + + def stop(self): + self.abort_triggered = True + def _match( action: str, diff --git a/lib/logitech_receiver/listener.py b/lib/logitech_receiver/listener.py index 44c7487e4f..4641a679b4 100644 --- a/lib/logitech_receiver/listener.py +++ b/lib/logitech_receiver/listener.py @@ -19,25 +19,42 @@ import queue import threading -from . import base +from typing import Any +from typing import Protocol + from . import exceptions logger = logging.getLogger(__name__) +class LowLevelInterface(Protocol): + def open_path(self, path): + ... + + def ping(self, handle, number, long_message=False): + ... + + def make_notification(self, report_id: int, devnumber: int, data: bytes) -> Any: + ... + + def close(self, handle): + ... + + class _ThreadedHandle: """A thread-local wrapper with different open handles for each thread. Closing a ThreadedHandle will close all handles. """ - __slots__ = ("path", "_local", "_handles", "_listener") + __slots__ = ("path", "_local", "_handles", "_listener", "_base") - def __init__(self, listener, path, handle): + def __init__(self, listener, path, handle, low_level_api: LowLevelInterface): assert listener is not None assert path is not None assert handle is not None assert isinstance(handle, int) + self._base = low_level_api self._listener = listener self.path = path self._local = threading.local() @@ -46,7 +63,7 @@ def __init__(self, listener, path, handle): self._handles = [handle] def _open(self): - handle = base.open_path(self.path) + handle = self._base.open_path(self.path) if handle is None: logger.error("%r failed to open new handle", self) else: @@ -63,7 +80,7 @@ def close(self): if logger.isEnabledFor(logging.DEBUG): logger.debug("%r closing %s", self, handles) for h in handles: - base.close(h) + self._base.close(h) @property def notifications_hook(self): @@ -112,12 +129,13 @@ class EventsListener(threading.Thread): Incoming packets will be passed to the callback function in sequence. """ - def __init__(self, receiver, notifications_callback): + def __init__(self, receiver, notifications_callback, low_level: LowLevelInterface): try: path_name = receiver.path.split("/")[2] except IndexError: path_name = receiver.path super().__init__(name=self.__class__.__name__ + ":" + path_name) + self._base = low_level self.daemon = True self._active = False self.receiver = receiver @@ -127,7 +145,7 @@ def __init__(self, receiver, notifications_callback): def run(self): self._active = True # replace the handle with a threaded one - self.receiver.handle = _ThreadedHandle(self, self.receiver.path, self.receiver.handle) + self.receiver.handle = _ThreadedHandle(self, self.receiver.path, self.receiver.handle, self._base) if logger.isEnabledFor(logging.INFO): logger.info("started with %s (%d)", self.receiver, int(self.receiver.handle)) self.has_started() @@ -139,13 +157,13 @@ def run(self): while self._active: if self._queued_notifications.empty(): try: - n = base.read(self.receiver.handle, _EVENT_READ_TIMEOUT) + n = self._base.read(self.receiver.handle, _EVENT_READ_TIMEOUT) except exceptions.NoReceiver: logger.warning("%s disconnected", self.receiver.name) self.receiver.close() break if n: - n = base.make_notification(*n) + n = self._base.make_notification(*n) else: n = self._queued_notifications.get() # deliver any queued notifications if n: diff --git a/lib/solaar/listener.py b/lib/solaar/listener.py index 76a4b1afde..f18c8336bf 100644 --- a/lib/solaar/listener.py +++ b/lib/solaar/listener.py @@ -22,6 +22,8 @@ from collections import namedtuple from functools import partial +from typing import Any +from typing import Protocol import gi import logitech_receiver @@ -58,12 +60,26 @@ def _ghost(device): ) +class LowLevelInterface(Protocol): + def open_path(self, path): + ... + + def ping(self, handle, number, long_message=False): + ... + + def make_notification(self, report_id: int, devnumber: int, data: bytes) -> Any: + ... + + def close(self, handle): + ... + + class SolaarListener(listener.EventsListener): """Keeps the status of a Receiver or Device (member name is receiver but it can also be a device).""" - def __init__(self, receiver, status_changed_callback): + def __init__(self, receiver, status_changed_callback, low_level): assert status_changed_callback - super().__init__(receiver, self._notifications_handler) + super().__init__(receiver, self._notifications_handler, low_level) self.status_changed_callback = status_changed_callback receiver.status_callback = self._status_changed @@ -275,7 +291,7 @@ def _start(device_info): receiver_.cleanups.append(_cleanup_bluez_dbus) if receiver_: - rl = SolaarListener(receiver_, _status_callback) + rl = SolaarListener(receiver_, _status_callback, base) rl.start() _all_listeners[device_info.path] = rl return rl diff --git a/tests/integrationtests/test_device_monitor.py b/tests/integrationtests/test_device_monitor.py index 699fd6c9f1..0b19f45d5a 100644 --- a/tests/integrationtests/test_device_monitor.py +++ b/tests/integrationtests/test_device_monitor.py @@ -9,17 +9,17 @@ def test_device_monitor(mocker): from hidapi.hidapi_impl import DeviceMonitor mock_callback = mocker.Mock() - monitor = DeviceMonitor(device_callback=mock_callback, polling_delay=1) + monitor = DeviceMonitor(device_callback=mock_callback, polling_delay=0.1) monitor.start() - while not monitor.is_alive(): - time.sleep(0.1) + while not monitor.alive: + time.sleep(0.01) assert monitor.alive monitor.stop() - while monitor.is_alive(): - time.sleep(0.1) + while monitor.alive: + time.sleep(0.01) assert not monitor.alive diff --git a/tests/integrationtests/test_events_listener.py b/tests/integrationtests/test_events_listener.py index 8c08c85ae6..4e68770e02 100644 --- a/tests/integrationtests/test_events_listener.py +++ b/tests/integrationtests/test_events_listener.py @@ -3,14 +3,14 @@ def test_events_listener(mocker): receiver = mocker.MagicMock() + receiver.handle = 1 + receiver.path = "pathname" status_callback = mocker.MagicMock() + low_level_mock = mocker.MagicMock() - e = EventsListener(receiver, status_callback) + e = EventsListener(receiver, status_callback, low_level_mock) e.start() assert bool(e) e.stop() - - assert not bool(e) - assert status_callback.call_count == 0 diff --git a/tests/integrationtests/test_solaar_listener.py b/tests/integrationtests/test_solaar_listener.py index 166d601b07..e9e415ffe9 100644 --- a/tests/integrationtests/test_solaar_listener.py +++ b/tests/integrationtests/test_solaar_listener.py @@ -1,15 +1,19 @@ from solaar.listener import SolaarListener +# @pytest.mark.skip(reason="Unstable") def test_solaar_listener(mocker): receiver = mocker.MagicMock() - receiver.handle = 1 + receiver.handle = mocker.MagicMock() receiver.path = "dsda" status_callback = mocker.MagicMock() + low_level_mock = mocker.MagicMock() - rl = SolaarListener(receiver, status_callback) - # rl.run() - # rl.stop() + rl = SolaarListener(receiver, status_callback, low_level_mock) + rl.start() + rl.stop() + + rl.join() assert not rl.is_alive() assert status_callback.call_count == 0 diff --git a/tests/integrationtests/test_task_runner.py b/tests/integrationtests/test_task_runner.py index 0ff7c3e894..d36e3f850f 100644 --- a/tests/integrationtests/test_task_runner.py +++ b/tests/integrationtests/test_task_runner.py @@ -7,9 +7,10 @@ def run_task(): def test_task_runner(mocker): tr = tasks.TaskRunner(name="Testrunner") - tr.queue.put((run_task, {}, {})) - # tr.run() - # tr.stop() - # assert tr.alive - # tr.stop() - # assert not tr.alive + tr.start() + assert tr.alive + + tr(run_task) + + tr.stop() + assert not tr.alive diff --git a/tests/unittests/logitech_receiver/test_notifications.py b/tests/unittests/logitech_receiver/test_notifications.py index 9cc8dd3338..426f42fb0a 100644 --- a/tests/unittests/logitech_receiver/test_notifications.py +++ b/tests/unittests/logitech_receiver/test_notifications.py @@ -22,6 +22,9 @@ def ping(self, handle, number, long_message=False): def request(self, handle, devnumber, request_id, *params, **kwargs): pass + def close(self): + pass + @pytest.mark.parametrize( "sub_id, notification_data, expected_error, expected_new_device",