Skip to content

Commit

Permalink
Prototyping a general purpose ReadableDeviceConfig dataclass to be us…
Browse files Browse the repository at this point in the history
…ed with StandardReadables.
  • Loading branch information
burkeds committed Aug 28, 2024
1 parent d7f0748 commit 135626f
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
YMDPathProvider,
)
from ._readable import ConfigSignal, HintedSignal, StandardReadable
from ._readable_config import ReadableDeviceConfig
from ._signal import (
Signal,
SignalR,
Expand Down Expand Up @@ -158,4 +159,5 @@
"get_unique",
"in_micros",
"wait_for_connection",
"ReadableDeviceConfig",
]
21 changes: 21 additions & 0 deletions src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import warnings
from contextlib import contextmanager
from typing import Callable, Dict, Generator, Optional, Sequence, Tuple, Type, Union
Expand All @@ -6,6 +7,7 @@

from ._device import Device, DeviceVector
from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable
from ._readable_config import ReadableDeviceConfig
from ._signal import SignalR
from ._status import AsyncStatus
from ._utils import merge_gathered_dicts
Expand Down Expand Up @@ -210,6 +212,25 @@ def add_readables(
if isinstance(obj, HasHints):
self._has_hints += (obj,)

@AsyncStatus.wrap
async def prepare(self, value: ReadableDeviceConfig) -> None:
tasks = []
for dtype, signals in value.signals.items():
for signal_name, (expected_dtype, val) in signals.items():
if hasattr(self, signal_name):
attr = getattr(self, signal_name)
if isinstance(attr, (HintedSignal, ConfigSignal)):
attr = attr.signal
if attr._backend.datatype == expected_dtype: # noqa: SLF001
tasks.append(attr.set(val))
else:
raise TypeError(
f"Expected value of type {expected_dtype} for attribute"
f" '{signal_name}',"
f" got {type(attr._backend.datatype)}" # noqa: SLF001
)
await asyncio.gather(*tasks)


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
Expand Down
100 changes: 100 additions & 0 deletions src/ophyd_async/core/_readable_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Tuple, Type

import numpy as np


@dataclass
@dataclass
class ReadableDeviceConfig:
int_signals: Dict[str, Tuple[int, Any]] = field(default_factory=dict)
float_signals: Dict[str, Tuple[float, Any]] = field(default_factory=dict)
str_signals: Dict[str, Tuple[str, Any]] = field(default_factory=dict)
bool_signals: Dict[str, Tuple[bool, Any]] = field(default_factory=dict)
list_signals: Dict[str, Tuple[list, Any]] = field(default_factory=dict)
tuple_signals: Dict[str, Tuple[tuple, Any]] = field(default_factory=dict)
dict_signals: Dict[str, Tuple[dict, Any]] = field(default_factory=dict)
set_signals: Dict[str, Tuple[set, Any]] = field(default_factory=dict)
frozenset_signals: Dict[str, Tuple[frozenset, Any]] = field(default_factory=dict)
bytes_signals: Dict[str, Tuple[bytes, Any]] = field(default_factory=dict)
bytearray_signals: Dict[str, Tuple[bytearray, Any]] = field(default_factory=dict)
complex_signals: Dict[str, Tuple[complex, Any]] = field(default_factory=dict)
none_signals: Dict[str, Tuple[type(None), Any]] = field(default_factory=dict)
ndarray_signals: Dict[str, Tuple[np.ndarray, Any]] = field(default_factory=dict)
signals: Dict[Type[Any], Dict[str, Tuple[Type[Any], Any]]] = field(init=False)

def __post_init__(self):
self.signals = {
int: self.int_signals,
float: self.float_signals,
str: self.str_signals,
bool: self.bool_signals,
list: self.list_signals,
tuple: self.tuple_signals,
dict: self.dict_signals,
set: self.set_signals,
frozenset: self.frozenset_signals,
bytes: self.bytes_signals,
bytearray: self.bytearray_signals,
complex: self.complex_signals,
type(None): self.none_signals,
np.ndarray: self.ndarray_signals,
}
self.attr_map = dict(self.signals.items())

class SignalAccessor:
def __init__(self, parent: "ReadableDeviceConfig", key: str):
self.parent = parent
self.key = key

def __getitem__(self, dtype: Type[Any]) -> Any:
signals = self.parent.signals.get(dtype)
if signals and self.key in signals:
return signals[self.key][1]
raise AttributeError(
f"'{self.parent.__class__.__name__}'"
f" object has no attribute"
f" '{self.key}[{dtype.__name__}]'"
)

def __setitem__(self, dtype: Type[Any], value: Any) -> None:
signals = self.parent.signals.get(dtype)
if signals and self.key in signals:
expected_type, _ = signals[self.key]
if isinstance(value, expected_type):
signals[self.key] = (expected_type, value)
else:
raise TypeError(
f"Expected value of type {expected_type}"
f" for attribute '{self.key}', got {type(value)}"
)
else:
raise KeyError(
f"Key '{self.key}' not found" f" in {self.parent.attr_map[dtype]}"
)

def __getattr__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor":
return self.SignalAccessor(self, key)

def add_attribute(self, name: str, dtype: Type[Any], value: Any) -> None:
if not isinstance(value, dtype):
raise TypeError(
f"Expected value of type {dtype} for attribute"
f" '{name}', got {type(value)}"
)
self.signals[dtype][name] = (dtype, value)

def __setattr__(self, key: str, value: Any) -> None:
if "attr_map" in self.__dict__:
raise AttributeError(
f"Cannot set attribute '{key}'"
f" directly. Use 'add_attribute' method."
)
else:
super().__setattr__(key, value)

def __getitem__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor":
return self.SignalAccessor(self, key)

def items(self):
return asdict(self).items()
92 changes: 92 additions & 0 deletions tests/core/test_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, get_type_hints
from unittest.mock import MagicMock

import numpy as np
import pytest
from bluesky.protocols import HasHints

Expand All @@ -14,7 +15,10 @@
DeviceVector,
HintedSignal,
MockSignalBackend,
ReadableDeviceConfig,
SignalR,
SignalRW,
SoftSignalBackend,
StandardReadable,
soft_signal_r_and_setter,
)
Expand Down Expand Up @@ -238,3 +242,91 @@ def test_standard_readable_add_children_multi_nested():
with outer.add_children_as_readables():
outer.inner = inner
assert outer


