Skip to content

Commit

Permalink
Merge branch 'main' into merge
Browse files Browse the repository at this point in the history
  • Loading branch information
burkeds committed Sep 12, 2024
2 parents b5f9787 + 0cbfbe6 commit 528be15
Show file tree
Hide file tree
Showing 29 changed files with 828 additions and 299 deletions.
8 changes: 7 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@
soft_signal_rw,
wait_for_value,
)
from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum
from ._signal_backend import (
RuntimeSubsetEnum,
SignalBackend,
SubsetEnum,
)
from ._soft_signal_backend import SignalMetadata, SoftSignalBackend
from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
from ._table import Table
from ._utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
Expand Down Expand Up @@ -152,6 +157,7 @@
"CalculateTimeout",
"NotConnected",
"ReadingValueCallback",
"Table",
"T",
"WatcherUpdate",
"get_dtype",
Expand Down
12 changes: 12 additions & 0 deletions src/ophyd_async/core/_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from pydantic import BaseModel

from ._device import Device
from ._signal import SignalRW
Expand All @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No
)


def pydantic_model_abstraction_representer(
dumper: yaml.Dumper, model: BaseModel
) -> yaml.Node:
return dumper.represent_data(model.model_dump(mode="python"))


class OphydDumper(yaml.Dumper):
def represent_data(self, data: Any) -> Any:
if isinstance(data, Enum):
Expand Down Expand Up @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:
"""

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)
yaml.add_multi_representer(
BaseModel,
pydantic_model_abstraction_representer,
Dumper=yaml.Dumper,
)

with open(save_path, "w") as file:
yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)
Expand Down
8 changes: 4 additions & 4 deletions src/ophyd_async/core/_mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from functools import cached_property
from typing import Callable, Optional, Type
from unittest.mock import Mock
from unittest.mock import AsyncMock

from bluesky.protocols import Descriptor, Reading

Expand Down Expand Up @@ -46,8 +46,8 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:
pass

@cached_property
def put_mock(self) -> Mock:
return Mock(name="put", spec=Callable)
def put_mock(self) -> AsyncMock:
return AsyncMock(name="put", spec=Callable)

@cached_property
def put_proceeds(self) -> asyncio.Event:
Expand All @@ -56,7 +56,7 @@ def put_proceeds(self) -> asyncio.Event:
return put_proceeds

async def put(self, value: Optional[T], wait=True, timeout=None):
self.put_mock(value, wait=wait, timeout=timeout)
await self.put_mock(value, wait=wait, timeout=timeout)
await self.soft_backend.put(value, wait=wait, timeout=timeout)

if wait:
Expand Down
12 changes: 7 additions & 5 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Callable, Iterable
from unittest.mock import Mock
from typing import Any, Awaitable, Callable, Iterable
from unittest.mock import AsyncMock

from ._mock_signal_backend import MockSignalBackend
from ._signal import Signal
Expand Down Expand Up @@ -41,7 +41,7 @@ async def mock_puts_blocked(*signals: Signal):
set_mock_put_proceeds(signal, True)


def get_mock_put(signal: Signal) -> Mock:
def get_mock_put(signal: Signal) -> AsyncMock:
"""Get the mock associated with the put call on the signal."""
return _get_mock_signal_backend(signal).put_mock

Expand Down Expand Up @@ -136,12 +136,14 @@ def set_mock_values(


@contextmanager
def _unset_side_effect_cm(put_mock: Mock):
def _unset_side_effect_cm(put_mock: AsyncMock):
yield
put_mock.side_effect = None


def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
def callback_on_mock_put(
signal: Signal[T], callback: Callable[[T], None] | Callable[[T], Awaitable[None]]
):
"""For setting a callback when a backend is put to.
Can either be used in a context, with the callback being
Expand Down
15 changes: 14 additions & 1 deletion src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
Literal,
Optional,
Tuple,
Type,
)

from ._protocol import DataKey, Reading
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
Expand All @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]):
#: Datatype of the signal value
datatype: Optional[Type[T]] = None

@classmethod
@abstractmethod
def datatype_allowed(cls, dtype: type):
"""Check if a given datatype is acceptable for this signal backend."""

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
def source(self, name: str) -> str:
Expand Down
32 changes: 30 additions & 2 deletions src/ophyd_async/core/_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

import inspect
import time
from abc import ABCMeta
from collections import abc
from enum import Enum
from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin

import numpy as np
from bluesky.protocols import DataKey, Dtype, Reading
from pydantic import BaseModel
from typing_extensions import TypedDict

from ._signal_backend import RuntimeSubsetEnum, SignalBackend
from ._signal_backend import (
RuntimeSubsetEnum,
SignalBackend,
)
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype

primitive_dtypes: Dict[type, Dtype] = {
Expand Down Expand Up @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
class SoftEnumConverter(SoftConverter):
choices: Tuple[str, ...]

def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]):
def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]):
if issubclass(datatype, Enum):
self.choices = tuple(v.value for v in datatype)
else:
Expand Down Expand Up @@ -122,17 +127,36 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
return cast(T, self.choices[0])


class SoftPydanticModelConverter(SoftConverter):
def __init__(self, datatype: Type[BaseModel]):
self.datatype = datatype

def write_value(self, value):
if isinstance(value, dict):
return self.datatype(**value)
return value


