Skip to content

Commit

Permalink
Make table subclass enums be sequence enum rather than numpy string (b…
Browse files Browse the repository at this point in the history
…luesky#579)


---------

Co-authored-by: Eva Lott <[email protected]>
Co-authored-by: Tom Cobb <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent 1007da7 commit 8af94c9
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 105 deletions.
119 changes: 101 additions & 18 deletions src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,125 @@
from typing import TypeVar
from enum import Enum
from typing import TypeVar, get_args, get_origin

import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator

TableSubclass = TypeVar("TableSubclass", bound="Table")


def _concat(value1, value2):
if isinstance(value1, np.ndarray):
return np.concatenate((value1, value2))
else:
return value1 + value2


class Table(BaseModel):
"""An abstraction of a Table of str to numpy array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@staticmethod
def row(cls: type[TableSubclass], **kwargs) -> TableSubclass: # type: ignore
arrayified_kwargs = {
field_name: np.concatenate(
(
(default_arr := field_value.default_factory()), # type: ignore
np.array([kwargs[field_name]], dtype=default_arr.dtype),
arrayified_kwargs = {}
for field_name, field_value in cls.model_fields.items():
value = kwargs.pop(field_name)
if field_value.default_factory is None:
raise ValueError(
"`Table` models should have default factories for their "
"mutable empty columns."
)
default_array = field_value.default_factory()
if isinstance(default_array, np.ndarray):
arrayified_kwargs[field_name] = np.array(
[value], dtype=default_array.dtype
)
elif issubclass(type(value), Enum) and isinstance(value, str):
arrayified_kwargs[field_name] = [value]
else:
raise TypeError(
"Row column should be numpy arrays or sequence of string `Enum`."
)
if kwargs:
raise TypeError(
f"Unexpected keyword arguments {kwargs.keys()} for {cls.__name__}."
)
for field_name, field_value in cls.model_fields.items()
}
return cls(**arrayified_kwargs)

def __add__(self, right: TableSubclass) -> TableSubclass:
"""Concatenate the arrays in field values."""

assert type(right) is type(self), (
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)
if type(right) is not type(self):
raise RuntimeError(
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)

return type(right)(
**{
field_name: np.concatenate(
(getattr(self, field_name), getattr(right, field_name))
field_name: _concat(
getattr(self, field_name), getattr(right, field_name)
)
for field_name in self.model_fields
}
)

def numpy_dtype(self) -> np.dtype:
dtype = []
for field_name, field_value in self.model_fields.items():
if np.ndarray in (
get_origin(field_value.annotation),
field_value.annotation,
):
dtype.append((field_name, getattr(self, field_name).dtype))
else:
enum_type = get_args(field_value.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype.append((field_name, np.dtype(f"<U{max_length_in_enum}")))

return np.dtype(dtype)

def numpy_table(self):
# It would be nice to be able to use np.transpose for this,
# but it defaults to the largest dtype for everything.
dtype = self.numpy_dtype()
transposed_list = [
np.array(tuple(row), dtype=dtype)
for row in zip(*self.numpy_columns(), strict=False)
]
transposed = np.array(transposed_list, dtype=dtype)
return transposed

def numpy_columns(self) -> list[np.ndarray]:
"""Columns in the table can be lists of string enums or numpy arrays.
This method returns the columns, converting the string enums to numpy arrays.
"""

columns = []
for field_name, field_value in self.model_fields.items():
if np.ndarray in (
get_origin(field_value.annotation),
field_value.annotation,
):
columns.append(getattr(self, field_name))
else:
enum_type = get_args(field_value.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype = np.dtype(f"<U{max_length_in_enum}")

columns.append(
np.array(
[enum.value for enum in getattr(self, field_name)], dtype=dtype
)
)

return columns

@model_validator(mode="after")
def validate_arrays(self) -> "Table":
first_length = len(next(iter(self))[1])
Expand All @@ -49,11 +128,15 @@ def validate_arrays(self) -> "Table":
), "Rows should all be of equal size."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, # type: ignore
field_value.dtype,
# Checks if the values are numpy subtypes if the array is a numpy array,
# or if the value is a string enum.
np.issubdtype(getattr(self, field_name).dtype, default_array.dtype)
if isinstance(
default_array := self.model_fields[field_name].default_factory(), # type: ignore
np.ndarray,
)
for field_name, field_value in self
else issubclass(get_args(field_value.annotation)[0], Enum)
for field_name, field_value in self.model_fields.items()
):
raise ValueError(
f"Cannot construct a `{type(self).__name__}`, "
Expand Down
40 changes: 3 additions & 37 deletions src/ophyd_async/fastcs/panda/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import numpy.typing as npt
from pydantic import Field, field_validator, model_validator
from pydantic import Field, model_validator
from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation
from typing_extensions import TypedDict

Expand Down Expand Up @@ -51,13 +51,7 @@ class SeqTrigger(str, Enum):
),
Field(default_factory=lambda: np.array([], dtype=np.bool_)),
]
TriggerStr = Annotated[
np.ndarray[tuple[int], np.dtype[np.unicode_]],
NpArrayPydanticAnnotation.factory(
data_type=np.unicode_, dimensions=1, strict_data_typing=False
),
Field(default_factory=lambda: np.array([], dtype=np.dtype("<U32"))),
]
TriggerStr = Annotated[Sequence[SeqTrigger], Field(default_factory=list)]


class SeqTable(Table):
Expand Down Expand Up @@ -101,35 +95,7 @@ def row( # type: ignore
oute2: bool = False,
outf2: bool = False,
) -> "SeqTable":
if isinstance(trigger, SeqTrigger):
trigger = trigger.value
return super().row(**locals())

@field_validator("trigger", mode="before")
@classmethod
def trigger_to_np_array(cls, trigger_column):
"""
The user can provide a list of SeqTrigger enum elements instead of a numpy str.
"""
if isinstance(trigger_column, Sequence) and all(
isinstance(trigger, SeqTrigger) for trigger in trigger_column
):
trigger_column = np.array(
[trigger.value for trigger in trigger_column], dtype=np.dtype("<U32")
)
elif isinstance(trigger_column, Sequence) or isinstance(
trigger_column, np.ndarray
):
for trigger in trigger_column:
SeqTrigger(
trigger
) # To check all the given strings are actually `SeqTrigger`s
else:
raise ValueError(
"Expected a numpy array or a sequence of `SeqTrigger`, got "
f"{type(trigger_column)}."
)
return trigger_column
return Table.row(**locals())

@model_validator(mode="after")
def validate_max_length(self) -> "SeqTable":
Expand Down
Loading

0 comments on commit 8af94c9

Please sign in to comment.