From 135626f1a16687d2820cbdd68d4d96d27fb15eb3 Mon Sep 17 00:00:00 2001 From: Devin Burke Date: Wed, 28 Aug 2024 14:37:18 +0200 Subject: [PATCH] Prototyping a general purpose ReadableDeviceConfig dataclass to be used with StandardReadables. --- src/ophyd_async/core/__init__.py | 2 + src/ophyd_async/core/_readable.py | 21 +++++ src/ophyd_async/core/_readable_config.py | 100 +++++++++++++++++++++++ tests/core/test_readable.py | 92 +++++++++++++++++++++ 4 files changed, 215 insertions(+) create mode 100644 src/ophyd_async/core/_readable_config.py diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 3f88752fdd..7cd9755857 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -44,6 +44,7 @@ YMDPathProvider, ) from ._readable import ConfigSignal, HintedSignal, StandardReadable +from ._readable_config import ReadableDeviceConfig from ._signal import ( Signal, SignalR, @@ -158,4 +159,5 @@ "get_unique", "in_micros", "wait_for_connection", + "ReadableDeviceConfig", ] diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index c63a0f5dcf..b8c05c1d68 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -1,3 +1,4 @@ +import asyncio import warnings from contextlib import contextmanager from typing import Callable, Dict, Generator, Optional, Sequence, Tuple, Type, Union @@ -6,6 +7,7 @@ from ._device import Device, DeviceVector from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable +from ._readable_config import ReadableDeviceConfig from ._signal import SignalR from ._status import AsyncStatus from ._utils import merge_gathered_dicts @@ -210,6 +212,25 @@ def add_readables( if isinstance(obj, HasHints): self._has_hints += (obj,) + @AsyncStatus.wrap + async def prepare(self, value: ReadableDeviceConfig) -> None: + tasks = [] + for dtype, signals in value.signals.items(): + for signal_name, (expected_dtype, val) in signals.items(): + if hasattr(self, signal_name): + attr = getattr(self, signal_name) + if isinstance(attr, (HintedSignal, ConfigSignal)): + attr = attr.signal + if attr._backend.datatype == expected_dtype: # noqa: SLF001 + tasks.append(attr.set(val)) + else: + raise TypeError( + f"Expected value of type {expected_dtype} for attribute" + f" '{signal_name}'," + f" got {type(attr._backend.datatype)}" # noqa: SLF001 + ) + await asyncio.gather(*tasks) + class ConfigSignal(AsyncConfigurable): def __init__(self, signal: ReadableChild) -> None: diff --git a/src/ophyd_async/core/_readable_config.py b/src/ophyd_async/core/_readable_config.py new file mode 100644 index 0000000000..7ce05d51a6 --- /dev/null +++ b/src/ophyd_async/core/_readable_config.py @@ -0,0 +1,100 @@ +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Tuple, Type + +import numpy as np + + +@dataclass +@dataclass +class ReadableDeviceConfig: + int_signals: Dict[str, Tuple[int, Any]] = field(default_factory=dict) + float_signals: Dict[str, Tuple[float, Any]] = field(default_factory=dict) + str_signals: Dict[str, Tuple[str, Any]] = field(default_factory=dict) + bool_signals: Dict[str, Tuple[bool, Any]] = field(default_factory=dict) + list_signals: Dict[str, Tuple[list, Any]] = field(default_factory=dict) + tuple_signals: Dict[str, Tuple[tuple, Any]] = field(default_factory=dict) + dict_signals: Dict[str, Tuple[dict, Any]] = field(default_factory=dict) + set_signals: Dict[str, Tuple[set, Any]] = field(default_factory=dict) + frozenset_signals: Dict[str, Tuple[frozenset, Any]] = field(default_factory=dict) + bytes_signals: Dict[str, Tuple[bytes, Any]] = field(default_factory=dict) + bytearray_signals: Dict[str, Tuple[bytearray, Any]] = field(default_factory=dict) + complex_signals: Dict[str, Tuple[complex, Any]] = field(default_factory=dict) + none_signals: Dict[str, Tuple[type(None), Any]] = field(default_factory=dict) + ndarray_signals: Dict[str, Tuple[np.ndarray, Any]] = field(default_factory=dict) + signals: Dict[Type[Any], Dict[str, Tuple[Type[Any], Any]]] = field(init=False) + + def __post_init__(self): + self.signals = { + int: self.int_signals, + float: self.float_signals, + str: self.str_signals, + bool: self.bool_signals, + list: self.list_signals, + tuple: self.tuple_signals, + dict: self.dict_signals, + set: self.set_signals, + frozenset: self.frozenset_signals, + bytes: self.bytes_signals, + bytearray: self.bytearray_signals, + complex: self.complex_signals, + type(None): self.none_signals, + np.ndarray: self.ndarray_signals, + } + self.attr_map = dict(self.signals.items()) + + class SignalAccessor: + def __init__(self, parent: "ReadableDeviceConfig", key: str): + self.parent = parent + self.key = key + + def __getitem__(self, dtype: Type[Any]) -> Any: + signals = self.parent.signals.get(dtype) + if signals and self.key in signals: + return signals[self.key][1] + raise AttributeError( + f"'{self.parent.__class__.__name__}'" + f" object has no attribute" + f" '{self.key}[{dtype.__name__}]'" + ) + + def __setitem__(self, dtype: Type[Any], value: Any) -> None: + signals = self.parent.signals.get(dtype) + if signals and self.key in signals: + expected_type, _ = signals[self.key] + if isinstance(value, expected_type): + signals[self.key] = (expected_type, value) + else: + raise TypeError( + f"Expected value of type {expected_type}" + f" for attribute '{self.key}', got {type(value)}" + ) + else: + raise KeyError( + f"Key '{self.key}' not found" f" in {self.parent.attr_map[dtype]}" + ) + + def __getattr__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor": + return self.SignalAccessor(self, key) + + def add_attribute(self, name: str, dtype: Type[Any], value: Any) -> None: + if not isinstance(value, dtype): + raise TypeError( + f"Expected value of type {dtype} for attribute" + f" '{name}', got {type(value)}" + ) + self.signals[dtype][name] = (dtype, value) + + def __setattr__(self, key: str, value: Any) -> None: + if "attr_map" in self.__dict__: + raise AttributeError( + f"Cannot set attribute '{key}'" + f" directly. Use 'add_attribute' method." + ) + else: + super().__setattr__(key, value) + + def __getitem__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor": + return self.SignalAccessor(self, key) + + def items(self): + return asdict(self).items() diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index 7d13f308c3..297bb06c5c 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -2,6 +2,7 @@ from typing import List, get_type_hints from unittest.mock import MagicMock +import numpy as np import pytest from bluesky.protocols import HasHints @@ -14,7 +15,10 @@ DeviceVector, HintedSignal, MockSignalBackend, + ReadableDeviceConfig, SignalR, + SignalRW, + SoftSignalBackend, StandardReadable, soft_signal_r_and_setter, ) @@ -238,3 +242,91 @@ def test_standard_readable_add_children_multi_nested(): with outer.add_children_as_readables(): outer.inner = inner assert outer + + +@pytest.fixture +def readable_device_config(): + return ReadableDeviceConfig() + + +test_data = [ + ("test_int", int, 42), + ("test_float", float, 3.14), + ("test_str", str, "hello"), + ("test_bool", bool, True), + ("test_list", list, [1, 2, 3]), + ("test_tuple", tuple, (1, 2, 3)), + ("test_dict", dict, {"key": "value"}), + ("test_set", set, {1, 2, 3}), + ("test_frozenset", frozenset, frozenset([1, 2, 3])), + ("test_bytes", bytes, b"hello"), + ("test_bytearray", bytearray, bytearray(b"hello")), + ("test_complex", complex, 1 + 2j), + ("test_nonetype", type(None), None), + ("test_ndarray", np.ndarray, np.array([1, 2, 3])), +] + + +@pytest.mark.parametrize("name,dtype,value", test_data) +def test_add_attribute(readable_device_config, name, dtype, value): + readable_device_config.add_attribute(name, dtype, value) + assert name in readable_device_config.signals[dtype] + assert readable_device_config.signals[dtype][name] == (dtype, value) + + +@pytest.mark.parametrize("name,dtype,value", test_data) +def test_get_attribute(readable_device_config, name, dtype, value): + readable_device_config.add_attribute(name, dtype, value) + if isinstance(value, np.ndarray): + assert np.array_equal(readable_device_config[name][dtype], value) + else: + assert readable_device_config[name][dtype] == value + + +@pytest.mark.parametrize("name,dtype,value", test_data) +def test_set_attribute(readable_device_config, name, dtype, value): + readable_device_config.add_attribute(name, dtype, value) + new_value = value if not isinstance(value, (int, float)) else value + 1 + if dtype is bool: + new_value = not value + if dtype is np.ndarray: + new_value = np.flip(value) + readable_device_config[name][dtype] = new_value + if isinstance(value, np.ndarray): + assert np.array_equal(readable_device_config[name][dtype], new_value) + else: + assert readable_device_config[name][dtype] == new_value + + +@pytest.mark.parametrize("name,dtype,value", test_data) +def test_invalid_type(readable_device_config, name, dtype, value): + with pytest.raises(TypeError): + if dtype is str: + readable_device_config.add_attribute(name, dtype, 1) + else: + readable_device_config.add_attribute(name, dtype, "invalid_type") + + +@pytest.mark.asyncio +async def test_readable_device_prepare(readable_device_config): + sr = StandardReadable() + mock = MagicMock() + sr.add_readables = mock + with sr.add_children_as_readables(ConfigSignal): + sr.a = SignalRW(name="a", backend=SoftSignalBackend(datatype=int)) + sr.b = SignalRW(name="b", backend=SoftSignalBackend(datatype=float)) + sr.c = SignalRW(name="c", backend=SoftSignalBackend(datatype=str)) + sr.d = SignalRW(name="d", backend=SoftSignalBackend(datatype=bool)) + + readable_device_config.add_attribute("a", int, 42) + readable_device_config.add_attribute("b", float, 3.14) + readable_device_config.add_attribute("c", str, "hello") + + await sr.prepare(readable_device_config) + assert await sr.a.get_value() == 42 + assert await sr.b.get_value() == 3.14 + assert await sr.c.get_value() == "hello" + + readable_device_config.add_attribute("d", int, 1) + with pytest.raises(TypeError): + await sr.prepare(readable_device_config)