def make_converter(datatype):
is_array = get_dtype(datatype) is not None
is_sequence = get_origin(datatype) == abc.Sequence
is_enum = inspect.isclass(datatype) and (
issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum)
)

is_pydantic_model = (
inspect.isclass(datatype)
# Necessary to avoid weirdness in ABCMeta.__subclasscheck__
and isinstance(datatype, ABCMeta)
and issubclass(datatype, BaseModel)
)

if is_array or is_sequence:
return SoftArrayConverter()
if is_enum:
return SoftEnumConverter(datatype)
if is_pydantic_model:
return SoftPydanticModelConverter(datatype)

return SoftConverter()

Expand All @@ -145,6 +169,10 @@ class SoftSignalBackend(SignalBackend[T]):
_timestamp: float
_severity: int

@classmethod
def datatype_allowed(cls, datatype: Type) -> bool:
return True # Any value allowed in a soft signal

def __init__(
self,
datatype: Optional[Type[T]],
Expand Down
58 changes: 58 additions & 0 deletions src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator


class Table(BaseModel):
"""An abstraction of a Table of str to numpy array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@classmethod
def row(cls, sub_cls, **kwargs) -> "Table":
arrayified_kwargs = {
field_name: np.concatenate(
(
(default_arr := field_value.default_factory()),
np.array([kwargs[field_name]], dtype=default_arr.dtype),
)
)
for field_name, field_value in sub_cls.model_fields.items()
}
return sub_cls(**arrayified_kwargs)

def __add__(self, right: "Table") -> "Table":
"""Concatenate the arrays in field values."""

assert isinstance(right, type(self)), (
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)

return type(self)(
**{
field_name: np.concatenate(
(getattr(self, field_name), getattr(right, field_name))
)
for field_name in self.model_fields
}
)

@model_validator(mode="after")
def validate_arrays(self) -> "Table":
first_length = len(next(iter(self))[1])
assert all(
len(field_value) == first_length for _, field_value in self
), "Rows should all be of equal size."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, field_value.dtype
)
for field_name, field_value in self
):
raise ValueError(
f"Cannot construct a `{type(self).__name__}`, "
"some rows have incorrect types."
)

return self
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]:


def get_unique(values: Dict[str, T], types: str) -> T:
"""If all values are the same, return that value, otherwise return TypeError
"""If all values are the same, return that value, otherwise raise TypeError
>>> get_unique({"a": 1, "b": 1}, "integers")
1
Expand Down
9 changes: 5 additions & 4 deletions src/ophyd_async/epics/adsimdetector/_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ class SimDetector(StandardDetector):

def __init__(
self,
drv: adcore.ADBaseIO,
hdf: adcore.NDFileHDFIO,
prefix: str,
path_provider: PathProvider,
drv_suffix="cam1:",
hdf_suffix="HDF1:",
name: str = "",
config_sigs: Sequence[SignalR] = (),
):
self.drv = drv
self.hdf = hdf
self.drv = adcore.ADBaseIO(prefix + drv_suffix)
self.hdf = adcore.NDFileHDFIO(prefix + hdf_suffix)

super().__init__(
SimController(self.drv),
Expand Down
32 changes: 29 additions & 3 deletions src/ophyd_async/epics/signal/_aioca.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect
import logging
import sys
from dataclasses import dataclass
from enum import Enum
from math import isnan, nan
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin

import numpy as np
from aioca import (
Expand All @@ -24,6 +25,7 @@
DEFAULT_TIMEOUT,
NotConnected,
ReadingValueCallback,
RuntimeSubsetEnum,
SignalBackend,
T,
get_dtype,
Expand Down Expand Up @@ -211,7 +213,8 @@ def make_converter(
raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]")
return CaArrayConverter(pv_dbr, None)
elif pv_dbr == dbr.DBR_ENUM and datatype is bool:
# Database can't do bools, so are often representated as enums, CA can do int
# Database can't do bools, so are often representated as enums,
# CA can do int
pv_choices_len = get_unique(
{k: len(v.enums) for k, v in values.items()}, "number of choices"
)
Expand Down Expand Up @@ -240,7 +243,7 @@ def make_converter(
f"{pv} has type {type(value).__name__.replace('ca_', '')} "
+ f"not {datatype.__name__}"
)
return CaConverter(pv_dbr, None)
return CaConverter(pv_dbr, None)


_tried_pyepics = False
Expand All @@ -256,8 +259,31 @@ def _use_pyepics_context_if_imported():


class CaSignalBackend(SignalBackend[T]):
_ALLOWED_DATATYPES = (
bool,
int,
float,
str,
Sequence,
Enum,
RuntimeSubsetEnum,
np.ndarray,
)

@classmethod
def datatype_allowed(cls, datatype: Optional[Type]) -> bool:
stripped_origin = get_origin(datatype) or datatype
if datatype is None:
return True

return inspect.isclass(stripped_origin) and issubclass(
stripped_origin, cls._ALLOWED_DATATYPES
)

def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str):
self.datatype = datatype
if not CaSignalBackend.datatype_allowed(self.datatype):
raise TypeError(f"Given datatype {self.datatype} unsupported in CA.")
self.read_pv = read_pv
self.write_pv = write_pv
self.initial_values: Dict[str, AugmentedValue] = {}
Expand Down
Loading

0 comments on commit 528be15

Please sign in to comment.