From 7b2350d7aacc15da32f3b4c544263a40fec61805 Mon Sep 17 00:00:00 2001 From: Devin Burke Date: Tue, 3 Sep 2024 15:42:35 +0200 Subject: [PATCH] Refactored into PerSignalConfig, a simpler dict-like dataclass that will co-operate with compile time typecheckers. --- src/ophyd_async/core/__init__.py | 5 +- src/ophyd_async/core/_readable.py | 80 +++++++------ src/ophyd_async/core/_readable_config.py | 97 ---------------- tests/core/test_readable.py | 140 ++++++++++------------- 4 files changed, 108 insertions(+), 214 deletions(-) delete mode 100644 src/ophyd_async/core/_readable_config.py diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 7cd9755857..d40e7e3a6f 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -43,8 +43,7 @@ UUIDFilenameProvider, YMDPathProvider, ) -from ._readable import ConfigSignal, HintedSignal, StandardReadable -from ._readable_config import ReadableDeviceConfig +from ._readable import ConfigSignal, HintedSignal, PerSignalConfig, StandardReadable from ._signal import ( Signal, SignalR, @@ -159,5 +158,5 @@ "get_unique", "in_micros", "wait_for_connection", - "ReadableDeviceConfig", + "PerSignalConfig", ] diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index 94e6d37544..8055196634 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -1,14 +1,27 @@ import asyncio import warnings +from collections.abc import MutableMapping from contextlib import contextmanager -from typing import Callable, Dict, Generator, Optional, Sequence, Tuple, Type, Union - -from bluesky.protocols import DataKey, HasHints, Hints, Reading +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +from bluesky.protocols import DataKey, HasHints, Hints, Preparable, Reading from ._device import Device, DeviceVector from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable -from ._readable_config import ReadableDeviceConfig -from ._signal import SignalR +from ._signal import SignalR, SignalW from ._status import AsyncStatus from ._utils import merge_gathered_dicts @@ -17,9 +30,35 @@ Callable[[ReadableChild], ReadableChild], Type["ConfigSignal"], Type["HintedSignal"] ] +T = TypeVar("T") + + +@dataclass +class PerSignalConfig(MutableMapping): + _signal_configuration: Dict[SignalW[Any], Any] = field(default_factory=dict) + + @property + def signal_configuration(self) -> Dict[SignalW[Any], Any]: + return self._signal_configuration + + def __setitem__(self, signal: SignalW[T], value: T): + self._signal_configuration[signal] = value + + def __getitem__(self, signal: SignalW[T]) -> T: + return self._signal_configuration[signal] + + def __delitem__(self, signal: SignalW[T]): + del self._signal_configuration[signal] + + def __iter__(self) -> Iterator[SignalW[Any]]: + return iter(self._signal_configuration) + + def __len__(self) -> int: + return len(self._signal_configuration) + class StandardReadable( - Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints + Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints, Preparable ): """Device that owns its children and provides useful default behavior. @@ -213,35 +252,12 @@ def add_readables( self._has_hints += (obj,) @AsyncStatus.wrap - async def prepare(self, value: ReadableDeviceConfig) -> None: + async def prepare(self, config: PerSignalConfig) -> None: tasks = [] - for dtype, signals in value.signals.items(): - for signal_name, (expected_dtype, val) in signals.items(): - if hasattr(self, signal_name): - attr = getattr(self, signal_name) - if isinstance(attr, (HintedSignal, ConfigSignal)): - attr = attr.signal - if attr._backend.datatype == expected_dtype: # noqa: SLF001 - if val is not None: - tasks.append(attr.set(val)) - else: - raise TypeError( - f"Expected value of type {expected_dtype} for attribute" - f" '{signal_name}'," - f" got {type(attr._backend.datatype)}" # noqa: SLF001 - ) + for sig, value in config.items(): + tasks.append(sig.set(value)) await asyncio.gather(*tasks) - def get_config(self) -> ReadableDeviceConfig: - config = ReadableDeviceConfig() - for readable in self._configurables: - if isinstance(readable, Union[ConfigSignal, HintedSignal]): - readable = readable.signal - name = readable.name.split("-")[-1] - dtype = readable._backend.datatype # noqa: SLF001 - config.add_attribute(name, dtype) - return config - class ConfigSignal(AsyncConfigurable): def __init__(self, signal: ReadableChild) -> None: diff --git a/src/ophyd_async/core/_readable_config.py b/src/ophyd_async/core/_readable_config.py deleted file mode 100644 index 687e8e3e55..0000000000 --- a/src/ophyd_async/core/_readable_config.py +++ /dev/null @@ -1,97 +0,0 @@ -from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Tuple, Type - -import numpy as np - - -@dataclass -class ReadableDeviceConfig: - int_signals: Dict[str, Tuple[int, Any]] = field(default_factory=dict) - float_signals: Dict[str, Tuple[float, Any]] = field(default_factory=dict) - str_signals: Dict[str, Tuple[str, Any]] = field(default_factory=dict) - bool_signals: Dict[str, Tuple[bool, Any]] = field(default_factory=dict) - list_signals: Dict[str, Tuple[list, Any]] = field(default_factory=dict) - tuple_signals: Dict[str, Tuple[tuple, Any]] = field(default_factory=dict) - dict_signals: Dict[str, Tuple[dict, Any]] = field(default_factory=dict) - set_signals: Dict[str, Tuple[set, Any]] = field(default_factory=dict) - frozenset_signals: Dict[str, Tuple[frozenset, Any]] = field(default_factory=dict) - bytes_signals: Dict[str, Tuple[bytes, Any]] = field(default_factory=dict) - bytearray_signals: Dict[str, Tuple[bytearray, Any]] = field(default_factory=dict) - complex_signals: Dict[str, Tuple[complex, Any]] = field(default_factory=dict) - none_signals: Dict[str, Tuple[type(None), Any]] = field(default_factory=dict) - ndarray_signals: Dict[str, Tuple[np.ndarray, Any]] = field(default_factory=dict) - signals: Dict[Type[Any], Dict[str, Tuple[Type[Any], Any]]] = field(init=False) - - def __post_init__(self): - self.signals = { - int: self.int_signals, - float: self.float_signals, - str: self.str_signals, - bool: self.bool_signals, - list: self.list_signals, - tuple: self.tuple_signals, - dict: self.dict_signals, - set: self.set_signals, - frozenset: self.frozenset_signals, - bytes: self.bytes_signals, - bytearray: self.bytearray_signals, - complex: self.complex_signals, - type(None): self.none_signals, - np.ndarray: self.ndarray_signals, - } - self.attr_map = dict(self.signals.items()) - - class SignalAccessor: - def __init__(self, parent: "ReadableDeviceConfig", key: str): - self.parent = parent - self.key = key - - def __getitem__(self, dtype: Type[Any]) -> Any: - signals = self.parent.signals.get(dtype) - if signals and self.key in signals: - return signals[self.key][1] - raise AttributeError( - f"'{self.parent.__class__.__name__}' object has no attribute" - f" '{self.key}[{dtype.__name__}]'" - ) - - def __setitem__(self, dtype: Type[Any], value: Any) -> None: - signals = self.parent.signals.get(dtype) - if signals and self.key in signals: - expected_type, _ = signals[self.key] - if isinstance(value, expected_type) or value is None: - signals[self.key] = (expected_type, value) - else: - raise TypeError( - f"Expected value of type {expected_type} for attribute" - f" '{self.key}', got {type(value)}" - ) - else: - raise KeyError( - f"Key '{self.key}' not found in {self.parent.attr_map[dtype]}" - ) - - def __getattr__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor": - return self.SignalAccessor(self, key) - - def add_attribute(self, name: str, dtype: Type[Any], value: Any = None) -> None: - if value is not None and not isinstance(value, dtype): - raise TypeError( - f"Expected value of type {dtype} for attribute '{name}'," - f" got {type(value)}" - ) - self.signals[dtype][name] = (dtype, value) - - def __setattr__(self, key: str, value: Any) -> None: - if "attr_map" in self.__dict__: - raise AttributeError( - f"Cannot set attribute '{key}' directly. Use 'add_attribute' method." - ) - else: - super().__setattr__(key, value) - - def __getitem__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor": - return self.SignalAccessor(self, key) - - def items(self): - return asdict(self).items() diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index 3f3d2a75bf..b6e99598dc 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -15,9 +15,10 @@ DeviceVector, HintedSignal, MockSignalBackend, - ReadableDeviceConfig, + PerSignalConfig, SignalR, SignalRW, + SignalW, SoftSignalBackend, StandardReadable, soft_signal_r_and_setter, @@ -245,8 +246,8 @@ def test_standard_readable_add_children_multi_nested(): @pytest.fixture -def readable_device_config(): - return ReadableDeviceConfig() +def standard_readable_config(): + return PerSignalConfig() test_data = [ @@ -267,97 +268,72 @@ def readable_device_config(): ] -@pytest.mark.parametrize("name,dtype,value", test_data) -def test_add_attribute(readable_device_config, name, dtype, value): - readable_device_config.add_attribute(name, dtype, value) - assert name in readable_device_config.signals[dtype] - assert readable_device_config.signals[dtype][name] == (dtype, value) +def test_config_initialization(standard_readable_config): + assert len(standard_readable_config) == 0 -@pytest.mark.parametrize("name,dtype,value", test_data) -def test_get_attribute(readable_device_config, name, dtype, value): - readable_device_config.add_attribute(name, dtype, value) - if isinstance(value, np.ndarray): - assert np.array_equal(readable_device_config[name][dtype], value) +@pytest.mark.parametrize("name, type_, value", test_data) +def test_config_set_get_item(standard_readable_config, name, type_, value): + mock_signal = MagicMock(spec=SignalW) + standard_readable_config[mock_signal] = value + if type_ is np.ndarray: + assert np.array_equal(standard_readable_config[mock_signal], value) else: - assert readable_device_config[name][dtype] == value - - -@pytest.mark.parametrize("name,dtype,value", test_data) -def test_set_attribute(readable_device_config, name, dtype, value): - readable_device_config.add_attribute(name, dtype, value) - new_value = value if not isinstance(value, (int, float)) else value + 1 - if dtype is bool: - new_value = not value - if dtype is np.ndarray: - new_value = np.flip(value) - readable_device_config[name][dtype] = new_value - if isinstance(value, np.ndarray): - assert np.array_equal(readable_device_config[name][dtype], new_value) - else: - assert readable_device_config[name][dtype] == new_value + assert standard_readable_config[mock_signal] == value -@pytest.mark.parametrize("name,dtype,value", test_data) -def test_invalid_type(readable_device_config, name, dtype, value): - with pytest.raises(TypeError): - if dtype is str: - readable_device_config.add_attribute(name, dtype, 1) - else: - readable_device_config.add_attribute(name, dtype, "invalid_type") +@pytest.mark.parametrize("name, type_, value", test_data) +def test_config_del_item(standard_readable_config, name, type_, value): + mock_signal = MagicMock(spec=SignalW) + standard_readable_config[mock_signal] = value + del standard_readable_config[mock_signal] + with pytest.raises(KeyError): + _ = standard_readable_config[mock_signal] -@pytest.mark.parametrize("name,dtype,value", test_data) -def test_add_attribute_default_value(readable_device_config, name, dtype, value): - readable_device_config.add_attribute(name, dtype) - assert name in readable_device_config.signals[dtype] - # Check that the default value is of the correct type - assert readable_device_config.signals[dtype][name][1] is None +def test_config_iteration(standard_readable_config): + mock_signal1 = MagicMock(spec=SignalW) + mock_signal2 = MagicMock(spec=SignalW) + standard_readable_config[mock_signal1] = 42 + standard_readable_config[mock_signal2] = 43 + signals = list(standard_readable_config) + assert mock_signal1 in signals + assert mock_signal2 in signals -@pytest.mark.asyncio -async def test_readable_device_prepare(readable_device_config): - sr = StandardReadable() - mock = MagicMock() - sr.add_readables = mock - with sr.add_children_as_readables(ConfigSignal): - sr.a = SignalRW(name="a", backend=SoftSignalBackend(datatype=int)) - sr.b = SignalRW(name="b", backend=SoftSignalBackend(datatype=float)) - sr.c = SignalRW(name="c", backend=SoftSignalBackend(datatype=str)) - sr.d = SignalRW(name="d", backend=SoftSignalBackend(datatype=bool)) - - readable_device_config.add_attribute("a", int, 42) - readable_device_config.add_attribute("b", float, 3.14) - readable_device_config.add_attribute("c", str, "hello") - - await sr.prepare(readable_device_config) - assert await sr.a.get_value() == 42 - assert await sr.b.get_value() == 3.14 - assert await sr.c.get_value() == "hello" - - readable_device_config.add_attribute("d", int, 1) - with pytest.raises(TypeError): - await sr.prepare(readable_device_config) +def test_config_length(standard_readable_config): + mock_signal1 = MagicMock(spec=SignalW) + mock_signal2 = MagicMock(spec=SignalW) + standard_readable_config[mock_signal1] = 42 + standard_readable_config[mock_signal2] = 43 + assert len(standard_readable_config) == 2 -def test_get_config(): - sr = StandardReadable() +@pytest.mark.asyncio +@pytest.mark.parametrize("name, type_, value", test_data) +async def test_config_prepare(standard_readable_config, name, type_, value): + readable = StandardReadable() + if type_ is np.ndarray: + readable.mock_signal1 = SignalRW( + name="mock_signal1", + backend=SoftSignalBackend( + datatype=type_, initial_value=np.ndarray([0, 0, 0]) + ), + ) + else: + readable.mock_signal1 = SignalRW( + name="mock_signal1", backend=SoftSignalBackend(datatype=type_) + ) - hinted = SignalRW(name="hinted", backend=SoftSignalBackend(datatype=int)) - configurable = SignalRW( - name="configurable", backend=SoftSignalBackend(datatype=int) - ) - normal = SignalRW(name="normal", backend=SoftSignalBackend(datatype=int)) + readable.add_readables([readable.mock_signal1]) - sr.add_readables([configurable], ConfigSignal) - sr.add_readables([hinted], HintedSignal) - sr.add_readables([normal]) + config = PerSignalConfig() + config[readable.mock_signal1] = value - config = sr.get_config() + await readable.prepare(config) + val = await readable.mock_signal1.get_value() - # Check that configurable is in the config - assert config["configurable"][int] is None - with pytest.raises(AttributeError): - config["hinted"][int] - with pytest.raises(AttributeError): - config["normal"][int] + if type_ is np.ndarray: + assert np.array_equal(val, value) + else: + assert await readable.mock_signal1.get_value() == value