diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 6b9c6ccac3..1928c7aba4 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -61,9 +61,14 @@ soft_signal_rw, wait_for_value, ) -from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, + SubsetEnum, +) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status +from ._table import Table from ._utils import ( DEFAULT_TIMEOUT, CalculatableTimeout, @@ -152,6 +157,7 @@ "CalculateTimeout", "NotConnected", "ReadingValueCallback", + "Table", "T", "WatcherUpdate", "get_dtype", diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 5b81228264..d847caff69 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -7,6 +7,7 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg +from pydantic import BaseModel from ._device import Device from ._signal import SignalRW @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel +) -> yaml.Node: + return dumper.represent_data(model.model_dump(mode="python")) + + class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: """ yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) + yaml.add_multi_representer( + BaseModel, + pydantic_model_abstraction_representer, + Dumper=yaml.Dumper, + ) with open(save_path, "w") as file: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 221645fe0b..97aae72a39 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -1,7 +1,7 @@ import asyncio from functools import cached_property from typing import Callable, Optional, Type -from unittest.mock import Mock +from unittest.mock import AsyncMock from bluesky.protocols import Descriptor, Reading @@ -46,8 +46,8 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None: pass @cached_property - def put_mock(self) -> Mock: - return Mock(name="put", spec=Callable) + def put_mock(self) -> AsyncMock: + return AsyncMock(name="put", spec=Callable) @cached_property def put_proceeds(self) -> asyncio.Event: @@ -56,7 +56,7 @@ def put_proceeds(self) -> asyncio.Event: return put_proceeds async def put(self, value: Optional[T], wait=True, timeout=None): - self.put_mock(value, wait=wait, timeout=timeout) + await self.put_mock(value, wait=wait, timeout=timeout) await self.soft_backend.put(value, wait=wait, timeout=timeout) if wait: diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 767ebfa125..76d1a04c12 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -1,6 +1,6 @@ from contextlib import asynccontextmanager, contextmanager -from typing import Any, Callable, Iterable -from unittest.mock import Mock +from typing import Any, Awaitable, Callable, Iterable +from unittest.mock import AsyncMock from ._mock_signal_backend import MockSignalBackend from ._signal import Signal @@ -41,7 +41,7 @@ async def mock_puts_blocked(*signals: Signal): set_mock_put_proceeds(signal, True) -def get_mock_put(signal: Signal) -> Mock: +def get_mock_put(signal: Signal) -> AsyncMock: """Get the mock associated with the put call on the signal.""" return _get_mock_signal_backend(signal).put_mock @@ -136,12 +136,14 @@ def set_mock_values( @contextmanager -def _unset_side_effect_cm(put_mock: Mock): +def _unset_side_effect_cm(put_mock: AsyncMock): yield put_mock.side_effect = None -def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]): +def callback_on_mock_put( + signal: Signal[T], callback: Callable[[T], None] | Callable[[T], Awaitable[None]] +): """For setting a callback when a backend is put to. Can either be used in a context, with the callback being diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 41e9fbcbd3..594863ef2a 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,5 +1,13 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, + Literal, + Optional, + Tuple, + Type, +) from ._protocol import DataKey, Reading from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]): #: Datatype of the signal value datatype: Optional[Type[T]] = None + @classmethod + @abstractmethod + def datatype_allowed(cls, dtype: type): + """Check if a given datatype is acceptable for this signal backend.""" + #: Like ca://PV_PREFIX:SIGNAL @abstractmethod def source(self, name: str) -> str: diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 62bafd5bb1..1e895e60cc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -2,15 +2,20 @@ import inspect import time +from abc import ABCMeta from collections import abc from enum import Enum from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin import numpy as np from bluesky.protocols import DataKey, Dtype, Reading +from pydantic import BaseModel from typing_extensions import TypedDict -from ._signal_backend import RuntimeSubsetEnum, SignalBackend +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, +) from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype primitive_dtypes: Dict[type, Dtype] = { @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftEnumConverter(SoftConverter): choices: Tuple[str, ...] - def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]): + def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]): if issubclass(datatype, Enum): self.choices = tuple(v.value for v in datatype) else: @@ -122,6 +127,16 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) +class SoftPydanticModelConverter(SoftConverter): + def __init__(self, datatype: Type[BaseModel]): + self.datatype = datatype + + def write_value(self, value): + if isinstance(value, dict): + return self.datatype(**value) + return value + + def make_converter(datatype): is_array = get_dtype(datatype) is not None is_sequence = get_origin(datatype) == abc.Sequence @@ -129,10 +144,19 @@ def make_converter(datatype): issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) + is_pydantic_model = ( + inspect.isclass(datatype) + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + and isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ) + if is_array or is_sequence: return SoftArrayConverter() if is_enum: return SoftEnumConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) return SoftConverter() @@ -145,6 +169,10 @@ class SoftSignalBackend(SignalBackend[T]): _timestamp: float _severity: int + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + return True # Any value allowed in a soft signal + def __init__( self, datatype: Optional[Type[T]], diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py new file mode 100644 index 0000000000..bdb619a3b9 --- /dev/null +++ b/src/ophyd_async/core/_table.py @@ -0,0 +1,58 @@ +import numpy as np +from pydantic import BaseModel, ConfigDict, model_validator + + +class Table(BaseModel): + """An abstraction of a Table of str to numpy array.""" + + model_config = ConfigDict(validate_assignment=True, strict=False) + + @classmethod + def row(cls, sub_cls, **kwargs) -> "Table": + arrayified_kwargs = { + field_name: np.concatenate( + ( + (default_arr := field_value.default_factory()), + np.array([kwargs[field_name]], dtype=default_arr.dtype), + ) + ) + for field_name, field_value in sub_cls.model_fields.items() + } + return sub_cls(**arrayified_kwargs) + + def __add__(self, right: "Table") -> "Table": + """Concatenate the arrays in field values.""" + + assert isinstance(right, type(self)), ( + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." + ) + + return type(self)( + **{ + field_name: np.concatenate( + (getattr(self, field_name), getattr(right, field_name)) + ) + for field_name in self.model_fields + } + ) + + @model_validator(mode="after") + def validate_arrays(self) -> "Table": + first_length = len(next(iter(self))[1]) + assert all( + len(field_value) == first_length for _, field_value in self + ), "Rows should all be of equal size." + + if not all( + np.issubdtype( + self.model_fields[field_name].default_factory().dtype, field_value.dtype + ) + for field_name, field_value in self + ): + raise ValueError( + f"Cannot construct a `{type(self).__name__}`, " + "some rows have incorrect types." + ) + + return self diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index f5098ce717..d081ed008f 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -145,7 +145,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]: def get_unique(values: Dict[str, T], types: str) -> T: - """If all values are the same, return that value, otherwise return TypeError + """If all values are the same, return that value, otherwise raise TypeError >>> get_unique({"a": 1, "b": 1}, "integers") 1 diff --git a/src/ophyd_async/epics/adsimdetector/_sim.py b/src/ophyd_async/epics/adsimdetector/_sim.py index b69937705f..c007c72ffc 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim.py +++ b/src/ophyd_async/epics/adsimdetector/_sim.py @@ -12,14 +12,15 @@ class SimDetector(StandardDetector): def __init__( self, - drv: adcore.ADBaseIO, - hdf: adcore.NDFileHDFIO, + prefix: str, path_provider: PathProvider, + drv_suffix="cam1:", + hdf_suffix="HDF1:", name: str = "", config_sigs: Sequence[SignalR] = (), ): - self.drv = drv - self.hdf = hdf + self.drv = adcore.ADBaseIO(prefix + drv_suffix) + self.hdf = adcore.NDFileHDFIO(prefix + hdf_suffix) super().__init__( SimController(self.drv), diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 78052d448d..ef8a5693e2 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,9 +1,10 @@ +import inspect import logging import sys from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin import numpy as np from aioca import ( @@ -24,6 +25,7 @@ DEFAULT_TIMEOUT, NotConnected, ReadingValueCallback, + RuntimeSubsetEnum, SignalBackend, T, get_dtype, @@ -211,7 +213,8 @@ def make_converter( raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") return CaArrayConverter(pv_dbr, None) elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, CA can do int + # Database can't do bools, so are often representated as enums, + # CA can do int pv_choices_len = get_unique( {k: len(v.enums) for k, v in values.items()}, "number of choices" ) @@ -240,7 +243,7 @@ def make_converter( f"{pv} has type {type(value).__name__.replace('ca_', '')} " + f"not {datatype.__name__}" ) - return CaConverter(pv_dbr, None) + return CaConverter(pv_dbr, None) _tried_pyepics = False @@ -256,8 +259,31 @@ def _use_pyepics_context_if_imported(): class CaSignalBackend(SignalBackend[T]): + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + Enum, + RuntimeSubsetEnum, + np.ndarray, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not CaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, AugmentedValue] = {} diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 28ec8fe6ab..c7d0b5240d 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,14 +3,17 @@ import inspect import logging import time +from abc import ABCMeta from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +import numpy as np from bluesky.protocols import DataKey, Dtype, Reading from p4p import Value from p4p.client.asyncio import Context, Subscription +from pydantic import BaseModel from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -253,6 +256,19 @@ def get_datakey(self, source: str, value) -> DataKey: return _data_key_from_value(source, value, dtype="object") +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): + self.datatype = datatype + + def value(self, value: Value): + return self.datatype(**value.todict()) + + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + if isinstance(value, self.datatype): + return value.model_dump(mode="python") + return value + + class PvaDictConverter(PvaConverter): def reading(self, value): ts = time.time() @@ -348,6 +364,15 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") return PvaConverter() elif "NTTable" in typeid: + if ( + datatype + and inspect.isclass(datatype) + and + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ): + return PvaPydanticModelConverter(datatype) return PvaTableConverter() elif "structure" in typeid: return PvaDictConverter() @@ -358,8 +383,33 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve class PvaSignalBackend(SignalBackend[T]): _ctxt: Optional[Context] = None + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + np.ndarray, + Enum, + RuntimeSubsetEnum, + BaseModel, + dict, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not PvaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") + self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 9d1c1d429f..0dbe7222b0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -15,10 +15,7 @@ DatasetTable, PandaHdf5DatasetType, SeqTable, - SeqTableRow, SeqTrigger, - seq_table_from_arrays, - seq_table_from_rows, ) from ._trigger import ( PcompInfo, @@ -45,10 +42,7 @@ "DatasetTable", "PandaHdf5DatasetType", "SeqTable", - "SeqTableRow", "SeqTrigger", - "seq_table_from_arrays", - "seq_table_from_rows", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ec2c1a5b8b..ee6df7522f 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,11 +1,14 @@ -from dataclasses import dataclass +import inspect from enum import Enum -from typing import Optional, Sequence, Type, TypeVar +from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy.typing as pnd -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, field_validator, model_validator +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation +from typing_extensions import TypedDict + +from ophyd_async.core import Table class PandaHdf5DatasetType(str, Enum): @@ -34,137 +37,113 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -@dataclass -class SeqTableRow: - repeats: int = 1 - trigger: SeqTrigger = SeqTrigger.IMMEDIATE - position: int = 0 - time1: int = 0 - outa1: bool = False - outb1: bool = False - outc1: bool = False - outd1: bool = False - oute1: bool = False - outf1: bool = False - time2: int = 0 - outa2: bool = False - outb2: bool = False - outc2: bool = False - outd2: bool = False - oute2: bool = False - outf2: bool = False - - -class SeqTable(TypedDict): - repeats: NotRequired[pnd.Np1DArrayUint16] - trigger: NotRequired[Sequence[SeqTrigger]] - position: NotRequired[pnd.Np1DArrayInt32] - time1: NotRequired[pnd.Np1DArrayUint32] - outa1: NotRequired[pnd.Np1DArrayBool] - outb1: NotRequired[pnd.Np1DArrayBool] - outc1: NotRequired[pnd.Np1DArrayBool] - outd1: NotRequired[pnd.Np1DArrayBool] - oute1: NotRequired[pnd.Np1DArrayBool] - outf1: NotRequired[pnd.Np1DArrayBool] - time2: NotRequired[pnd.Np1DArrayUint32] - outa2: NotRequired[pnd.Np1DArrayBool] - outb2: NotRequired[pnd.Np1DArrayBool] - outc2: NotRequired[pnd.Np1DArrayBool] - outd2: NotRequired[pnd.Np1DArrayBool] - oute2: NotRequired[pnd.Np1DArrayBool] - outf2: NotRequired[pnd.Np1DArrayBool] - - -def seq_table_from_rows(*rows: SeqTableRow): - """ - Constructs a sequence table from a series of rows. - """ - return seq_table_from_arrays( - repeats=np.array([row.repeats for row in rows], dtype=np.uint16), - trigger=[row.trigger for row in rows], - position=np.array([row.position for row in rows], dtype=np.int32), - time1=np.array([row.time1 for row in rows], dtype=np.uint32), - outa1=np.array([row.outa1 for row in rows], dtype=np.bool_), - outb1=np.array([row.outb1 for row in rows], dtype=np.bool_), - outc1=np.array([row.outc1 for row in rows], dtype=np.bool_), - outd1=np.array([row.outd1 for row in rows], dtype=np.bool_), - oute1=np.array([row.oute1 for row in rows], dtype=np.bool_), - outf1=np.array([row.outf1 for row in rows], dtype=np.bool_), - time2=np.array([row.time2 for row in rows], dtype=np.uint32), - outa2=np.array([row.outa2 for row in rows], dtype=np.bool_), - outb2=np.array([row.outb2 for row in rows], dtype=np.bool_), - outc2=np.array([row.outc2 for row in rows], dtype=np.bool_), - outd2=np.array([row.outd2 for row in rows], dtype=np.bool_), - oute2=np.array([row.oute2 for row in rows], dtype=np.bool_), - outf2=np.array([row.outf2 for row in rows], dtype=np.bool_), - ) - - -T = TypeVar("T", bound=np.generic) - - -def seq_table_from_arrays( - *, - repeats: Optional[npt.NDArray[np.uint16]] = None, - trigger: Optional[Sequence[SeqTrigger]] = None, - position: Optional[npt.NDArray[np.int32]] = None, - time1: Optional[npt.NDArray[np.uint32]] = None, - outa1: Optional[npt.NDArray[np.bool_]] = None, - outb1: Optional[npt.NDArray[np.bool_]] = None, - outc1: Optional[npt.NDArray[np.bool_]] = None, - outd1: Optional[npt.NDArray[np.bool_]] = None, - oute1: Optional[npt.NDArray[np.bool_]] = None, - outf1: Optional[npt.NDArray[np.bool_]] = None, - time2: npt.NDArray[np.uint32], - outa2: Optional[npt.NDArray[np.bool_]] = None, - outb2: Optional[npt.NDArray[np.bool_]] = None, - outc2: Optional[npt.NDArray[np.bool_]] = None, - outd2: Optional[npt.NDArray[np.bool_]] = None, - oute2: Optional[npt.NDArray[np.bool_]] = None, - outf2: Optional[npt.NDArray[np.bool_]] = None, -) -> SeqTable: - """ - Constructs a sequence table from a series of columns as arrays. - time2 is the only required argument and must not be None. - All other provided arguments must be of equal length to time2. - If any other argument is not given, or else given as None or empty, - an array of length len(time2) filled with the following is defaulted: - repeats: 1 - trigger: SeqTrigger.IMMEDIATE - all others: 0/False as appropriate - """ - assert time2 is not None, "time2 must be provided" - length = len(time2) - assert 0 < length < 4096, f"Length {length} not in range" - - def or_default( - value: Optional[npt.NDArray[T]], dtype: Type[T], default_value: int = 0 - ) -> npt.NDArray[T]: - if value is None or len(value) == 0: - return np.full(length, default_value, dtype=dtype) - return value - - table = SeqTable( - repeats=or_default(repeats, np.uint16, 1), - trigger=trigger or [SeqTrigger.IMMEDIATE] * length, - position=or_default(position, np.int32), - time1=or_default(time1, np.uint32), - outa1=or_default(outa1, np.bool_), - outb1=or_default(outb1, np.bool_), - outc1=or_default(outc1, np.bool_), - outd1=or_default(outd1, np.bool_), - oute1=or_default(oute1, np.bool_), - outf1=or_default(outf1, np.bool_), - time2=time2, - outa2=or_default(outa2, np.bool_), - outb2=or_default(outb2, np.bool_), - outc2=or_default(outc2, np.bool_), - outd2=or_default(outd2, np.bool_), - oute2=or_default(oute2, np.bool_), - outf2=or_default(outf2, np.bool_), - ) - for k, v in table.items(): - size = len(v) # type: ignore - if size != length: - raise ValueError(f"{k}: has length {size} not {length}") - return table +PydanticNp1DArrayInt32 = Annotated[ + np.ndarray[tuple[int], np.int32], + NpArrayPydanticAnnotation.factory( + data_type=np.int32, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], np.int32)), +] +PydanticNp1DArrayBool = Annotated[ + np.ndarray[tuple[int], np.bool_], + NpArrayPydanticAnnotation.factory( + data_type=np.bool_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.bool_)), +] +TriggerStr = Annotated[ + np.ndarray[tuple[int], np.unicode_], + NpArrayPydanticAnnotation.factory( + data_type=np.unicode_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.dtype(" "SeqTable": + sig = inspect.signature(cls.row) + kwargs = {k: v for k, v in locals().items() if k in sig.parameters} + + if isinstance(kwargs["trigger"], SeqTrigger): + kwargs["trigger"] = kwargs["trigger"].value + elif isinstance(kwargs["trigger"], str): + SeqTrigger(kwargs["trigger"]) + + return Table.row(cls, **kwargs) + + @field_validator("trigger", mode="before") + @classmethod + def trigger_to_np_array(cls, trigger_column): + """ + The user can provide a list of SeqTrigger enum elements instead of a numpy str. + """ + if isinstance(trigger_column, Sequence) and all( + isinstance(trigger, SeqTrigger) for trigger in trigger_column + ): + trigger_column = np.array( + [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": + """ + Used to check max_length. Unfortunately trying the `max_length` arg in + the pydantic field doesn't work + """ + + first_length = len(next(iter(self))[1]) + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + return self diff --git a/src/ophyd_async/fastcs/panda/_writer.py b/src/ophyd_async/fastcs/panda/_writer.py index cf4c6c514a..65ca186fe3 100644 --- a/src/ophyd_async/fastcs/panda/_writer.py +++ b/src/ophyd_async/fastcs/panda/_writer.py @@ -106,6 +106,16 @@ async def _update_datasets(self) -> None: for dataset_name in capture_table["name"] ] + # Warn user if dataset table is empty in PandA + # i.e. no stream resources will be generated + if len(self._datasets) == 0: + self.panda_data_block.log.warning( + f"PandA {self._name_provider()} DATASETS table is empty! " + "No stream resource docs will be generated. " + "Make sure captured positions have their corresponding " + "*:DATASET PV set to a scientifically relevant name." + ) + # Next few functions are exactly the same as AD writer. Could move as default # StandardDetector behavior async def wait_for_index( diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 087ec62dd1..daa686b477 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -15,8 +15,6 @@ PcompInfo, SeqTable, SeqTableInfo, - SeqTableRow, - seq_table_from_rows, ) @@ -74,24 +72,26 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table: SeqTable = seq_table_from_rows( + table = ( # Wait for pre-delay then open shutter - SeqTableRow( + SeqTable.row( time1=in_micros(pre_delay), time2=in_micros(shutter_time), outa2=True, - ), + ) + + # Keeping shutter open, do N triggers - SeqTableRow( + SeqTable.row( repeats=number_of_frames, time1=in_micros(exposure), outa1=True, outb1=True, time2=in_micros(deadtime), outa2=True, - ), + ) + + # Add the shutter close - SeqTableRow(time2=in_micros(shutter_time)), + SeqTable.row(time2=in_micros(shutter_time)) ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index aa60be9802..b265b86137 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -8,6 +8,8 @@ import pytest import yaml from bluesky.run_engine import RunEngine +from pydantic import BaseModel, Field +from pydantic_numpy.typing import NpNDArrayFp16, NpNDArrayInt32 from ophyd_async.core import ( Device, @@ -54,6 +56,16 @@ class MyEnum(str, Enum): three = "three" +class SomePvaPydanticModel(BaseModel): + some_int_field: int = Field(default=1) + some_pydantic_numpy_field_float: NpNDArrayFp16 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + some_pydantic_numpy_field_int: NpNDArrayInt32 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + + class DummyDeviceGroupAllTypes(Device): def __init__(self, name: str): self.pv_int: SignalRW = epics_signal_rw(int, "PV1") @@ -73,6 +85,9 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") + self.pv_protocol_device_abstraction = epics_signal_rw( + SomePvaPydanticModel, "pva://PV17" + ) @pytest.fixture @@ -155,6 +170,7 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): await device_all_types.pv_array_str.set( ["one", "two", "three"], ) + await device_all_types.pv_protocol_device_abstraction.set(SomePvaPydanticModel()) # Create save plan from utility functions def save_my_device(): diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index 5aa1b03613..00c11a2708 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -1,7 +1,7 @@ import asyncio import re from itertools import repeat -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, AsyncMock, MagicMock, call import pytest @@ -65,6 +65,7 @@ async def test_set_mock_value(): mock_signal = SignalRW(SoftSignalBackend(int)) await mock_signal.connect(mock=True) assert await mock_signal.get_value() == 0 + assert mock_signal._backend assert await mock_signal._backend.get_value() == 0 set_mock_value(mock_signal, 10) assert await mock_signal.get_value() == 10 @@ -200,7 +201,7 @@ async def test_blocks_during_put(mock_signals): assert await signal2._backend.get_value() == "second_value" -async def test_callback_on_mock_put_ctxt(mock_signals): +async def test_callback_on_mock_put_as_context_manager(mock_signals): signal1_callbacks = MagicMock() signal2_callbacks = MagicMock() signal1, signal2 = mock_signals @@ -213,7 +214,7 @@ async def test_callback_on_mock_put_ctxt(mock_signals): signal2_callbacks.assert_called_once_with("second_value", wait=True, timeout=1) -async def test_callback_on_mock_put_no_ctx(): +async def test_callback_on_mock_put_not_as_context_manager(): mock_signal = SignalRW(SoftSignalBackend(float)) await mock_signal.connect(mock=True) calls = [] @@ -230,6 +231,19 @@ async def test_callback_on_mock_put_no_ctx(): ] +async def test_async_callback_on_mock_put(mock_signals): + signal1_callbacks = AsyncMock() + signal2_callbacks = AsyncMock() + signal1, signal2 = mock_signals + with callback_on_mock_put(signal1, signal1_callbacks): + await signal1.set("second_value", wait=True, timeout=1) + with callback_on_mock_put(signal2, signal2_callbacks): + await signal2.set("second_value", wait=True, timeout=1) + + signal1_callbacks.assert_awaited_once_with("second_value", wait=True, timeout=1) + signal2_callbacks.assert_awaited_once_with("second_value", wait=True, timeout=1) + + async def test_callback_on_mock_put_fails_if_args_are_not_correct(): mock_signal = SignalRW(SoftSignalBackend(float)) await mock_signal.connect(mock=True) @@ -392,4 +406,4 @@ async def test_when_put_mock_called_with_typo_then_fails_but_calling_directly_pa mock = mock_signal._backend.put_mock with pytest.raises(AttributeError): mock.asssert_called_once() # Note typo here is deliberate! - mock() + await mock() diff --git a/tests/core/test_protocol.py b/tests/core/test_protocol.py index d71c4cce09..637d287213 100644 --- a/tests/core/test_protocol.py +++ b/tests/core/test_protocol.py @@ -9,7 +9,7 @@ StaticFilenameProvider, StaticPathProvider, ) -from ophyd_async.epics import adcore, adsimdetector +from ophyd_async.epics import adsimdetector from ophyd_async.sim.demo import SimMotor @@ -18,11 +18,9 @@ async def make_detector(prefix: str, name: str, tmp_path: Path): dp = StaticPathProvider(fp, tmp_path) async with DeviceCollector(mock=True): - drv = adcore.ADBaseIO(f"{prefix}DRV:") - hdf = adcore.NDFileHDFIO(f"{prefix}HDF:") - det = adsimdetector.SimDetector( - drv, hdf, dp, config_sigs=[drv.acquire_time, drv.acquire], name=name - ) + det = adsimdetector.SimDetector(prefix, dp, name=name) + + det._config_sigs = [det.drv.acquire_time, det.drv.acquire] return det diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3b4c4934f4..ab5c02cffe 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -403,3 +403,28 @@ async def test_subscription_logs(caplog): assert "Making subscription" in caplog.text mock_signal_rw.clear_sub(cbs.append) assert "Closing subscription on source" in caplog.text + + +async def test_signal_unknown_datatype(): + class SomeClass: + def __init__(self): + self.some_attribute = "some_attribute" + + def some_function(self): + pass + + err_str = ( + "Given datatype .SomeClass'>" + " unsupported in %s." + ) + with pytest.raises(TypeError, match=err_str % ("PVA",)): + epics_signal_rw(SomeClass, "pva://mock_signal", name="mock_signal") + with pytest.raises(TypeError, match=err_str % ("CA",)): + epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") + + # Any dtype allowed in soft signal + signal = soft_signal_rw(SomeClass, SomeClass(), "soft_signal") + assert isinstance((await signal.get_value()), SomeClass) + await signal.set(1) + assert (await signal.get_value()) == 1 diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 5e55507626..16bf23567e 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -94,7 +94,7 @@ async def test_soft_signal_backend_get_put_monitor( descriptor: Callable[[Any], dict], dtype_numpy: str, ): - backend = SoftSignalBackend(datatype) + backend = SoftSignalBackend(datatype=datatype) await backend.connect() q = MonitorQueue(backend) diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 41af248aac..8c638d2770 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -7,8 +7,8 @@ from ophyd_async.epics.signal import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import make_converter as aioca_make_converter -from ophyd_async.epics.signal._p4p import make_converter as p4p_make_converter +from ophyd_async.epics.signal._aioca import make_converter as ca_make_converter +from ophyd_async.epics.signal._p4p import make_converter as pva_make_converter async def test_runtime_enum_behaviour(): @@ -52,7 +52,7 @@ def __init__(self): epics_value = EpicsValue() rt_enum = SubsetEnum["A", "B"] - converter = aioca_make_converter( + converter = ca_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert converter.choices == {"A": "A", "B": "B", "C": "C"} @@ -68,7 +68,7 @@ async def test_pva_runtime_enum_converter(): }, ) rt_enum = SubsetEnum["A", "B"] - converter = p4p_make_converter( + converter = pva_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert {"A", "B"}.issubset(set(converter.choices)) diff --git a/tests/epics/adsimdetector/test_adsim_controller.py b/tests/epics/adsimdetector/test_adsim_controller.py index b2610fe31b..8a7c33516b 100644 --- a/tests/epics/adsimdetector/test_adsim_controller.py +++ b/tests/epics/adsimdetector/test_adsim_controller.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - import pytest from ophyd_async.core import DeviceCollector @@ -16,15 +14,13 @@ async def ad(RE) -> adsimdetector.SimController: async def test_ad_controller(RE, ad: adsimdetector.SimController): - with patch("ophyd_async.core._signal.wait_for_value", return_value=None): - await ad.arm(num=1) + await ad.arm(num=1) driver = ad.driver assert await driver.num_images.get_value() == 1 assert await driver.image_mode.get_value() == adcore.ImageMode.multiple assert await driver.acquire.get_value() is True - with patch("ophyd_async.epics.adcore._utils.wait_for_value", return_value=None): - await ad.disarm() + await ad.disarm() assert await driver.acquire.get_value() is False diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 17494e1bb5..149835aa49 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -31,16 +31,13 @@ async def make_detector(prefix: str, name: str, tmp_path: Path): dp = StaticPathProvider(fp, tmp_path) async with DeviceCollector(mock=True): - drv = adcore.ADBaseIO(f"{prefix}DRV:", name="drv") - hdf = adcore.NDFileHDFIO(f"{prefix}HDF:") - det = adsimdetector.SimDetector( - drv, hdf, dp, config_sigs=[drv.acquire_time, drv.acquire], name=name - ) + det = adsimdetector.SimDetector(prefix, dp, name=name) + det._config_sigs = [det.drv.acquire_time, det.drv.acquire] def _set_full_file_name(val, *args, **kwargs): - set_mock_value(hdf.full_file_name, str(tmp_path / val)) + set_mock_value(det.hdf.full_file_name, str(tmp_path / val)) - callback_on_mock_put(hdf.file_name, _set_full_file_name) + callback_on_mock_put(det.hdf.file_name, _set_full_file_name) return det @@ -60,7 +57,7 @@ def count_sim(dets: List[StandardDetector], times: int = 1): for det in dets: yield from bps.trigger(det, wait=False, group="wait_for_trigger") - yield from bps.sleep(0.1) + yield from bps.sleep(0.2) [ set_mock_value( cast(adcore.ADHDFWriter, det.writer).hdf.num_captured, @@ -284,13 +281,13 @@ async def test_read_and_describe_detector(single_detector: StandardDetector): read = await single_detector.read_configuration() assert describe == { "test-drv-acquire_time": { - "source": "mock+ca://TEST:DRV:AcquireTime_RBV", + "source": "mock+ca://TEST:cam1:AcquireTime_RBV", "dtype": "number", "dtype_numpy": " PandaHDFWriter: @pytest.mark.parametrize("table", TABLES) async def test_open_returns_correct_descriptors( - mock_writer: PandaHDFWriter, table: DatasetTable + mock_writer: PandaHDFWriter, table: DatasetTable, caplog ): assert hasattr(mock_writer, "panda_data_block") set_mock_value( mock_writer.panda_data_block.datasets, table, ) - description = await mock_writer.open() # to make capturing status not time out + + with caplog.at_level(logging.WARNING): + description = await mock_writer.open() # to make capturing status not time out + + # Check if empty datasets table leads to warning log message + if len(table["name"]) == 0: + assert "DATASETS table is empty!" in caplog.text for key, entry, expected_key in zip( description.keys(), description.values(), table["name"] diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml index fc3e1ebd95..17dc9e61a9 100644 --- a/tests/test_data/test_yaml_save.yml +++ b/tests/test_data/test_yaml_save.yml @@ -19,4 +19,8 @@ pv_enum_str: two pv_float: 1.234 pv_int: 1 + pv_protocol_device_abstraction: + some_int_field: 1 + some_pydantic_numpy_field_float: [1, 2, 3] + some_pydantic_numpy_field_int: [1, 2, 3] pv_str: test_string