diff --git a/atef/bin/check.py b/atef/bin/check.py index ea9038e8..3b38ef93 100644 --- a/atef/bin/check.py +++ b/atef/bin/check.py @@ -3,6 +3,7 @@ """ import argparse +import asyncio import dataclasses import logging import pathlib @@ -14,6 +15,7 @@ import rich.console import rich.tree +from ..cache import DataCache, _SignalCache, get_signal_cache from ..check import Comparison, Result, Severity from ..config import (AnyConfiguration, Configuration, ConfigurationFile, PathItem, PreparedComparison) @@ -52,6 +54,12 @@ def build_arg_parser(argparser=None): help="Limit checkout to the named device(s) or identifiers", ) + argparser.add_argument( + "-p", "--parallel", + action="store_true", + help="Acquire data for comparisons in parallel", + ) + return argparser @@ -256,18 +264,55 @@ def log_results_rich( console.print(root) -def check_and_log( +async def check_and_log( config: AnyConfiguration, console: rich.console.Console, verbose: int = 0, client: Optional[happi.Client] = None, name_filter: Optional[Sequence[str]] = None, + parallel: bool = True, + cache: Optional[DataCache] = None, ): - """Check a configuration and log the results.""" + """ + Check a configuration and log the results. + + Parameters + ---------- + config : AnyConfiguration + The configuration to check. + console : rich.console.Console + The rich console to write output to. + verbose : int, optional + The verbosity level for the output. + client : happi.Client, optional + The happi client, if available. + name_filter : Sequence[str], optional + A filter for names. + parallel : bool, optional + Pre-fill cache in parallel when possible. + cache : DataCache + The data cache instance. + """ items = [] name_filter = list(name_filter or []) severities = [] - for prepared in PreparedComparison.from_config(config, client=client): + + if cache is None: + cache = DataCache() + + all_prepared = list( + PreparedComparison.from_config(config, client=client, cache=cache) + ) + + cache_fill_tasks = [] + if parallel: + for prepared in all_prepared: + if isinstance(prepared, PreparedComparison): + cache_fill_tasks.append( + asyncio.create_task(prepared.get_data_async()) + ) + + for prepared in all_prepared: if isinstance(prepared, PreparedComparison): if name_filter: device_name = getattr(prepared.device, "name", None) @@ -285,7 +330,7 @@ def check_and_log( ) continue - prepared.result = prepared.compare() + prepared.result = await prepared.compare() if prepared.result is not None: items.append(prepared) severities.append(prepared.result.severity) @@ -312,12 +357,14 @@ def check_and_log( ) -def main( +async def main( filename: str, name_filter: Optional[Sequence[str]] = None, verbose: int = 0, + parallel: bool = False, *, - cleanup: bool = True + cleanup: bool = True, + signal_cache: Optional[_SignalCache] = None, ): path = pathlib.Path(filename) if path.suffix.lower() == ".json": @@ -333,15 +380,18 @@ def main( client = None console = rich.console.Console() + cache = DataCache(signals=signal_cache or get_signal_cache()) try: with console.status("[bold green] Performing checks..."): for config in config_file.configs: - check_and_log( + await check_and_log( config, console=console, verbose=verbose, client=client, name_filter=name_filter, + parallel=parallel, + cache=cache, ) finally: if cleanup: diff --git a/atef/bin/main.py b/atef/bin/main.py index 049620ae..23dd5eda 100644 --- a/atef/bin/main.py +++ b/atef/bin/main.py @@ -6,8 +6,10 @@ """ import argparse +import asyncio import importlib import logging +from inspect import iscoroutinefunction import atef @@ -91,7 +93,10 @@ def main(): if hasattr(args, 'func'): func = kwargs.pop('func') logger.debug('%s(**%r)', func.__name__, kwargs) - func(**kwargs) + if iscoroutinefunction(func): + asyncio.run(func(**kwargs)) + else: + func(**kwargs) else: top_parser.print_help() diff --git a/atef/cache.py b/atef/cache.py index 934f0502..2cf6c736 100644 --- a/atef/cache.py +++ b/atef/cache.py @@ -1,26 +1,51 @@ from __future__ import annotations +import asyncio +import concurrent.futures +import dataclasses import logging +import typing from dataclasses import dataclass, field -from typing import Dict, Mapping, TypeVar +from typing import (Any, Dict, Hashable, Iterable, Mapping, Optional, Type, + TypeVar, cast) import ophyd -_CacheSignalType = TypeVar("_CacheSignalType") +from .reduce import ReduceMethod, get_data_for_signal_async +from .type_hints import Number + +if typing.TYPE_CHECKING: + from . import tools + +_CacheSignalType = TypeVar("_CacheSignalType", bound=ophyd.Signal) logger = logging.getLogger(__name__) +_signal_cache: Optional[_SignalCache[ophyd.EpicsSignalRO]] = None + + +def get_signal_cache() -> _SignalCache[ophyd.EpicsSignalRO]: + """Get the global EpicsSignal cache.""" + global _signal_cache + if _signal_cache is None: + _signal_cache = _SignalCache[ophyd.EpicsSignalRO](ophyd.EpicsSignalRO) + return _signal_cache @dataclass class _SignalCache(Mapping[str, _CacheSignalType]): - signal_type_cls: _CacheSignalType + signal_type_cls: Type[ophyd.EpicsSignalRO] pv_to_signal: Dict[str, _CacheSignalType] = field(default_factory=dict) def __getitem__(self, pv: str) -> _CacheSignalType: """Get a PV from the cache.""" if pv not in self.pv_to_signal: - self.pv_to_signal[pv] = self.signal_type_cls(pv, name=pv) + signal = cast( + _CacheSignalType, + self.signal_type_cls(pv, name=pv) + ) + self.pv_to_signal[pv] = signal + return signal return self.pv_to_signal[pv] @@ -40,12 +65,276 @@ def clear(self) -> None: self.pv_to_signal.clear() -_signal_cache = None +@dataclass(frozen=True, eq=True) +class DataKey: + period: Optional[Number] = None + method: ReduceMethod = ReduceMethod.average + string: bool = False -def get_signal_cache() -> _SignalCache[ophyd.EpicsSignalRO]: - """Get the global EpicsSignal cache.""" - global _signal_cache - if _signal_cache is None: - _signal_cache = _SignalCache(ophyd.EpicsSignalRO) - return _signal_cache +@dataclass(frozen=True, eq=True) +class ToolKey: + tool_cls: Type[tools.Tool] + settings: Optional[Hashable] + + @classmethod + def from_tool(cls, tool: tools.Tool) -> ToolKey: + settings = cast(Hashable, _freeze(dataclasses.asdict(tool))) + return cls( + tool_cls=type(tool), + settings=settings, + ) + + +def _freeze(data): + """ + Freeze ``data`` such that it can be used as a hashable key. + + Parameters + ---------- + data : Any + The data to be frozen. + + Returns + ------- + Any + Hopefully hashable version of ``data``. + """ + if isinstance(data, str): + return data + if isinstance(data, Mapping): + return frozenset( + (_freeze(key), _freeze(value)) + for key, value in data.items() + ) + if isinstance(data, Iterable): + return tuple(_freeze(part) for part in data) + return data + + +@dataclass +class DataCache: + signal_data: Dict[ophyd.Signal, Dict[DataKey, Any]] = field(default_factory=dict) + signals: _SignalCache[ophyd.EpicsSignalRO] = field( + default_factory=get_signal_cache + ) + tool_data: Dict[ToolKey, Any] = field( + default_factory=dict + ) + + def clear(self) -> None: + """Clear the data cache.""" + for data in list(self.signal_data.values()): + data.clear() + self.tool_data.clear() + + async def get_pv_data( + self, + pv: str, + reduce_period: Optional[Number] = None, + reduce_method: ReduceMethod = ReduceMethod.average, + string: bool = False, + executor: Optional[concurrent.futures.Executor] = None, + ) -> Optional[Any]: + """ + Get EPICS PV data with the provided data reduction settings. + + Utilizes cached data if already available. Multiple calls + with the same cache key will be batched. + + Parameters + ---------- + pv : str + The PV name. + reduce_period : float, optional + Period over which the comparison will occur, where multiple samples may + be acquired prior to a result being available. + reduce_method : ReduceMethod, optional + Reduce collected samples by this reduce method. Ignored if + reduce_period unset. + string : bool, optional + If applicable, request and compare string values rather than the + default specified. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + """ + return await self.get_signal_data( + self.signals[pv], + reduce_period=reduce_period, + reduce_method=reduce_method, + string=string, + executor=executor, + ) + + async def get_signal_data( + self, + signal: ophyd.Signal, + reduce_period: Optional[Number] = None, + reduce_method: ReduceMethod = ReduceMethod.average, + string: bool = False, + executor: Optional[concurrent.futures.Executor] = None, + ) -> Optional[Any]: + """ + Get signal data with the provided data reduction settings. + + Utilizes cached data if already available. Multiple calls + with the same cache key will be batched. + + Parameters + ---------- + signal : ophyd.Signal + The signal to retrieve data from. + reduce_period : float, optional + Period over which the comparison will occur, where multiple samples may + be acquired prior to a result being available. + reduce_method : ReduceMethod, optional + Reduce collected samples by this reduce method. Ignored if + reduce_period unset. + string : bool, optional + If applicable, request and compare string values rather than the + default specified. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + """ + key = DataKey(period=reduce_period, method=reduce_method, string=string) + signal_data = self.signal_data.setdefault(signal, {}) + try: + data = signal_data[key] + except KeyError: + data = asyncio.create_task( + self._update_signal_data_by_key(signal, key, executor=executor) + ) + signal_data[key] = data + + if isinstance(data, asyncio.Future): + return await data + return data + + async def _update_signal_data_by_key( + self, + signal: ophyd.Signal, + key: DataKey, + executor: Optional[concurrent.futures.Executor] = None, + ) -> Any: + """ + Update the signal data cache given the signal and the reduction key. + + Parameters + ---------- + signal : ophyd.Signal + The signal to update. + key : DataKey + The data key corresponding to the acquisition settings. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + """ + signal_data = self.signal_data[signal] + try: + acquired = await asyncio.shield( + get_data_for_signal_async( + signal, + reduce_period=key.period, + reduce_method=key.method, + string=key.string, + executor=executor, + ) + ) + except TimeoutError: + acquired = None + + signal_data[key] = acquired + return acquired + + async def get_tool_data( + self, + tool: tools.Tool, + ) -> Optional[Any]: + """ + Get tool data. + + Utilizes cached data if already available. Multiple calls + with the same cache key will be batched. + + Parameters + ---------- + tool : tools.Tool + The tool to run. + + Returns + ------- + Any + The acquired data. + """ + try: + key = ToolKey.from_tool(tool) + except Exception: + # Unhashable for some reason: we need to fix `_freeze`. Re-run + # the tool on demand and don't cache its results. + logger.warning( + "Internal issue with tool: %s. Caching mechanism " + "unavailable so performance may suffer.", + tool, + ) + logger.debug("Tool cache key failure: %s", tool, exc_info=True) + return await tool.run() + + try: + data = self.tool_data[key] + except KeyError: + data = asyncio.create_task(self._update_tool_by_key(tool, key)) + self.tool_data[key] = data + + if isinstance(data, asyncio.Future): + return await data + + return data + + async def _update_tool_by_key( + self, + tool: tools.Tool, + key: ToolKey, + executor: Optional[concurrent.futures.Executor] = None, + ) -> Any: + """ + Update the tool cache given the signal and the reduction key. + + Parameters + ---------- + tool : tools.Tool + The tool to run. + key : ToolKey + The hashable key according to the tool's configuration. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + """ + try: + acquired = await tool.run() + except Exception: + acquired = None + + self.tool_data[key] = acquired + return acquired diff --git a/atef/check.py b/atef/check.py index 683742ee..c4942485 100644 --- a/atef/check.py +++ b/atef/check.py @@ -1,8 +1,9 @@ from __future__ import annotations +import concurrent.futures import logging from dataclasses import dataclass, field -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, Iterable, List, Optional, Sequence import numpy as np import ophyd @@ -216,7 +217,6 @@ def __str__(self) -> str: f"{self.__class__.__name__}.describe() failure " f"({ex.__class__.__name__}: {ex})" ) - # return f"{self.__class__.__name__}({desc})" def compare(self, value: Any, identifier: Optional[str] = None) -> Result: """ @@ -258,6 +258,11 @@ def compare(self, value: Any, identifier: Optional[str] = None) -> Result: if self.invert: passed = not passed + # Some comparisons may be done with array values; require that + # all match for a success here: + if isinstance(passed, Iterable): + passed = all(passed) + if passed: return success @@ -273,51 +278,64 @@ def get_data_for_signal(self, signal: ophyd.Signal) -> Any: """ Get data for the given signal, according to the string and data reduction settings. - """ - if self.reduce_period and self.reduce_period > 0: - return self.reduce_method.subscribe_and_reduce( - signal, self.reduce_period - ) - if self.string: - return signal.get(as_string=True) + Parameters + ---------- + signal : ophyd.Signal + The signal. + + Returns + ------- + Any + The acquired data. - return signal.get() + Raises + ------ + TimeoutError + If the get operation times out. + """ + return reduce.get_data_for_signal( + signal, + reduce_period=self.reduce_period, + reduce_method=self.reduce_method, + string=self.string or False, + ) - def compare_signal( - self, signal: ophyd.Signal, *, identifier: Optional[str] = None - ) -> Result: + async def get_data_for_signal_async( + self, + signal: ophyd.Signal, + *, + executor: Optional[concurrent.futures.Executor] = None + ) -> Any: """ - Compare the provided signal's value using the comparator's settings. + Get data for the given signal, according to the string and data + reduction settings. Parameters ---------- signal : ophyd.Signal - The signal to get data from and run a comparison on. - - identifier : str, optional - An identifier that goes along with the provided signal. Used for - severity result descriptions. Defaults to the signal's dotted - name. + The signal. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + + Raises + ------ + TimeoutError + If the get operation times out. """ - try: - identifier = identifier or signal.dotted_name - try: - value = self.get_data_for_signal(signal) - except TimeoutError: - return Result( - severity=self.if_disconnected, - reason=f"Signal disconnected when reading: {signal}" - ) - return self.compare(value, identifier=identifier) - except Exception as ex: - return Result( - severity=Severity.internal_error, - reason=( - f"Checking if {identifier!r} {self} " - f"raised {ex.__class__.__name__}: {ex}" - ), - ) + return await reduce.get_data_for_signal_async( + signal, + reduce_period=self.reduce_period, + reduce_method=self.reduce_method, + string=self.string or False, + executor=executor, + ) @dataclass diff --git a/atef/config.py b/atef/config.py index d7fd2306..1ea87702 100644 --- a/atef/config.py +++ b/atef/config.py @@ -1,17 +1,19 @@ from __future__ import annotations +import asyncio import json import logging from dataclasses import dataclass, field -from typing import Generator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Generator, List, Optional, Sequence, Tuple, Union import apischema import happi import ophyd import yaml +from ophyd.signal import ConnectionTimeoutError -from . import serialization, util -from .cache import get_signal_cache +from . import serialization, tools, util +from .cache import DataCache from .check import Comparison, Result from .enums import Severity from .exceptions import PreparedComparisonException @@ -57,20 +59,50 @@ class Configuration: @dataclass class DeviceConfiguration(Configuration): + """ + A configuration that is built to check one or more devices. + + Identifiers are by default assumed to be attribute (component) names of the + devices. Identifiers may refer to components on the device + (``"component"`` would mean to access each device's ``.component``) or may + refer to any level of sub-device components (``"sub_device.component"`` + would mean to access each device's ``.sub_device`` and that sub-device's + ``.a`` component). + """ #: Happi device names which give meaning to self.checklist[].ids. devices: List[str] = field(default_factory=list) @dataclass class PVConfiguration(Configuration): + """ + A configuration that is built to check live EPICS PVs. + + Identifiers are by default assumed to be PV names. + """ ... -AnyConfiguration = Union[PVConfiguration, DeviceConfiguration] +@dataclass +class ToolConfiguration(Configuration): + """ + A configuration unrelated to PVs or Devices which verifies status via some + tool. + + Comparisons can optionally be run on the tool's results. + """ + tool: tools.Tool = field(default_factory=tools.Ping) + + +AnyConfiguration = Union[ + PVConfiguration, + DeviceConfiguration, + ToolConfiguration, +] PathItem = Union[ AnyConfiguration, - "IdentifierAndComparison", - "Comparison", + IdentifierAndComparison, + Comparison, str, ] @@ -81,7 +113,7 @@ class ConfigurationFile: A configuration file comprised of a number of devices/PV configurations. """ - #: configs: either PVConfiguration or DeviceConfiguration. + #: configs: PVConfiguration, DeviceConfiguration, or ToolConfiguration. configs: List[Configuration] def get_by_device(self, name: str) -> Generator[DeviceConfiguration, None, None]: @@ -140,14 +172,12 @@ class PreparedComparison: """ A unified representation of comparisons for device signals and standalone PVs. """ + #: The data cache to use for the preparation step. + cache: DataCache #: The identifier used for the comparison. identifier: str = "" #: The comparison itself. comparison: Comparison = field(default_factory=Comparison) - #: The device the comparison applies to, if applicable. - device: Optional[ophyd.Device] = None - #: The signal the comparison is to be run on. - signal: Optional[ophyd.Signal] = None #: The name of the associated configuration. name: Optional[str] = None #: The hierarhical path that led to this prepared comparison. @@ -155,7 +185,149 @@ class PreparedComparison: #: The last result of the comparison, if run. result: Optional[Result] = None - def compare(self) -> Result: + async def get_data_async(self) -> Any: + """ + Get the data according to the comparison's configuration. + + To be immplemented in subclass. + + Returns + ------- + data : Any + The acquired data. + """ + raise NotImplementedError() + + async def compare(self) -> Result: + """ + Run the comparison. + + To be immplemented in subclass. + """ + raise NotImplementedError() + + @classmethod + def from_config( + cls, + config: AnyConfiguration, + *, + client: Optional[happi.Client] = None, + cache: Optional[DataCache] = None, + ) -> Generator[Union[PreparedComparison, PreparedComparisonException], None, None]: + """ + Create one or more PreparedComparison instances from a PVConfiguration + or a DeviceConfiguration. + + If available, provide an instantiated happi Client and a data + cache. If unspecified, a configuration-derived happi Client will + be instantiated and a global data cache will be utilized. + + It is recommended - but not required - to manage a data cache on a + per-configuration basis. Managing the global cache is up to the user. + + Parameters + ---------- + config : PVConfiguration or DeviceConfiguration + The configuration. + client : happi.Client, optional + A happi Client instance. + cache : DataCache, optional + The data cache to use for this and other similar comparisons. + + Yields + ------ + item : PreparedComparisonException or PreparedComparison + If an error occurs during preparation, a + PreparedComparisonException will be yielded in place of the + PreparedComparison. + """ + if cache is None: + cache = DataCache() + + if isinstance(config, PVConfiguration): + yield from PreparedSignalComparison._from_pv_config(config, cache=cache) + elif isinstance(config, DeviceConfiguration): + if client is None: + client = happi.Client.from_config() + for dev_name in config.devices: + try: + device = util.get_happi_device_by_name(dev_name, client=client) + except Exception as ex: + yield PreparedComparisonException( + exception=ex, + comparison=None, # TODO + name=config.name or config.description, + identifier=dev_name, + path=[ + config, + dev_name, + ], + ) + else: + yield from PreparedSignalComparison._from_device_config( + config=config, + device=device, + cache=cache, + ) + elif isinstance(config, ToolConfiguration): + yield from PreparedToolComparison._from_tool_config(config, cache=cache) + else: + raise NotImplementedError(f"Configuration type unsupported: {type(config)}") + + +@dataclass +class PreparedSignalComparison(PreparedComparison): + """ + A unified representation of comparisons for device signals and standalone + PVs. + + Each PreparedSignalComparison has a single leaf in the configuration tree, + comprised of: + * A configuration + * The signal specification. This is comprised of the configuration and + "IdentifierAndComparison" + - DeviceConfiguration: Device and attribute (the "identifier") + - PVConfiguration: PV name (the "identifier") + * A comparison to run + - Including data reduction settings + """ + #: The device the comparison applies to, if applicable. + device: Optional[ophyd.Device] = None + #: The signal the comparison is to be run on. + signal: Optional[ophyd.Signal] = None + #: The value from the signal the comparison is to be run on. + data: Optional[Any] = None + + async def get_data_async(self) -> Any: + """ + Get the provided signal's data from the cache according to the + reduction configuration. + + Returns + ------- + data : Any + The acquired data. + + Raises + ------ + TimeoutError + If unable to connect or retrieve data from the signal. + """ + signal = self.signal + if signal is None: + raise ValueError("Signal instance unset") + + data = await self.cache.get_signal_data( + signal, + reduce_period=self.comparison.reduce_period, + reduce_method=self.comparison.reduce_method, + string=self.comparison.string or False, + ) + + self.data = data + return data + + async def compare(self) -> Result: """ Run the prepared comparison. @@ -164,26 +336,63 @@ def compare(self) -> Result: Result The result of the comparison. This is also set in ``self.result``. """ + try: + self.data = await self.get_data_async() + except (TimeoutError, asyncio.TimeoutError, ConnectionTimeoutError): + result = Result( + severity=self.comparison.if_disconnected, + reason=f"Signal not able to connect or read: {self.identifier}" + ) + except Exception as ex: + result = Result( + severity=Severity.internal_error, + reason=( + f"Getting data for signal {self.identifier!r} comparison " + f"{self.comparison} raised {ex.__class__.__name__}: {ex}" + ), + ) + + try: + result = self._compare() + except Exception as ex: + result = Result( + severity=Severity.internal_error, + reason=( + f"Failed to run {self.identifier!r} comparison " + f"{self.comparison} raised {ex.__class__.__name__}: {ex} " + f"with value {self.data}" + ), + ) + + self.result = result + return result + + def _compare(self) -> Result: + """ + Run the comparison with the already-acquired data in ``self.data``. + """ if self.signal is None: return Result( severity=Severity.internal_error, reason="Signal not set" ) - try: - self.signal.wait_for_connection() - except TimeoutError: + + data = self.data + if data is None: + # 'None' is likely incompatible with our comparisons and should + # be raised for separately return Result( severity=self.comparison.if_disconnected, reason=( - f"Unable to connect to {self.identifier!r} ({self.name}) " - f"for comparison {self.comparison}" + f"No data available for signal {self.identifier!r} in " + f"comparison {self.comparison}" ), ) - self.result = self.comparison.compare_signal( - self.signal, + + return self.comparison.compare( + data, identifier=self.identifier ) - return self.result @classmethod def from_device( @@ -193,10 +402,14 @@ def from_device( comparison: Comparison, name: Optional[str] = None, path: Optional[List[PathItem]] = None, - ) -> PreparedComparison: + cache: Optional[DataCache] = None, + ) -> PreparedSignalComparison: """Create a PreparedComparison from a device and comparison.""" full_attr = f"{device.name}.{attr}" logger.debug("Checking %s.%s with comparison %s", full_attr, comparison) + if cache is None: + cache = DataCache() + signal = getattr(device, attr, None) if signal is None: raise AttributeError( @@ -211,6 +424,7 @@ def from_device( comparison=comparison, signal=signal, path=path or [], + cache=cache, ) @classmethod @@ -220,28 +434,30 @@ def from_pvname( comparison: Comparison, name: Optional[str] = None, path: Optional[List[PathItem]] = None, - *, - cache: Optional[Mapping[str, ophyd.Signal]] = None, - ) -> PreparedComparison: + cache: Optional[DataCache] = None, + ) -> PreparedSignalComparison: """Create a PreparedComparison from a PV name and comparison.""" if cache is None: - cache = get_signal_cache() + cache = DataCache() return cls( identifier=pvname, device=None, - signal=cache[pvname], + signal=cache.signals[pvname], comparison=comparison, name=name, path=path or [], + cache=cache, ) @classmethod def _from_pv_config( cls, config: PVConfiguration, - cache: Optional[Mapping[str, ophyd.Signal]] = None, - ) -> Generator[Union[PreparedComparisonException, PreparedComparison], None, None]: + cache: DataCache, + ) -> Generator[ + Union[PreparedComparisonException, PreparedSignalComparison], None, None + ]: """ Create one or more PreparedComparison instances from a PVConfiguration. @@ -291,7 +507,10 @@ def _from_device_config( cls, device: ophyd.Device, config: DeviceConfiguration, - ) -> Generator[Union[PreparedComparisonException, PreparedComparison], None, None]: + cache: DataCache, + ) -> Generator[ + Union[PreparedComparisonException, PreparedSignalComparison], None, None + ]: """ Create one or more PreparedComparison instances from a DeviceConfiguration. @@ -326,6 +545,7 @@ def _from_device_config( comparison=comparison, name=config.name or config.description, path=path, + cache=cache, ) except Exception as ex: yield PreparedComparisonException( @@ -336,32 +556,143 @@ def _from_device_config( path=path, ) + +@dataclass +class PreparedToolComparison(PreparedComparison): + """ + A unified representation of comparisons for device signals and standalone PVs. + + Each PreparedToolComparison has a single leaf in the configuration tree, + comprised of: + * A configuration + * The tool configuration (i.e., a :class:`tools.Tool` instance) + * Identifiers to compare are dependent on the tool type + * A comparison to run + - For example, a :class:`tools.Ping` has keys described in + :class:`tools.PingResult`. + """ + #: The device the comparison applies to, if applicable. + tool: tools.Tool = field(default_factory=lambda: tools.Ping(hosts=[])) + + async def get_data_async(self) -> Any: + """ + Get the provided tool's result data from the cache. + + Returns + ------- + data : Any + The acquired data. + """ + return await self.cache.get_tool_data(self.tool) + + async def compare(self) -> Result: + """ + Run the prepared comparison. + + Returns + ------- + Result + The result of the comparison. This is also set in ``self.result``. + """ + try: + result = await self.get_data_async() + except (asyncio.TimeoutError, TimeoutError): + return Result( + severity=self.comparison.if_disconnected, + reason=( + f"Tool {self.tool} timed out {self.identifier!r} ({self.name}) " + f"for comparison {self.comparison}" + ), + ) + except Exception as ex: + logger.debug("Internal error with tool %s", self, exc_info=True) + # TODO: include some traceback information for debugging? + # Could 'Result' have optional verbose error information? + return Result( + severity=Severity.internal_error, + reason=( + f"Tool {self.tool} failed to run {self.identifier!r} ({self.name}) " + f"for comparison {self.comparison}: {ex.__class__.__name__} {ex}" + ), + ) + + try: + value = tools.get_result_value_by_key(result, self.identifier) + except KeyError as ex: + return Result( + severity=self.comparison.severity_on_failure, + reason=( + f"Provided key is invalid for tool result {self.tool} " + f"{self.identifier!r} ({self.name}): {ex} " + f"(in comparison {self.comparison})" + ), + ) + + self.result = self.comparison.compare( + value, + identifier=self.identifier + ) + return self.result + @classmethod - def from_config( + def from_tool( cls, - config: AnyConfiguration, - *, - client: Optional[happi.Client] = None, - cache: Optional[Mapping[str, ophyd.Signal]] = None, - ) -> Generator[Union[PreparedComparison, PreparedComparisonException], None, None]: + tool: tools.Tool, + result_key: str, + comparison: Comparison, + name: Optional[str] = None, + path: Optional[List[PathItem]] = None, + cache: Optional[DataCache] = None, + ) -> PreparedToolComparison: """ - Create one or more PreparedComparison instances from a PVConfiguration - or a DeviceConfiguration. - - If available, provide an instantiated happi Client and PV-to-Signal - cache. If unspecified, a configuration-derived happi Client will - be instantiated and a global PV-to-Signal cache will be utilized. + Prepare a tool-based comparison for execution. Parameters ---------- - config : PVConfiguration or DeviceConfiguration - The configuration. + tool : Tool + The tool to run. + result_key : str + The key from the result dictionary to check after running the tool. + comparison : Comparison + The comparison to perform on the tool's results (looking at the + specific result_key). + name : Optional[str], optional + The name of the comparison. + path : Optional[List[PathItem]], optional + The path that led us to this single comparison. + cache : DataCache, optional + The data cache to use for this and other similar comparisons. - client : happi.Client - A happi Client instance. + Returns + ------- + PreparedToolComparison + """ + if cache is None: + cache = DataCache() + tool.check_result_key(result_key) + return cls( + tool=tool, + comparison=comparison, + name=name, + identifier=result_key, + path=path or [], + cache=cache, + ) - cache : dict of str to type[Signal] - The PV to signal cache. + @classmethod + def _from_tool_config( + cls, + config: ToolConfiguration, + cache: DataCache, + ) -> Generator[Union[PreparedComparisonException, PreparedComparison], None, None]: + """ + Create one or more PreparedComparison instances from a + ToolConfiguration. + + Parameters + ---------- + config : ToolConfiguration + The configuration. Yields ------ @@ -370,33 +701,35 @@ def from_config( PreparedComparisonException will be yielded in place of the PreparedComparison. """ - if isinstance(config, PVConfiguration): - yield from cls._from_pv_config(config, cache=cache) - elif isinstance(config, DeviceConfiguration): - if client is None: - client = happi.Client.from_config() - for dev_name in config.devices: - try: - device = util.get_happi_device_by_name(dev_name, client=client) - except Exception as ex: - yield PreparedComparisonException( - exception=ex, - comparison=None, # TODO - name=config.name or config.description, - identifier=dev_name, - path=[ - config, - dev_name, - ], - ) - else: - yield from cls._from_device_config( - config=config, - device=device, - ) + for checklist_item in config.checklist: + for comparison in checklist_item.comparisons: + for result_key in checklist_item.ids: + path = [ + config, + checklist_item, + comparison, + result_key, + ] + try: + yield cls.from_tool( + tool=config.tool, + result_key=result_key, + comparison=comparison, + name=config.name or config.description, + path=path, + cache=cache, + ) + except Exception as ex: + yield PreparedComparisonException( + exception=ex, + comparison=comparison, + name=config.name or config.description, + identifier=result_key, + path=path, + ) -def check_device( +async def check_device( device: ophyd.Device, checklist: Sequence[IdentifierAndComparison] ) -> Tuple[Severity, List[Result]]: """ @@ -427,7 +760,7 @@ def check_device( full_attr = f"{device.name}.{attr}" logger.debug("Checking %s.%s with comparison %s", full_attr, comparison) try: - prepared = PreparedComparison.from_device( + prepared = PreparedSignalComparison.from_device( device=device, attr=attr, comparison=comparison ) except AttributeError: @@ -439,7 +772,7 @@ def check_device( ), ) else: - result = prepared.compare() + result = await prepared.compare() if result.severity > overall: overall = result.severity @@ -448,10 +781,9 @@ def check_device( return overall, results -def check_pvs( +async def check_pvs( checklist: Sequence[IdentifierAndComparison], - *, - cache: Optional[Mapping[str, ophyd.Signal]] = None, + cache: Optional[DataCache] = None, ) -> Tuple[Severity, List[Result]]: """ Check a PVConfiguration. @@ -461,6 +793,8 @@ def check_pvs( checklist : sequence of IdentifierAndComparison Comparisons to run on the given device. Multiple PVs may share the same checks. + cache : DataCache, optional + The data cache to use for this and other similar comparisons. Returns ------- @@ -472,7 +806,8 @@ def check_pvs( """ overall = Severity.success results = [] - cache = cache or get_signal_cache() + if cache is None: + cache = DataCache() def get_comparison_and_pvname(): for checklist_item in checklist: @@ -480,17 +815,81 @@ def get_comparison_and_pvname(): for pvname in checklist_item.ids: yield comparison, pvname - for comparison, pvname in get_comparison_and_pvname(): + prepared_comparisons = [ + PreparedSignalComparison.from_pvname( + pvname=pvname, comparison=comparison, cache=cache + ) + for comparison, pvname in get_comparison_and_pvname() + ] + + cache_fill_tasks = [] + for prepared in prepared_comparisons: # Pre-fill the cache with PVs, connecting in the background - _ = cache[pvname] + cache_fill_tasks.append( + asyncio.create_task( + prepared.get_data_async() + ) + ) - for comparison, pvname in get_comparison_and_pvname(): - logger.debug("Checking %s.%s with comparison %s", pvname, comparison) + for prepared in prepared_comparisons: + logger.debug( + "Checking %r with comparison %s", prepared.identifier, prepared.comparison + ) - prepared = PreparedComparison.from_pvname( - pvname=pvname, comparison=comparison, cache=cache + result = await prepared.compare() + + if result.severity > overall: + overall = result.severity + results.append(result) + + return overall, results + + +async def check_tool( + tool: tools.Tool, + checklist: Sequence[IdentifierAndComparison], + cache: Optional[DataCache] = None, +) -> Tuple[Severity, List[Result]]: + """ + Check a PVConfiguration. + + Parameters + ---------- + tool : Tool + The tool instance defining which tool to run and with what arguments. + checklist : sequence of IdentifierAndComparison + Comparisons to run on the given device. Multiple PVs may share the + same checks. + cache : DataCache, optional + The data cache to use for this tool and other similar comparisons. + + Returns + ------- + overall_severity : Severity + Maximum severity found when running comparisons. + + results : list of Result + Individual comparison results. + """ + overall = Severity.success + results = [] + + if cache is None: + cache = DataCache() + + def get_comparison_and_key(): + for checklist_item in checklist: + for comparison in checklist_item.comparisons: + for key in checklist_item.ids: + yield comparison, key + + for comparison, key in get_comparison_and_key(): + logger.debug("Checking %r with comparison %s", key, comparison) + + prepared = PreparedToolComparison.from_tool( + tool, result_key=key, comparison=comparison, cache=cache, ) - result = prepared.compare() + result = await prepared.compare() if result.severity > overall: overall = result.severity diff --git a/atef/exceptions.py b/atef/exceptions.py index 6ff57398..b4d7ddc1 100644 --- a/atef/exceptions.py +++ b/atef/exceptions.py @@ -52,8 +52,8 @@ class PreparedComparisonException(Exception): exception: Exception #: The identifier used for the comparison. identifier: str - #: The comparison itself. - comparison: Comparison + #: The comparison related to the exception, if applicable. + comparison: Optional[Comparison] #: The hierarhical path that led to this prepared comparison. path: List[PathItem] #: The name of the associated configuration. @@ -63,7 +63,7 @@ def __init__( self, exception: Exception, identifier: str, - comparison: Comparison, + comparison: Optional[Comparison] = None, name: Optional[str] = None, path: Optional[List[PathItem]] = None, ): @@ -73,3 +73,11 @@ def __init__( self.comparison = comparison self.name = name self.path = path or [] + + +class ToolException(Exception): + """Base exception for tool-related errors.""" + + +class ToolDependencyMissingException(Exception): + """Required dependency for a tool to work is unavailable.""" diff --git a/atef/reduce.py b/atef/reduce.py index 285c9795..94d89c24 100644 --- a/atef/reduce.py +++ b/atef/reduce.py @@ -1,13 +1,16 @@ from __future__ import annotations +import concurrent.futures import enum -from typing import Protocol, Sequence +from dataclasses import dataclass +from typing import Any, Optional, Protocol, Sequence import numpy as np import ophyd from .ophyd_helpers import acquire_async, acquire_blocking from .type_hints import Number, PrimitiveType +from .util import run_in_executor class _ReduceMethodType(Protocol): @@ -63,3 +66,171 @@ async def subscribe_and_reduce_async( """ data = await acquire_async(signal, duration) return self.reduce_values(data) + + +@dataclass(frozen=True, eq=True) +class ReductionKey: + period: Optional[Number] + method: ReduceMethod + + def get_data_for_signal(self, signal: ophyd.Signal, string: bool = False) -> Any: + """ + Get data for the given signal, according to the string and data + reduction settings. + + Parameters + ---------- + signal : ophyd.Signal + The signal. + + Returns + ------- + Any + The acquired data. + + Raises + ------ + TimeoutError + If the get operation times out. + """ + return get_data_for_signal( + signal, + reduce_period=self.period, + reduce_method=self.method, + string=string or False, + ) + + async def get_data_for_signal_async( + self, + signal: ophyd.Signal, + string: bool = False, + *, + executor: Optional[concurrent.futures.Executor] = None + ) -> Any: + """ + Get data for the given signal, according to the string and data + reduction settings. + + Parameters + ---------- + signal : ophyd.Signal + The signal. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + + Raises + ------ + TimeoutError + If the get operation times out. + """ + return await get_data_for_signal_async( + signal, + reduce_period=self.period, + reduce_method=self.method, + string=string or False, + executor=executor, + ) + + +def get_data_for_signal( + signal: ophyd.Signal, + reduce_period: Optional[Number] = None, + reduce_method: ReduceMethod = ReduceMethod.average, + string: bool = False, +) -> Any: + """ + Get data for the given signal, according to the string and data reduction + settings. + + Parameters + ---------- + signal : ophyd.Signal + The signal. + reduce_period : float, optional + Period over which the comparison will occur, where multiple samples may + be acquired prior to a result being available. + reduce_method : ReduceMethod, optional + Reduce collected samples by this reduce method. Ignored if + reduce_period unset. + string : bool, optional + If applicable, request and compare string values rather than the + default specified. + + Returns + ------- + Any + The acquired data. + + Raises + ------ + TimeoutError + If the get operation times out. + """ + if reduce_period is not None and reduce_period > 0: + return reduce_method.subscribe_and_reduce( + signal, reduce_period + ) + + if string: + return signal.get(as_string=True) + + return signal.get() + + +async def get_data_for_signal_async( + signal: ophyd.Signal, + reduce_period: Optional[Number] = None, + reduce_method: ReduceMethod = ReduceMethod.average, + string: bool = False, + *, + executor: Optional[concurrent.futures.Executor] = None +) -> Any: + """ + Get data for the given signal, according to the string and data + reduction settings. + + Parameters + ---------- + signal : ophyd.Signal + The signal. + reduce_period : float, optional + Period over which the comparison will occur, where multiple samples may + be acquired prior to a result being available. + reduce_method : ReduceMethod, optional + Reduce collected samples by this reduce method. Ignored if + reduce_period unset. + string : bool, optional + If applicable, request and compare string values rather than the + default specified. + executor : concurrent.futures.Executor, optional + The executor to run the synchronous call in. Defaults to + the loop-defined default executor. + + Returns + ------- + Any + The acquired data. + + Raises + ------ + TimeoutError + If the get operation times out. + """ + if reduce_period is not None and reduce_period > 0: + return await reduce_method.subscribe_and_reduce_async( + signal, reduce_period + ) + + def inner_sync_get(): + if string: + return signal.get(as_string=True) + + return signal.get() + + return await run_in_executor(executor, inner_sync_get) diff --git a/atef/tests/configs/device_based.yml b/atef/tests/configs/device_based.yml index fb5373a2..64b27e94 100644 --- a/atef/tests/configs/device_based.yml +++ b/atef/tests/configs/device_based.yml @@ -61,6 +61,7 @@ configs: description: Filter status unknown invert: true value: 3 + name: Checklist ids: - blade_01.state.state - blade_02.state.state diff --git a/atef/tests/configs/ping_localhost.json b/atef/tests/configs/ping_localhost.json new file mode 100644 index 00000000..1f9632ff --- /dev/null +++ b/atef/tests/configs/ping_localhost.json @@ -0,0 +1,40 @@ +{ + "configs": [ + { + "ToolConfiguration": { + "name": "Host alive check", + "description": "Tool", + "tags": null, + "tool": { + "Ping": { + "hosts": ["127.0.0.1", "localhost"], + "count": 1 + } + }, + "checklist": [ + { + "name": "Check that the ping time is good", + "ids": [ + "max_time", + "times.localhost" + ], + "comparisons": [ + { + "Less": { + "name": "Ping time OK", + "description": "Is the ping less than x ms?", + "invert": false, + "reduce_period": null, + "reduce_method": "average", + "string": null, + "severity_on_failure": 2, + "if_disconnected": 2, + "value": 1000 + } + } + ] + } + ] + } + }] +} diff --git a/atef/tests/configs/pv_based.yml b/atef/tests/configs/pv_based.yml index d3ab8b01..0aba8010 100644 --- a/atef/tests/configs/pv_based.yml +++ b/atef/tests/configs/pv_based.yml @@ -27,9 +27,11 @@ configs: - PVConfiguration: + name: PV Config 1 checklist: - comparisons: - Equals: + name: Equality check 1 value: 1 ids: - simple:A @@ -38,11 +40,14 @@ configs: tags: - a - PVConfiguration: + name: PV Config 2 checklist: - comparisons: - Equals: + name: Equality check 2 value: 1 ids: + - simple:A - simple:C name: pv_checks2 tags: diff --git a/atef/tests/test_commandline.py b/atef/tests/test_commandline.py index 4fc36d5f..20775678 100644 --- a/atef/tests/test_commandline.py +++ b/atef/tests/test_commandline.py @@ -4,8 +4,9 @@ import pytest import atef.bin.main as atef_main +from atef.bin import check as bin_check -from .. import config, util +from .. import util from .conftest import CONFIG_PATH from .test_comparison_device import at2l0, mock_signal_cache # noqa: F401 @@ -22,18 +23,23 @@ def test_help_module(monkeypatch, subcommand): atef_main.main() -def test_check_pv_smoke(monkeypatch, mock_signal_cache): # noqa: F811 - from atef.bin.check import main as check_main - monkeypatch.setattr(config, "get_signal_cache", lambda: mock_signal_cache) - check_main(filename=str(CONFIG_PATH / "pv_based.yml")) +@pytest.mark.asyncio +async def test_check_pv_smoke(mock_signal_cache): # noqa: F811 + await bin_check.main( + filename=str(CONFIG_PATH / "pv_based.yml"), signal_cache=mock_signal_cache + ) -def test_check_device_smoke(monkeypatch, at2l0): # noqa: F811 - from atef.bin import check as bin_check - +@pytest.mark.asyncio +async def test_check_device_smoke(monkeypatch, at2l0): # noqa: F811 def get_happi_device_by_name(name, client=None): return at2l0 monkeypatch.setattr(util, "get_happi_device_by_name", get_happi_device_by_name) monkeypatch.setattr(happi.Client, "from_config", lambda: None) - bin_check.main(filename=str(CONFIG_PATH / "device_based.yml")) + await bin_check.main(filename=str(CONFIG_PATH / "device_based.yml")) + + +@pytest.mark.asyncio +async def test_check_ping_localhost_smoke(): # noqa: F811 + await bin_check.main(filename=str(CONFIG_PATH / "ping_localhost.json"), verbose=2) diff --git a/atef/tests/test_comparison_device.py b/atef/tests/test_comparison_device.py index c350a607..67c34b0f 100644 --- a/atef/tests/test_comparison_device.py +++ b/atef/tests/test_comparison_device.py @@ -120,10 +120,11 @@ def test_serializable(config: DeviceConfiguration, severity: Severity): @config_and_severity -def test_result_severity( +@pytest.mark.asyncio +async def test_result_severity( device, config: DeviceConfiguration, severity: Severity ): - overall, _ = check_device(device=device, checklist=config.checklist) + overall, _ = await check_device(device=device, checklist=config.checklist) assert overall == severity @@ -159,7 +160,8 @@ def __getattr__(self, attr): return FakeAT2L0() -def test_at2l0_standin(at2l0): +@pytest.mark.asyncio +async def test_at2l0_standin(at2l0): state1: ophyd.Signal = getattr(at2l0, "blade_01.state.state") severity = { 0: Severity.error, @@ -196,12 +198,13 @@ def test_at2l0_standin(at2l0): ), ] - overall, results = check_device(at2l0, checklist=checklist) + overall, results = await check_device(at2l0, checklist=checklist) print("\n".join(res.reason or "n/a" for res in results)) assert overall == severity -def test_at2l0_standin_reduce(at2l0): +@pytest.mark.asyncio +async def test_at2l0_standin_reduce(at2l0): state1: ophyd.Signal = getattr(at2l0, "blade_01.state.state") state1.put(1.0) checklist = [ @@ -219,12 +222,13 @@ def test_at2l0_standin_reduce(at2l0): ), ] - overall, results = check_device(at2l0, checklist=checklist) + overall, results = await check_device(at2l0, checklist=checklist) print("\n".join(res.reason or "n/a" for res in results)) assert overall == Severity.success -def test_at2l0_standin_value_map(at2l0): +@pytest.mark.asyncio +async def test_at2l0_standin_value_map(at2l0): state1: ophyd.Signal = getattr(at2l0, "blade_01.state.state") value_to_severity = { 0: Severity.error, @@ -261,7 +265,7 @@ def test_at2l0_standin_value_map(at2l0): ) ] - overall, results = check_device(at2l0, checklist=checklist) + overall, results = await check_device(at2l0, checklist=checklist) print("\n".join(res.reason or "n/a" for res in results)) assert overall == severity @@ -275,6 +279,15 @@ def mock_signal_cache() -> cache._SignalCache[ophyd.sim.FakeEpicsSignalRO]: return mock_cache +@pytest.fixture +def data_cache( + mock_signal_cache: cache._SignalCache[ophyd.sim.FakeEpicsSignalRO], +) -> cache.DataCache: + return cache.DataCache( + signals=mock_signal_cache, + ) + + @pytest.mark.parametrize( "checklist, expected_severity", [ @@ -304,12 +317,13 @@ def mock_signal_cache() -> cache._SignalCache[ophyd.sim.FakeEpicsSignalRO]: ), ], ) -def test_pv_config( - mock_signal_cache: cache._SignalCache[ophyd.sim.FakeEpicsSignalRO], +@pytest.mark.asyncio +async def test_pv_config( + data_cache: cache.DataCache, checklist: List[IdentifierAndComparison], expected_severity: check.Severity ): - overall, _ = check_pvs(checklist, cache=mock_signal_cache) + overall, _ = await check_pvs(checklist, cache=data_cache) assert overall == expected_severity diff --git a/atef/tests/test_comparison_tools.py b/atef/tests/test_comparison_tools.py new file mode 100644 index 00000000..40fecd0f --- /dev/null +++ b/atef/tests/test_comparison_tools.py @@ -0,0 +1,175 @@ +from dataclasses import dataclass +from typing import Any, ClassVar, Optional + +import apischema +import pytest + +from .. import check, config, tools +from ..cache import DataCache +from ..check import Result, Severity +from ..config import IdentifierAndComparison, ToolConfiguration + +config_and_severity = pytest.mark.parametrize( + "conf, severity", + [ + pytest.param( + ToolConfiguration( + tool=tools.Ping( + hosts=["127.0.0.1"], + count=1, + ), + checklist=[ + IdentifierAndComparison( + ids=["max_time"], + comparisons=[check.LessOrEqual(value=1)] + ), + ] + ), + Severity.success, + id="all_good", + ), + pytest.param( + ToolConfiguration( + tool=tools.Ping( + hosts=["127.0.0.1"], + count=1, + ), + checklist=[ + IdentifierAndComparison( + ids=["max_time"], + comparisons=[check.Less(value=0.0)] + ), + ] + ), + Severity.error, + id="must_fail", + ), + ] +) + + +@config_and_severity +def test_serializable(conf: ToolConfiguration, severity: Severity): + serialized = apischema.serialize(conf) + assert apischema.deserialize(ToolConfiguration, serialized) == conf + + +@config_and_severity +@pytest.mark.asyncio +async def test_result_severity( + conf: ToolConfiguration, severity: Severity +): + overall, results = await config.check_tool(conf.tool, conf.checklist) + assert overall == severity + + +@pytest.mark.parametrize( + "tool, key, valid", + [ + (tools.Ping(), "max_time", True), + (tools.Ping(), "max_time.abc", False), + (tools.Ping(), "times.hostname", True), + (tools.Ping(), "badkey", False), + ] +) +def test_result_keys( + tool: tools.Tool, key: str, valid: bool +): + if not valid: + with pytest.raises(ValueError) as ex: + tool.check_result_key(key) + print("Failed check, as expected:\n", ex) + else: + tool.check_result_key(key) + + +@dataclass +class CustomToolResult(tools.ToolResult): + run_count: int + + +@dataclass +class CustomTool(tools.Tool): + result: ClassVar[Optional[CustomToolResult]] = None + + async def run(self) -> CustomToolResult: + print("Running custom tool...") + if CustomTool.result is None: + CustomTool.result = CustomToolResult(result=Result(), run_count=0) + + CustomTool.result.run_count += 1 + return CustomTool.result + + +@pytest.mark.asyncio +async def test_tool_cache(): + cache = DataCache() + tool = CustomTool() + first_data = await cache.get_tool_data(tool) + assert isinstance(first_data, CustomToolResult) + assert first_data.run_count == 1 + + second_data = await cache.get_tool_data(tool) + assert first_data is second_data + assert isinstance(second_data, CustomToolResult) + assert second_data.run_count == 1 + + +class _TestItem: + a = { + "b": [1, 2, 3] + } + + +@pytest.mark.parametrize( + "value, key, expected", + [ + # abc[1] = "b" + ("abc", "1", "b"), + # [1, 2, 3][1] = 2 + ([1, 2, 3], "1", 2), + # dict(a=dict(b="c"))["a"]["b"] = "c" + ({"a": {"b": "c"}}, "a.b", "c"), + # dict(a=dict(b="c"))["a"]["b"][1] = 2 + ({"a": {"b": [1, 2, 3]}}, "a.b.1", 2), + # _TestItem.a.b[1] + (_TestItem, "a.b.1", 2), + ] +) +def test_get_result_value_by_key( + value: Any, key: str, expected: Any +): + assert tools.get_result_value_by_key(value, key) == expected + + +@pytest.mark.parametrize( + "output, expected", + [ + pytest.param( + "Reply from 127.0.0.1: bytes=32 time<1ms TTL=128", + 1.0e-3, + id="win32_less", + ), + pytest.param( + "Reply from 127.0.0.1: bytes=32 time=10ms TTL=128", + 10e-3, + id="win32_equal", + ), + pytest.param( + "64 bytes from 1.1.1.1: icmp_seq=0 ttl=55 time=11.000 ms", + 11e-3, + id="macos", + ), + pytest.param( + "64 bytes from 1.1.1.1: icmp_seq=1 ttl=50 time=3.00 ms", + 3e-3, + id="linux", + ), + ], +) +def test_ping_regex( + output: str, + expected: float, +): + result = tools.PingResult.from_output("", output) + assert abs(result.max_time - expected) < 1e-6 diff --git a/atef/tools.py b/atef/tools.py new file mode 100644 index 00000000..40bcc32a --- /dev/null +++ b/atef/tools.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import re +import shutil +import sys +import typing +from dataclasses import dataclass, field +from typing import Any, ClassVar, Dict, List, Mapping, Sequence, TypeVar, Union + +from . import serialization +from .check import Result, Severity +from .exceptions import ToolDependencyMissingException + +T = TypeVar("T", bound="Tool") + + +@dataclass +class ToolResult: + """ + The base result dictionary of any tool. + """ + result: Result + + +@dataclass +class PingResult(ToolResult): + """ + The result dictionary of the 'ping' tool. + """ + #: Host(s) that are alive + alive: List[str] = field(default_factory=list) + #: Number of hosts that are alive. + num_alive: int = 0 + + #: Host(s) that are unresponsvie + unresponsive: List[str] = field(default_factory=list) + #: Number of hosts that are unresponsive. + num_unresponsive: int = 0 + + #: Host name to time taken. + times: Dict[str, float] = field(default_factory=dict) + #: Minimum time in seconds from ``times``. + min_time: float = 0.0 + #: Maximum time in seconds from ``times``. + max_time: float = 0.0 + + #: Time pattern for matching the ping output. + _time_re: ClassVar[re.Pattern] = re.compile(r"time[=<](.*)\s?ms") + + def add_host_result( + self, + host: str, + result: Union[PingResult, Exception], + *, + failure_time: float = 100.0 + ) -> None: + """ + Add a new per-host result to this aggregate one. + + Parameters + ---------- + host : str + The hostname or IP address. + result : Union[PingResult, Exception] + The result to add. Caught exceptions will be interpreted as a ping + failure for the given host. + failure_time : float, optional + The time to use when failures happen. + """ + if isinstance(result, Exception): + self.result = Result( + severity=Severity.error, + ) + self.unresponsive.append(host) + self.times[host] = failure_time + else: + self.unresponsive.extend(result.unresponsive) + self.alive.extend(result.alive) + self.times.update(result.times) + + times = self.times.values() + self.min_time = min(times) if times else 0.0 + self.max_time = max(times) if times else failure_time + + self.num_unresponsive = len(self.unresponsive) + self.num_alive = len(self.alive) + + @classmethod + def from_output( + cls, host: str, output: str, unresponsive_time: float = 100.0 + ) -> PingResult: + """ + Fill a PingResult from the results of the ping program. + + Parameters + ---------- + host : str + The hostname that ``ping`` was called with. + output : str + The decoded output of the subprocess call. + unresponsive_time : float, optional + Time to use for unresponsive or errored hosts. + + Returns + ------- + PingResult + """ + # NOTE: lazily ignoring non-millisecond-level results here; 1 second+ + # is the same as non-responsive if you ask me... + times = [float(ms) / 1000.0 for ms in PingResult._time_re.findall(output)] + + if not times: + return cls( + result=Result(severity=Severity.error), + alive=[], + unresponsive=[host], + min_time=unresponsive_time, + max_time=unresponsive_time, + times={host: unresponsive_time}, + ) + + return cls( + result=Result(severity=Severity.success), + alive=[host], + unresponsive=[], + min_time=min(times), + max_time=max(times), + times={host: sum(times) / len(times)}, + ) + + +def get_result_value_by_key(result: ToolResult, key: str) -> Any: + """ + Retrieve the value indicated by the dotted key name from the ToolResult. + + Supports attributes of generic types, items (for mappings as in + dictionaries), and iterables (by numeric index). + + Parameters + ---------- + result : object + The result dataclass instance. + key : str + The (optionally) dotted key name. + + Raises + ------ + KeyError + If the key is blank or otherwise invalid. + + Returns + ------- + Any + The data found by the key. + """ + if not key: + raise KeyError("No key provided") + + item = result + path = [] + key_parts = key.split(".") + + while key_parts: + key = key_parts.pop(0) + path.append(key) + try: + if isinstance(item, Mapping): + item = item[key] + elif isinstance(item, Sequence): + item = item[int(key)] + else: + item = getattr(item, key) + except KeyError: + path_str = ".".join(path) + raise KeyError( + f"{item} does not have key {key!r} ({path_str})" + ) from None + except AttributeError: + path_str = ".".join(path) + raise KeyError( + f"{item} does not have attribute {key!r} ({path_str})" + ) from None + except Exception: + path_str = ".".join(path) + raise KeyError( + f"{item} does not have {key!r} ({path_str})" + ) + + return item + + +@dataclass +@serialization.as_tagged_union +class Tool: + """ + Base class for atef tool checks. + """ + + def check_result_key(self, key: str) -> None: + """ + Check that the result ``key`` is valid for the given tool. + + For example, ``PingResult`` keys can include ``"min_time"``, + ``"max_time"``, and so on. + + Parameters + ---------- + key : str + The key to check. + + Raises + ------ + ValueError + If the key is invalid. + """ + top_level_key, *parts = key.split(".", 1) + # Use the return type of the tool's run() method to tell us the + # ToolResult type: + run_type: ToolResult = typing.get_type_hints(self.run)["return"] + # And then the keys that are defined in its definition: + result_type_hints = typing.get_type_hints(run_type) + valid_keys = list(result_type_hints) + if top_level_key not in valid_keys: + raise ValueError( + f"Invalid result key for tool {self}: {top_level_key!r}. Valid " + f"keys are: {', '.join(valid_keys)}" + ) + + if parts: + top_level_type = result_type_hints[top_level_key] + origin = typing.get_origin(top_level_type) + if origin is None or not issubclass(origin, (Mapping, Sequence)): + raise ValueError( + f"Invalid result key for tool {self}: {top_level_key!r} does " + f"not have sub-keys because it is of type {top_level_type}." + ) + + async def run(self, *args, **kwargs) -> ToolResult: + raise NotImplementedError("To be implemented by subclass") + + +@dataclass +class Ping(Tool): + """ + Tool for pinging one or more hosts and summarizing the results. + """ + #: The hosts to ping. + hosts: List[str] = field(default_factory=list) + #: The number of ping attempts to make per host. + count: int = 3 + #: The assumed output encoding of the 'ping' command. + encoding: str = "utf-8" + + #: Time to report when unresponsive [sec] + _unresponsive_time: ClassVar[float] = 100.0 + + async def ping(self, host: str) -> PingResult: + """ + Ping the given host. + + Parameters + ---------- + host : str + The host to ping. + + Returns + ------- + PingResult + """ + + # Ensure we don't ping forever: + count = min(self.count, 1) + + if sys.platform == "win32": + args = ("/n", str(count)) + else: + args = ("-c", str(count)) + + ping = shutil.which("ping") + + if ping is None: + raise ToolDependencyMissingException( + "The 'ping' binary is unavailable on the currently-defined " + "PATH" + ) + + proc = await asyncio.create_subprocess_exec( + ping, + *args, + host, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + assert proc.stdout is not None + output = await proc.stdout.read() + await proc.wait() + return PingResult.from_output(host, output.decode(self.encoding)) + + async def run(self) -> PingResult: + """ + Run the "Ping" tool with the current settings. + + Returns + ------- + PingResult + """ + result = PingResult(result=Result()) + + if not self.hosts: + return result + + ping_by_host: Dict[str, Union[Exception, PingResult]] = {} + + async def _ping(host: str) -> None: + try: + ping_by_host[host] = await self.ping(host) + except Exception as ex: + ping_by_host[host] = ex + + tasks = [asyncio.create_task(_ping(host)) for host in self.hosts] + + try: + await asyncio.wait(tasks) + except KeyboardInterrupt: + for task in tasks: + task.cancel() + raise + + for host, host_result in ping_by_host.items(): + result.add_host_result( + host, host_result, failure_time=self._unresponsive_time + ) + + return result diff --git a/atef/util.py b/atef/util.py index e7e57733..c5491856 100644 --- a/atef/util.py +++ b/atef/util.py @@ -1,7 +1,9 @@ +import asyncio +import concurrent.futures import functools import logging import pathlib -from typing import Optional, Sequence +from typing import Callable, Optional, Sequence import happi import ophyd @@ -85,3 +87,35 @@ def regex_for_devices(names: Optional[Sequence[str]]) -> str: """Get a regular expression that matches all the given device names.""" names = list(names or []) return "|".join(f"^{name}$" for name in names) + + +async def run_in_executor( + executor: Optional[concurrent.futures.Executor], + func: Callable, + *args, **kwargs +): + """ + Using the provided executor, run the function and return its value. + + Parameters + ---------- + executor : concurrent.futures.Executor or None + The executor to use. Defaults to the one from the running loop. + func : Callable + The function to run. + *args : + Arguments to pass. + **kwargs : + Keyword arguments to pass. + + Returns + ------- + Any + The value returned from func(). + """ + @functools.wraps(func) + def wrapped(): + return func(*args, **kwargs) + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(executor, wrapped) diff --git a/atef/yaml_support.py b/atef/yaml_support.py index d6109c6b..af5237e8 100644 --- a/atef/yaml_support.py +++ b/atef/yaml_support.py @@ -21,7 +21,8 @@ def str_enum_representer(dumper, data): return dumper.represent_str(data.value) # The ugliness of this makes me think we should use a different library - from . import enums, reduce + from . import enums, reduce, tools yaml.add_representer(enums.Severity, int_enum_representer) yaml.add_representer(reduce.ReduceMethod, str_enum_representer) + yaml.add_representer(tools.SupportedTool, str_enum_representer) diff --git a/dev-requirements.txt b/dev-requirements.txt index 8d18be6d..2a102c93 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,5 +12,7 @@ sphinx_rtd_theme doctr docs-versions-menu recommonmark +# test suite-specific +pytest-asyncio # for happi loading at lcls pcdsdevices