Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
MattHag committed Nov 6, 2024
1 parent f2e4e8c commit 46a8fd8
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 32 deletions.
10 changes: 9 additions & 1 deletion lib/hidapi/hidapi_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,18 @@ def __init__(self, device_callback, polling_delay=5.0):
self.device_callback = device_callback
self.polling_delay = polling_delay
self.prev_devices = None
self.alive = False
self.abort_triggered = False
# daemon threads are automatically killed when main thread exits
super().__init__(daemon=True)

def run(self):
self.alive = True
# Populate initial set of devices so startup doesn't cause any callbacks
self.prev_devices = {tuple(dev.items()): dev for dev in _enumerate_devices()}

# Continously enumerate devices and raise callback for changes
while True:
while not self.abort_triggered:
current_devices = {tuple(dev.items()): dev for dev in _enumerate_devices()}
for key, device in self.prev_devices.items():
if key not in current_devices:
Expand All @@ -225,6 +228,11 @@ def run(self):
self.prev_devices = current_devices
sleep(self.polling_delay)

self.alive = False

def stop(self):
self.abort_triggered = True


def _match(
action: str,
Expand Down
36 changes: 27 additions & 9 deletions lib/logitech_receiver/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,42 @@
import queue
import threading

from . import base
from typing import Any
from typing import Protocol

from . import exceptions

logger = logging.getLogger(__name__)


class LowLevelInterface(Protocol):
def open_path(self, path):
...

def ping(self, handle, number, long_message=False):
...

def make_notification(self, report_id: int, devnumber: int, data: bytes) -> Any:
...

def close(self, handle):
...


class _ThreadedHandle:
"""A thread-local wrapper with different open handles for each thread.
Closing a ThreadedHandle will close all handles.
"""

__slots__ = ("path", "_local", "_handles", "_listener")
__slots__ = ("path", "_local", "_handles", "_listener", "_base")

def __init__(self, listener, path, handle):
def __init__(self, listener, path, handle, low_level_api: LowLevelInterface):
assert listener is not None
assert path is not None
assert handle is not None
assert isinstance(handle, int)

self._base = low_level_api
self._listener = listener
self.path = path
self._local = threading.local()
Expand All @@ -46,7 +63,7 @@ def __init__(self, listener, path, handle):
self._handles = [handle]

def _open(self):
handle = base.open_path(self.path)
handle = self._base.open_path(self.path)
if handle is None:
logger.error("%r failed to open new handle", self)
else:
Expand All @@ -63,7 +80,7 @@ def close(self):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("%r closing %s", self, handles)
for h in handles:
base.close(h)
self._base.close(h)

@property
def notifications_hook(self):
Expand Down Expand Up @@ -112,12 +129,13 @@ class EventsListener(threading.Thread):
Incoming packets will be passed to the callback function in sequence.
"""

def __init__(self, receiver, notifications_callback):
def __init__(self, receiver, notifications_callback, low_level: LowLevelInterface):
try:
path_name = receiver.path.split("/")[2]
except IndexError:
path_name = receiver.path
super().__init__(name=self.__class__.__name__ + ":" + path_name)
self._base = low_level
self.daemon = True
self._active = False
self.receiver = receiver
Expand All @@ -127,7 +145,7 @@ def __init__(self, receiver, notifications_callback):
def run(self):
self._active = True
# replace the handle with a threaded one
self.receiver.handle = _ThreadedHandle(self, self.receiver.path, self.receiver.handle)
self.receiver.handle = _ThreadedHandle(self, self.receiver.path, self.receiver.handle, self._base)
if logger.isEnabledFor(logging.INFO):
logger.info("started with %s (%d)", self.receiver, int(self.receiver.handle))
self.has_started()
Expand All @@ -139,13 +157,13 @@ def run(self):
while self._active:
if self._queued_notifications.empty():
try:
n = base.read(self.receiver.handle, _EVENT_READ_TIMEOUT)
n = self._base.read(self.receiver.handle, _EVENT_READ_TIMEOUT)
except exceptions.NoReceiver:
logger.warning("%s disconnected", self.receiver.name)
self.receiver.close()
break
if n:
n = base.make_notification(*n)
n = self._base.make_notification(*n)
else:
n = self._queued_notifications.get() # deliver any queued notifications
if n:
Expand Down
22 changes: 19 additions & 3 deletions lib/solaar/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from collections import namedtuple
from functools import partial
from typing import Any
from typing import Protocol

import gi
import logitech_receiver
Expand Down Expand Up @@ -58,12 +60,26 @@ def _ghost(device):
)


class LowLevelInterface(Protocol):
def open_path(self, path):
...

def ping(self, handle, number, long_message=False):
...

def make_notification(self, report_id: int, devnumber: int, data: bytes) -> Any:
...

def close(self, handle):
...


class SolaarListener(listener.EventsListener):
"""Keeps the status of a Receiver or Device (member name is receiver but it can also be a device)."""

def __init__(self, receiver, status_changed_callback):
def __init__(self, receiver, status_changed_callback, low_level):
assert status_changed_callback
super().__init__(receiver, self._notifications_handler)
super().__init__(receiver, self._notifications_handler, low_level)
self.status_changed_callback = status_changed_callback
receiver.status_callback = self._status_changed

Expand Down Expand Up @@ -275,7 +291,7 @@ def _start(device_info):
receiver_.cleanups.append(_cleanup_bluez_dbus)

if receiver_:
rl = SolaarListener(receiver_, _status_callback)
rl = SolaarListener(receiver_, _status_callback, base)
rl.start()
_all_listeners[device_info.path] = rl
return rl
Expand Down
10 changes: 5 additions & 5 deletions tests/integrationtests/test_device_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ def test_device_monitor(mocker):
from hidapi.hidapi_impl import DeviceMonitor

mock_callback = mocker.Mock()
monitor = DeviceMonitor(device_callback=mock_callback, polling_delay=1)
monitor = DeviceMonitor(device_callback=mock_callback, polling_delay=0.1)
monitor.start()

while not monitor.is_alive():
time.sleep(0.1)
while not monitor.alive:
time.sleep(0.01)

assert monitor.alive

monitor.stop()

while monitor.is_alive():
time.sleep(0.1)
while monitor.alive:
time.sleep(0.01)

assert not monitor.alive
8 changes: 4 additions & 4 deletions tests/integrationtests/test_events_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

def test_events_listener(mocker):
receiver = mocker.MagicMock()
receiver.handle = 1
receiver.path = "pathname"
status_callback = mocker.MagicMock()
low_level_mock = mocker.MagicMock()

e = EventsListener(receiver, status_callback)
e = EventsListener(receiver, status_callback, low_level_mock)
e.start()

assert bool(e)

e.stop()

assert not bool(e)
assert status_callback.call_count == 0
12 changes: 8 additions & 4 deletions tests/integrationtests/test_solaar_listener.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from solaar.listener import SolaarListener


# @pytest.mark.skip(reason="Unstable")
def test_solaar_listener(mocker):
receiver = mocker.MagicMock()
receiver.handle = 1
receiver.handle = mocker.MagicMock()
receiver.path = "dsda"
status_callback = mocker.MagicMock()
low_level_mock = mocker.MagicMock()

rl = SolaarListener(receiver, status_callback)
# rl.run()
# rl.stop()
rl = SolaarListener(receiver, status_callback, low_level_mock)
rl.start()
rl.stop()

rl.join()

assert not rl.is_alive()
assert status_callback.call_count == 0
13 changes: 7 additions & 6 deletions tests/integrationtests/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ def run_task():

def test_task_runner(mocker):
tr = tasks.TaskRunner(name="Testrunner")
tr.queue.put((run_task, {}, {}))
# tr.run()
# tr.stop()
# assert tr.alive
# tr.stop()
# assert not tr.alive
tr.start()
assert tr.alive

tr(run_task)

tr.stop()
assert not tr.alive
3 changes: 3 additions & 0 deletions tests/unittests/logitech_receiver/test_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def ping(self, handle, number, long_message=False):
def request(self, handle, devnumber, request_id, *params, **kwargs):
pass

def close(self):
pass


@pytest.mark.parametrize(
"sub_id, notification_data, expected_error, expected_new_device",
Expand Down

0 comments on commit 46a8fd8

Please sign in to comment.