diff --git a/ophyd/v2/core.py b/ophyd/v2/core.py index 559c41d75..897abd65e 100644 --- a/ophyd/v2/core.py +++ b/ophyd/v2/core.py @@ -327,25 +327,24 @@ async def _on_exit(self) -> None: async def _wait_for_tasks(self, tasks: Dict[asyncio.Task, str]): done, pending = await asyncio.wait(tasks, timeout=self._timeout) + + # Handle all devices where connection has timed out if pending: - msg = f"{len(pending)} Devices did not connect:" + logging.error(f"{len(pending)} Devices did not connect:") for t in pending: t.cancel() with suppress(Exception): await t - e = t.exception() - msg += f"\n {tasks[t]}: {type(e).__name__}" - lines = str(e).splitlines() - if len(lines) <= 1: - msg += f": {e}" - else: - msg += "".join(f"\n {line}" for line in lines) - logging.error(msg) + logging.exception(f" {tasks[t]}:", exc_info=t.exception()) + + # Handle all devices where connection has raised an error before + # timeout raised = [t for t in done if t.exception()] if raised: logging.error(f"{len(raised)} Devices raised an error:") for t in raised: logging.exception(f" {tasks[t]}:", exc_info=t.exception()) + if pending or raised: raise NotConnected("Not all Devices connected") diff --git a/ophyd/v2/tests/test_core.py b/ophyd/v2/tests/test_core.py index aa2566151..f2b653cdf 100644 --- a/ophyd/v2/tests/test_core.py +++ b/ophyd/v2/tests/test_core.py @@ -1,9 +1,10 @@ import asyncio +import logging import re import time import traceback from enum import Enum -from typing import Any, Callable, Sequence, Tuple, Type +from typing import Any, Callable, Sequence, Tuple, Type, cast from unittest.mock import Mock import bluesky.plan_stubs as bps @@ -18,10 +19,12 @@ Device, DeviceCollector, DeviceVector, + NotConnected, Signal, SignalBackend, SignalRW, SimSignalBackend, + StandardReadable, T, get_device_children, set_and_wait_for_value, @@ -195,6 +198,19 @@ async def connect(self, sim=False): self.connected = True +class DummyDeviceThatErrorsWhenConnecting(Device): + async def connect(self, sim: bool = False): + raise IOError("Connection failed") + + +class DummyDeviceThatTimesOutWhenConnecting(StandardReadable): + async def connect(self, sim: bool = False): + try: + await asyncio.Future() + except asyncio.CancelledError: + raise NotConnected("source: foo") + + class DummyDeviceGroup(Device): def __init__(self, name: str) -> None: self.child1 = DummyBaseDevice() @@ -205,6 +221,25 @@ def __init__(self, name: str) -> None: self.set_name(name) +class DummyDeviceGroupThatTimesOut(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatTimesOutWhenConnecting() + self.set_name(name) + + +class DummyDeviceGroupThatErrors(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatErrorsWhenConnecting() + self.set_name(name) + + +class DummyDeviceGroupThatErrorsAndTimesOut(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatErrorsWhenConnecting() + self.child2 = DummyDeviceThatTimesOutWhenConnecting() + self.set_name(name) + + def test_get_device_children(): parent = DummyDeviceGroup("parent") @@ -246,6 +281,131 @@ async def test_device_with_device_collector(): assert parent.dict_with_children[123].connected +@pytest.mark.parametrize( + "device_constructor", + [ + DummyDeviceThatErrorsWhenConnecting, + DummyDeviceThatTimesOutWhenConnecting, + DummyDeviceGroupThatErrors, + DummyDeviceGroupThatTimesOut, + DummyDeviceGroupThatErrorsAndTimesOut, + ], +) +async def test_device_collector_propagates_errors_and_timeouts( + device_constructor: Callable[[str], Device] +): + await _assert_failing_device_does_not_connect(device_constructor) + + +@pytest.mark.parametrize( + "device_constructor_1,device_constructor_2", + [ + (DummyDeviceThatErrorsWhenConnecting, DummyDeviceThatTimesOutWhenConnecting), + (DummyDeviceGroupThatErrors, DummyDeviceGroupThatTimesOut), + (DummyDeviceGroupThatErrors, DummyDeviceGroupThatErrorsAndTimesOut), + (DummyDeviceThatErrorsWhenConnecting, DummyDeviceGroupThatErrors), + ], +) +async def test_device_collector_propagates_errors_and_timeouts_from_multiple_devices( + device_constructor_1: Callable[[str], Device], + device_constructor_2: Callable[[str], Device], +): + await _assert_failing_devices_do_not_connect( + device_constructor_1, + device_constructor_2, + ) + + +async def test_device_collector_logs_exceptions_for_raised_errors( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_device_does_not_connect(DummyDeviceGroupThatErrors) + assert caplog.records[0].message == "1 Devices raised an error:" + assert caplog.records[1].message == " should_fail:" + assert_exception_type_and_message( + caplog.records[1], + OSError, + "Connection failed", + ) + + +async def test_device_collector_logs_exceptions_for_timeouts( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_device_does_not_connect(DummyDeviceGroupThatTimesOut) + assert caplog.records[0].message == "1 Devices did not connect:" + assert caplog.records[1].message == " should_fail:" + assert_exception_type_and_message( + caplog.records[1], + NotConnected, + "child1: source: foo", + ) + + +async def test_device_collector_logs_exceptions_for_multiple_devices( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_devices_do_not_connect( + DummyDeviceGroupThatErrorsAndTimesOut, DummyDeviceGroupThatErrors + ) + assert caplog.records[0].message == "1 Devices did not connect:" + assert caplog.records[1].message == " should_fail_1:" + assert_exception_type_and_message( + caplog.records[1], + OSError, + "Connection failed", + ) + assert caplog.records[2].message == "1 Devices raised an error:" + assert caplog.records[3].message == " should_fail_2:" + assert_exception_type_and_message( + caplog.records[3], + OSError, + "Connection failed", + ) + + +async def _assert_failing_device_does_not_connect( + device_constructor: Callable[[str], Device] +) -> pytest.ExceptionInfo[NotConnected]: + with pytest.raises(NotConnected) as excepton_info: + async with DeviceCollector( + sim=False, + timeout=1.0, + ): + should_fail = device_constructor("should_fail") # noqa: F841 + return excepton_info + + +async def _assert_failing_devices_do_not_connect( + device_constructor_1: Callable[[str], Device], + device_constructor_2: Callable[[str], Device], +) -> pytest.ExceptionInfo[NotConnected]: + with pytest.raises(NotConnected) as excepton_info: + async with DeviceCollector( + sim=False, + timeout=1.0, + ): + should_fail_1 = device_constructor_1("should_fail_1") # noqa: F841 + should_fail_2 = device_constructor_2("should_fail_2") # noqa: F841 + return excepton_info + + +def assert_exception_type_and_message( + record: logging.LogRecord, + expected_type: Type[Exception], + expected_message: str, +): + exception_type, exception, _ = cast( + Tuple[Type[Exception], Exception, str], + record.exc_info, + ) + assert expected_type is exception_type + assert (expected_message,) == exception.args + + async def normal_coroutine(time: float): await asyncio.sleep(time) diff --git a/ophyd/v2/tests/test_epicsdemo.py b/ophyd/v2/tests/test_epicsdemo.py index 1038f1a1b..8dd70eaf5 100644 --- a/ophyd/v2/tests/test_epicsdemo.py +++ b/ophyd/v2/tests/test_epicsdemo.py @@ -1,6 +1,7 @@ import asyncio -from typing import Dict -from unittest.mock import Mock, call, patch +import logging +from typing import Dict, Tuple, Type, cast +from unittest.mock import Mock, call import pytest from bluesky.protocols import Reading @@ -130,18 +131,25 @@ async def test_mover_disconncted(): assert m.name == "mover" -async def test_sensor_disconncted(): - with patch("ophyd.v2.core.logging") as mock_logging: - with pytest.raises(NotConnected, match="Not all Devices connected"): - async with DeviceCollector(timeout=0.1): - s = epicsdemo.Sensor("ca://PRE:", name="sensor") - mock_logging.error.assert_called_once_with( - """\ -1 Devices did not connect: - s: NotConnected - value: ca://PRE:Value - mode: ca://PRE:Mode""" - ) +async def test_sensor_disconncted(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.INFO) + with pytest.raises(NotConnected, match="Not all Devices connected"): + async with DeviceCollector(timeout=0.1): + s = epicsdemo.Sensor("ca://PRE:", name="sensor") + + # Check log messages + assert caplog.records[0].message == "1 Devices did not connect:" + assert caplog.records[1].message == " s:" + + # Check logged exception + exception_type, exception, _ = cast( + Tuple[Type[Exception], Exception, str], + caplog.records[1].exc_info, + ) + assert NotConnected is exception_type + assert ("value: ca://PRE:Value", "mode: ca://PRE:Mode") == exception.args + + # Ensure correct device assert s.name == "sensor"