Skip to content

Commit

Permalink
Refactored into PerSignalConfig, a simpler dict-like dataclass that w…
Browse files Browse the repository at this point in the history
…ill co-operate with compile time typecheckers.
  • Loading branch information
burkeds committed Sep 3, 2024
1 parent c01afb1 commit 7b2350d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 214 deletions.
5 changes: 2 additions & 3 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
UUIDFilenameProvider,
YMDPathProvider,
)
from ._readable import ConfigSignal, HintedSignal, StandardReadable
from ._readable_config import ReadableDeviceConfig
from ._readable import ConfigSignal, HintedSignal, PerSignalConfig, StandardReadable
from ._signal import (
Signal,
SignalR,
Expand Down Expand Up @@ -159,5 +158,5 @@
"get_unique",
"in_micros",
"wait_for_connection",
"ReadableDeviceConfig",
"PerSignalConfig",
]
80 changes: 48 additions & 32 deletions src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import asyncio
import warnings
from collections.abc import MutableMapping
from contextlib import contextmanager
from typing import Callable, Dict, Generator, Optional, Sequence, Tuple, Type, Union

from bluesky.protocols import DataKey, HasHints, Hints, Reading
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from bluesky.protocols import DataKey, HasHints, Hints, Preparable, Reading

from ._device import Device, DeviceVector
from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable
from ._readable_config import ReadableDeviceConfig
from ._signal import SignalR
from ._signal import SignalR, SignalW
from ._status import AsyncStatus
from ._utils import merge_gathered_dicts

Expand All @@ -17,9 +30,35 @@
Callable[[ReadableChild], ReadableChild], Type["ConfigSignal"], Type["HintedSignal"]
]

T = TypeVar("T")


@dataclass
class PerSignalConfig(MutableMapping):
_signal_configuration: Dict[SignalW[Any], Any] = field(default_factory=dict)

@property
def signal_configuration(self) -> Dict[SignalW[Any], Any]:
return self._signal_configuration

def __setitem__(self, signal: SignalW[T], value: T):
self._signal_configuration[signal] = value

def __getitem__(self, signal: SignalW[T]) -> T:
return self._signal_configuration[signal]

def __delitem__(self, signal: SignalW[T]):
del self._signal_configuration[signal]

def __iter__(self) -> Iterator[SignalW[Any]]:
return iter(self._signal_configuration)

def __len__(self) -> int:
return len(self._signal_configuration)


