Skip to content

Commit

Permalink
Merge pull request #114 from klauer/enh_tools
Browse files Browse the repository at this point in the history
ENH: "tools" such as ping in passive checkout procedures
  • Loading branch information
ZLLentz authored Jul 23, 2022
2 parents 03daedd + d6d4e04 commit a7c875f
Show file tree
Hide file tree
Showing 17 changed files with 1,721 additions and 168 deletions.
64 changes: 57 additions & 7 deletions atef/bin/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import argparse
import asyncio
import dataclasses
import logging
import pathlib
Expand All @@ -14,6 +15,7 @@
import rich.console
import rich.tree

from ..cache import DataCache, _SignalCache, get_signal_cache
from ..check import Comparison, Result, Severity
from ..config import (AnyConfiguration, Configuration, ConfigurationFile,
PathItem, PreparedComparison)
Expand Down Expand Up @@ -52,6 +54,12 @@ def build_arg_parser(argparser=None):
help="Limit checkout to the named device(s) or identifiers",
)

argparser.add_argument(
"-p", "--parallel",
action="store_true",
help="Acquire data for comparisons in parallel",
)

return argparser


Expand Down Expand Up @@ -256,18 +264,55 @@ def log_results_rich(
console.print(root)


def check_and_log(
async def check_and_log(
config: AnyConfiguration,
console: rich.console.Console,
verbose: int = 0,
client: Optional[happi.Client] = None,
name_filter: Optional[Sequence[str]] = None,
parallel: bool = True,
cache: Optional[DataCache] = None,
):
"""Check a configuration and log the results."""
"""
Check a configuration and log the results.
Parameters
----------
config : AnyConfiguration
The configuration to check.
console : rich.console.Console
The rich console to write output to.
verbose : int, optional
The verbosity level for the output.
client : happi.Client, optional
The happi client, if available.
name_filter : Sequence[str], optional
A filter for names.
parallel : bool, optional
Pre-fill cache in parallel when possible.
cache : DataCache
The data cache instance.
"""
items = []
name_filter = list(name_filter or [])
severities = []
for prepared in PreparedComparison.from_config(config, client=client):

if cache is None:
cache = DataCache()

all_prepared = list(
PreparedComparison.from_config(config, client=client, cache=cache)
)

cache_fill_tasks = []
if parallel:
for prepared in all_prepared:
if isinstance(prepared, PreparedComparison):
cache_fill_tasks.append(
asyncio.create_task(prepared.get_data_async())
)

for prepared in all_prepared:
if isinstance(prepared, PreparedComparison):
if name_filter:
device_name = getattr(prepared.device, "name", None)
Expand All @@ -285,7 +330,7 @@ def check_and_log(
)
continue

prepared.result = prepared.compare()
prepared.result = await prepared.compare()
if prepared.result is not None:
items.append(prepared)
severities.append(prepared.result.severity)
Expand All @@ -312,12 +357,14 @@ def check_and_log(
)


def main(
async def main(
filename: str,
name_filter: Optional[Sequence[str]] = None,
verbose: int = 0,
parallel: bool = False,
*,
cleanup: bool = True
cleanup: bool = True,
signal_cache: Optional[_SignalCache] = None,
):
path = pathlib.Path(filename)
if path.suffix.lower() == ".json":
Expand All @@ -333,15 +380,18 @@ def main(
client = None

console = rich.console.Console()
cache = DataCache(signals=signal_cache or get_signal_cache())
try:
with console.status("[bold green] Performing checks..."):
for config in config_file.configs:
check_and_log(
await check_and_log(
config,
console=console,
verbose=verbose,
client=client,
name_filter=name_filter,
parallel=parallel,
cache=cache,
)
finally:
if cleanup:
Expand Down
7 changes: 6 additions & 1 deletion atef/bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
"""

import argparse
import asyncio
import importlib
import logging
from inspect import iscoroutinefunction

import atef

Expand Down Expand Up @@ -91,7 +93,10 @@ def main():
if hasattr(args, 'func'):
func = kwargs.pop('func')
logger.debug('%s(**%r)', func.__name__, kwargs)
func(**kwargs)
if iscoroutinefunction(func):
asyncio.run(func(**kwargs))
else:
func(**kwargs)
else:
top_parser.print_help()

Expand Down
Loading

0 comments on commit a7c875f

Please sign in to comment.