diff --git a/lib/logitech_receiver/base.py b/lib/logitech_receiver/base.py index b7d4114127..2ed6aeaa34 100644 --- a/lib/logitech_receiver/base.py +++ b/lib/logitech_receiver/base.py @@ -17,6 +17,8 @@ # Base low-level functions used by the API proper. # Unlikely to be used directly unless you're expanding the API. +from __future__ import annotations + import logging import threading as _threading @@ -73,9 +75,10 @@ def other_device_check(bus_id, vendor_id, product_id): return _bt_device(product_id) -def product_information(usb_id): +def product_information(usb_id: int | str) -> dict: if isinstance(usb_id, str): usb_id = int(usb_id, 16) + for r in _RECEIVER_USB_IDS: if usb_id == r.get("product_id"): return r diff --git a/tests/logitech_receiver/test_base.py b/tests/logitech_receiver/test_base.py new file mode 100644 index 0000000000..1df0da8e92 --- /dev/null +++ b/tests/logitech_receiver/test_base.py @@ -0,0 +1,21 @@ +import pytest + +from logitech_receiver import base + + +@pytest.mark.parametrize( + "usb_id, expected_name, expected_receiver_kind", + [ + ("0xC548", "Bolt Receiver", "bolt"), + ("0xC52B", "Unifying Receiver", "unifying"), + ("0xC531", "Nano Receiver", "nano"), + ("0xC53F", "Lightspeed Receiver", None), + ("0xC517", "EX100 Receiver 27 Mhz", "27Mhz"), + ], +) +def test_product_information(usb_id, expected_name, expected_receiver_kind): + res = base.product_information(usb_id) + + assert res["name"] == expected_name + if expected_receiver_kind: + assert res["receiver_kind"] == expected_receiver_kind