class StandardReadable(
Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints
Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints, Preparable
):
"""Device that owns its children and provides useful default behavior.
Expand Down Expand Up @@ -213,35 +252,12 @@ def add_readables(
self._has_hints += (obj,)

@AsyncStatus.wrap
async def prepare(self, value: ReadableDeviceConfig) -> None:
async def prepare(self, config: PerSignalConfig) -> None:
tasks = []
for dtype, signals in value.signals.items():
for signal_name, (expected_dtype, val) in signals.items():
if hasattr(self, signal_name):
attr = getattr(self, signal_name)
if isinstance(attr, (HintedSignal, ConfigSignal)):
attr = attr.signal
if attr._backend.datatype == expected_dtype: # noqa: SLF001
if val is not None:
tasks.append(attr.set(val))
else:
raise TypeError(
f"Expected value of type {expected_dtype} for attribute"
f" '{signal_name}',"
f" got {type(attr._backend.datatype)}" # noqa: SLF001
)
for sig, value in config.items():
tasks.append(sig.set(value))
await asyncio.gather(*tasks)

def get_config(self) -> ReadableDeviceConfig:
config = ReadableDeviceConfig()
for readable in self._configurables:
if isinstance(readable, Union[ConfigSignal, HintedSignal]):
readable = readable.signal
name = readable.name.split("-")[-1]
dtype = readable._backend.datatype # noqa: SLF001
config.add_attribute(name, dtype)
return config


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
Expand Down
97 changes: 0 additions & 97 deletions src/ophyd_async/core/_readable_config.py

This file was deleted.

140 changes: 58 additions & 82 deletions tests/core/test_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
DeviceVector,
HintedSignal,
MockSignalBackend,
ReadableDeviceConfig,
PerSignalConfig,
SignalR,
SignalRW,
SignalW,
SoftSignalBackend,
StandardReadable,
soft_signal_r_and_setter,
Expand Down Expand Up @@ -245,8 +246,8 @@ def test_standard_readable_add_children_multi_nested():


@pytest.fixture
def readable_device_config():
return ReadableDeviceConfig()
def standard_readable_config():
return PerSignalConfig()


test_data = [
Expand All @@ -267,97 +268,72 @@ def readable_device_config():
]


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_add_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
assert name in readable_device_config.signals[dtype]
assert readable_device_config.signals[dtype][name] == (dtype, value)
def test_config_initialization(standard_readable_config):
assert len(standard_readable_config) == 0


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_get_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
if isinstance(value, np.ndarray):
assert np.array_equal(readable_device_config[name][dtype], value)
@pytest.mark.parametrize("name, type_, value", test_data)
def test_config_set_get_item(standard_readable_config, name, type_, value):
mock_signal = MagicMock(spec=SignalW)
standard_readable_config[mock_signal] = value
if type_ is np.ndarray:
assert np.array_equal(standard_readable_config[mock_signal], value)
else:
assert readable_device_config[name][dtype] == value


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_set_attribute(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, value)
new_value = value if not isinstance(value, (int, float)) else value + 1
if dtype is bool:
new_value = not value
if dtype is np.ndarray:
new_value = np.flip(value)
readable_device_config[name][dtype] = new_value
if isinstance(value, np.ndarray):
assert np.array_equal(readable_device_config[name][dtype], new_value)
else:
assert readable_device_config[name][dtype] == new_value
assert standard_readable_config[mock_signal] == value


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_invalid_type(readable_device_config, name, dtype, value):
with pytest.raises(TypeError):
if dtype is str:
readable_device_config.add_attribute(name, dtype, 1)
else:
readable_device_config.add_attribute(name, dtype, "invalid_type")
@pytest.mark.parametrize("name, type_, value", test_data)
def test_config_del_item(standard_readable_config, name, type_, value):
mock_signal = MagicMock(spec=SignalW)
standard_readable_config[mock_signal] = value
del standard_readable_config[mock_signal]
with pytest.raises(KeyError):
_ = standard_readable_config[mock_signal]


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_add_attribute_default_value(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype)
assert name in readable_device_config.signals[dtype]
# Check that the default value is of the correct type
assert readable_device_config.signals[dtype][name][1] is None
def test_config_iteration(standard_readable_config):
mock_signal1 = MagicMock(spec=SignalW)
mock_signal2 = MagicMock(spec=SignalW)
standard_readable_config[mock_signal1] = 42
standard_readable_config[mock_signal2] = 43
signals = list(standard_readable_config)
assert mock_signal1 in signals
assert mock_signal2 in signals


@pytest.mark.asyncio
async def test_readable_device_prepare(readable_device_config):
sr = StandardReadable()
mock = MagicMock()
sr.add_readables = mock
with sr.add_children_as_readables(ConfigSignal):
sr.a = SignalRW(name="a", backend=SoftSignalBackend(datatype=int))
sr.b = SignalRW(name="b", backend=SoftSignalBackend(datatype=float))
sr.c = SignalRW(name="c", backend=SoftSignalBackend(datatype=str))
sr.d = SignalRW(name="d", backend=SoftSignalBackend(datatype=bool))

readable_device_config.add_attribute("a", int, 42)
readable_device_config.add_attribute("b", float, 3.14)
readable_device_config.add_attribute("c", str, "hello")

await sr.prepare(readable_device_config)
assert await sr.a.get_value() == 42
assert await sr.b.get_value() == 3.14
assert await sr.c.get_value() == "hello"

readable_device_config.add_attribute("d", int, 1)
with pytest.raises(TypeError):
await sr.prepare(readable_device_config)
def test_config_length(standard_readable_config):
mock_signal1 = MagicMock(spec=SignalW)
mock_signal2 = MagicMock(spec=SignalW)
standard_readable_config[mock_signal1] = 42
standard_readable_config[mock_signal2] = 43
assert len(standard_readable_config) == 2


def test_get_config():
sr = StandardReadable()
@pytest.mark.asyncio
@pytest.mark.parametrize("name, type_, value", test_data)
async def test_config_prepare(standard_readable_config, name, type_, value):
readable = StandardReadable()
if type_ is np.ndarray:
readable.mock_signal1 = SignalRW(
name="mock_signal1",
backend=SoftSignalBackend(
datatype=type_, initial_value=np.ndarray([0, 0, 0])
),
)
else:
readable.mock_signal1 = SignalRW(
name="mock_signal1", backend=SoftSignalBackend(datatype=type_)
)

hinted = SignalRW(name="hinted", backend=SoftSignalBackend(datatype=int))
configurable = SignalRW(
name="configurable", backend=SoftSignalBackend(datatype=int)
)
normal = SignalRW(name="normal", backend=SoftSignalBackend(datatype=int))
readable.add_readables([readable.mock_signal1])

sr.add_readables([configurable], ConfigSignal)
sr.add_readables([hinted], HintedSignal)
sr.add_readables([normal])
config = PerSignalConfig()
config[readable.mock_signal1] = value

config = sr.get_config()
await readable.prepare(config)
val = await readable.mock_signal1.get_value()

# Check that configurable is in the config
assert config["configurable"][int] is None
with pytest.raises(AttributeError):
config["hinted"][int]
with pytest.raises(AttributeError):
config["normal"][int]
if type_ is np.ndarray:
assert np.array_equal(val, value)
else:
assert await readable.mock_signal1.get_value() == value

0 comments on commit 7b2350d

Please sign in to comment.