Skip to content

Commit

Permalink
Added get_config method to StandardReadable which returns a ReadableD…
Browse files Browse the repository at this point in the history
…eviceConfig instance containing entries for each signal in _configurables.
  • Loading branch information
burkeds committed Aug 28, 2024
1 parent 135626f commit 957e0e1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
13 changes: 12 additions & 1 deletion src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ async def prepare(self, value: ReadableDeviceConfig) -> None:
if isinstance(attr, (HintedSignal, ConfigSignal)):
attr = attr.signal
if attr._backend.datatype == expected_dtype: # noqa: SLF001
tasks.append(attr.set(val))
if val is not None:
tasks.append(attr.set(val))
else:
raise TypeError(
f"Expected value of type {expected_dtype} for attribute"
Expand All @@ -231,6 +232,16 @@ async def prepare(self, value: ReadableDeviceConfig) -> None:
)
await asyncio.gather(*tasks)

def get_config(self) -> ReadableDeviceConfig:
config = ReadableDeviceConfig()
for readable in self._configurables:
if isinstance(readable, Union[ConfigSignal, HintedSignal]):
readable = readable.signal
name = readable.name.split("-")[-1]
dtype = readable._backend.datatype # noqa: SLF001
config.add_attribute(name, dtype)
return config


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
Expand Down
23 changes: 10 additions & 13 deletions src/ophyd_async/core/_readable_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np


@dataclass
@dataclass
class ReadableDeviceConfig:
int_signals: Dict[str, Tuple[int, Any]] = field(default_factory=dict)
Expand Down Expand Up @@ -52,43 +51,41 @@ def __getitem__(self, dtype: Type[Any]) -> Any:
if signals and self.key in signals:
return signals[self.key][1]
raise AttributeError(
f"'{self.parent.__class__.__name__}'"
f" object has no attribute"
f"'{self.parent.__class__.__name__}' object has no attribute"
f" '{self.key}[{dtype.__name__}]'"
)

def __setitem__(self, dtype: Type[Any], value: Any) -> None:
signals = self.parent.signals.get(dtype)
if signals and self.key in signals:
expected_type, _ = signals[self.key]
if isinstance(value, expected_type):
if isinstance(value, expected_type) or value is None:
signals[self.key] = (expected_type, value)
else:
raise TypeError(
f"Expected value of type {expected_type}"
f" for attribute '{self.key}', got {type(value)}"
f"Expected value of type {expected_type} for attribute"
f" '{self.key}', got {type(value)}"
)
else:
raise KeyError(
f"Key '{self.key}' not found" f" in {self.parent.attr_map[dtype]}"
f"Key '{self.key}' not found in {self.parent.attr_map[dtype]}"
)

def __getattr__(self, key: str) -> "ReadableDeviceConfig.SignalAccessor":
return self.SignalAccessor(self, key)

def add_attribute(self, name: str, dtype: Type[Any], value: Any) -> None:
if not isinstance(value, dtype):
def add_attribute(self, name: str, dtype: Type[Any], value: Any = None) -> None:
if value is not None and not isinstance(value, dtype):
raise TypeError(
f"Expected value of type {dtype} for attribute"
f" '{name}', got {type(value)}"
f"Expected value of type {dtype} for attribute '{name}',"
f" got {type(value)}"
)
self.signals[dtype][name] = (dtype, value)

def __setattr__(self, key: str, value: Any) -> None:
if "attr_map" in self.__dict__:
raise AttributeError(
f"Cannot set attribute '{key}'"
f" directly. Use 'add_attribute' method."
f"Cannot set attribute '{key}' directly. Use 'add_attribute' method."
)
else:
super().__setattr__(key, value)
Expand Down
31 changes: 31 additions & 0 deletions tests/core/test_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ def test_invalid_type(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype, "invalid_type")


@pytest.mark.parametrize("name,dtype,value", test_data)
def test_add_attribute_default_value(readable_device_config, name, dtype, value):
readable_device_config.add_attribute(name, dtype)
assert name in readable_device_config.signals[dtype]
# Check that the default value is of the correct type
assert readable_device_config.signals[dtype][name][1] is None


@pytest.mark.asyncio
async def test_readable_device_prepare(readable_device_config):
sr = StandardReadable()
Expand All @@ -330,3 +338,26 @@ async def test_readable_device_prepare(readable_device_config):
readable_device_config.add_attribute("d", int, 1)
with pytest.raises(TypeError):
await sr.prepare(readable_device_config)


def test_get_config():
sr = StandardReadable()

hinted = SignalRW(name="hinted", backend=SoftSignalBackend(datatype=int))
configurable = SignalRW(
name="configurable", backend=SoftSignalBackend(datatype=int)
)
normal = SignalRW(name="normal", backend=SoftSignalBackend(datatype=int))

sr.add_readables([configurable], ConfigSignal)
sr.add_readables([hinted], HintedSignal)
sr.add_readables([normal])

config = sr.get_config()

# Check that configurable is in the config
assert config["configurable"][int] is None
with pytest.raises(AttributeError):
config["hinted"][int]
with pytest.raises(AttributeError):
config["normal"][int]

0 comments on commit 957e0e1

Please sign in to comment.