From 957e0e10f189997641cfd12a7e13e8f897854f11 Mon Sep 17 00:00:00 2001 From: Devin Burke Date: Wed, 28 Aug 2024 15:12:04 +0200 Subject: [PATCH] Added get_config method to StandardReadable which returns a ReadableDeviceConfig instance containing entries for each signal in _configurables. --- src/ophyd_async/core/_readable.py | 13 +++++++++- src/ophyd_async/core/_readable_config.py | 23 ++++++++---------- tests/core/test_readable.py | 31 ++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index b8c05c1d68..94e6d37544 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -222,7 +222,8 @@ async def prepare(self, value: ReadableDeviceConfig) -> None: if isinstance(attr, (HintedSignal, ConfigSignal)): attr = attr.signal if attr._backend.datatype == expected_dtype: # noqa: SLF001 - tasks.append(attr.set(val)) + if val is not None: + tasks.append(attr.set(val)) else: raise TypeError( f"Expected value of type {expected_dtype} for attribute" @@ -231,6 +232,16 @@ async def prepare(self, value: ReadableDeviceConfig) -> None: ) await asyncio.gather(*tasks) + def get_config(self) -> ReadableDeviceConfig: + config = ReadableDeviceConfig() + for readable in self._configurables: + if isinstance(readable, Union[ConfigSignal, HintedSignal]): + readable = readable.signal + name = readable.name.split("-")[-1] + dtype = readable._backend.datatype # noqa: SLF001 + config.add_attribute(name, dtype) + return config + class ConfigSignal(AsyncConfigurable): def __init__(self, signal: ReadableChild) -> None: diff --git a/src/ophyd_async/core/_readable_config.py b/src/ophyd_async/core/_readable_config.py index 7ce05d51a6..687e8e3e55 100644 --- a/src/ophyd_async/core/_readable_config.py +++ b/src/ophyd_async/core/_readable_config.py @@ -4,7 +4,6 @@ import numpy as np -@dataclass @dataclass class ReadableDeviceConfig: int_signals: Dict[str, Tuple[int, Any]] = field(default_factory=dict) @@ -52,8 +51,7 @@ def __getitem__(self, dtype: Type[Any]) -> Any: if signals and self.key in signals: return signals[self.key][1] raise AttributeError( - f"'{self.parent.__class__.__name__}'" - f" object has no attribute" + f"'{self.parent.__class__.__name__}' object has no attribute" f" '{self.key}[{dtype.__name__}]'" ) @@ -61,34 +59,33 @@ def __setitem__(self, dtype: Type[Any], value: Any) -> None: signals = self.parent.signals.get(dtype) if signals and self.key in signals: expected_type, _ = signals[self.key] - if isinstance(value, expected_type): + if isinstance(value, expected_type) or value is None: signals[self.key] = (expected_type, value) else: raise TypeError( - f"Expected value of type {expected_type}" - f" for attribute '{self.key}', got {type(value)}" + f"Expected value of type {expected_type} for attribute" + f" '{self.key}', got {type(value)}" ) else: raise KeyError( - f"Key '{self.key}' not found" f" in {self.parent.attr_map[dtype]}" + f"Key '{self.key}' not found in {self.parent.attr_map[dtype]}" ) def __getattr__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor": return self.SignalAccessor(self, key) - def add_attribute(self, name: str, dtype: Type[Any], value: Any) -> None: - if not isinstance(value, dtype): + def add_attribute(self, name: str, dtype: Type[Any], value: Any = None) -> None: + if value is not None and not isinstance(value, dtype): raise TypeError( - f"Expected value of type {dtype} for attribute" - f" '{name}', got {type(value)}" + f"Expected value of type {dtype} for attribute '{name}'," + f" got {type(value)}" ) self.signals[dtype][name] = (dtype, value) def __setattr__(self, key: str, value: Any) -> None: if "attr_map" in self.__dict__: raise AttributeError( - f"Cannot set attribute '{key}'" - f" directly. Use 'add_attribute' method." + f"Cannot set attribute '{key}' directly. Use 'add_attribute' method." ) else: super().__setattr__(key, value) diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index 297bb06c5c..3f3d2a75bf 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -307,6 +307,14 @@ def test_invalid_type(readable_device_config, name, dtype, value): readable_device_config.add_attribute(name, dtype, "invalid_type") +@pytest.mark.parametrize("name,dtype,value", test_data) +def test_add_attribute_default_value(readable_device_config, name, dtype, value): + readable_device_config.add_attribute(name, dtype) + assert name in readable_device_config.signals[dtype] + # Check that the default value is of the correct type + assert readable_device_config.signals[dtype][name][1] is None + + @pytest.mark.asyncio async def test_readable_device_prepare(readable_device_config): sr = StandardReadable() @@ -330,3 +338,26 @@ async def test_readable_device_prepare(readable_device_config): readable_device_config.add_attribute("d", int, 1) with pytest.raises(TypeError): await sr.prepare(readable_device_config) + + +def test_get_config(): + sr = StandardReadable() + + hinted = SignalRW(name="hinted", backend=SoftSignalBackend(datatype=int)) + configurable = SignalRW( + name="configurable", backend=SoftSignalBackend(datatype=int) + ) + normal = SignalRW(name="normal", backend=SoftSignalBackend(datatype=int)) + + sr.add_readables([configurable], ConfigSignal) + sr.add_readables([hinted], HintedSignal) + sr.add_readables([normal]) + + config = sr.get_config() + + # Check that configurable is in the config + assert config["configurable"][int] is None + with pytest.raises(AttributeError): + config["hinted"][int] + with pytest.raises(AttributeError): + config["normal"][int]