@pytest.fixture
def readable_device_config():
return ReadableDeviceConfig()


test_data = [
("test_int", int, 42),
("test_float", float, 3.14),
("test_str", str, "hello"),
("test_bool", bool, True),
("test_list", list, [1, 2, 3]),
("test_tuple", tuple, (1, 2, 3)),
("test_dict", dict, {"key": "value"}),
("test_set", set, {1, 2, 3}),
("test_frozenset", frozenset, frozenset([1, 2, 3])),
("test_bytes", bytes, b"hello"),
("test_bytearray", bytearray, bytearray(b"hello")),
("test_complex", complex, 1 + 2j),
("test_nonetype", type(None), None),
("test_ndarray", np.ndarray, np.array([1, 2, 3])),
]


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_add_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
assert name in readable_device_config.signals[dtype]
assert readable_device_config.signals[dtype][name] == (dtype, value)


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_get_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
if isinstance(value, np.ndarray):
assert np.array_equal(readable_device_config[name][dtype], value)
else:
assert readable_device_config[name][dtype] == value


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_set_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
new_value = value if not isinstance(value, (int, float)) else value + 1
if dtype is bool:
new_value = not value
if dtype is np.ndarray:
new_value = np.flip(value)
readable_device_config[name][dtype] = new_value
if isinstance(value, np.ndarray):
assert np.array_equal(readable_device_config[name][dtype], new_value)
else:
assert readable_device_config[name][dtype] == new_value


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_invalid_type(readable_device_config, name, dtype, value):
with pytest.raises(TypeError):
if dtype is str:
readable_device_config.add_attribute(name, dtype, 1)
else:
readable_device_config.add_attribute(name, dtype, "invalid_type")


@pytest.mark.asyncio
async def test_readable_device_prepare(readable_device_config):
sr = StandardReadable()
mock = MagicMock()
sr.add_readables = mock
with sr.add_children_as_readables(ConfigSignal):
sr.a = SignalRW(name="a", backend=SoftSignalBackend(datatype=int))
sr.b = SignalRW(name="b", backend=SoftSignalBackend(datatype=float))
sr.c = SignalRW(name="c", backend=SoftSignalBackend(datatype=str))
sr.d = SignalRW(name="d", backend=SoftSignalBackend(datatype=bool))

readable_device_config.add_attribute("a", int, 42)
readable_device_config.add_attribute("b", float, 3.14)
readable_device_config.add_attribute("c", str, "hello")

await sr.prepare(readable_device_config)
assert await sr.a.get_value() == 42
assert await sr.b.get_value() == 3.14
assert await sr.c.get_value() == "hello"

readable_device_config.add_attribute("d", int, 1)
with pytest.raises(TypeError):
await sr.prepare(readable_device_config)

0 comments on commit 135626f

Please sign in to comment.