Skip to content

Commit

Permalink
Merge branch 'stream-spin-ui' of github.com:alan-turing-institute/reg…
Browse files Browse the repository at this point in the history
…inald into stream-spin-ui
  • Loading branch information
rchan26 committed Jun 12, 2024
2 parents 7efa703 + 7b2347c commit 8430c66
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions reginald/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
REGINAL_PROMPT: Final[str] = "Reginald: "


def stream_progress_wrapper(
streamer: Generator | list | tuple | Callable | chain,
def stream_iter_progress_wrapper(
streamer: Iterable | Callable | chain,
task_str: str = REGINAL_PROMPT,
progress_bar: bool = True,
end: str = "\n",
end: str = "",
*args,
**kwargs,
) -> chain | Generator | list | tuple | Callable:
) -> Iterable:
"""Add a progress bar for iteration.
Examples
Expand All @@ -26,14 +26,12 @@ def stream_progress_wrapper(
... for nap in range(naps):
... sleep(1)
... yield f'nap: {nap}'
>>> tuple(stream_progress_wrapper(streamer=sleeper))
>>> tuple(stream_iter_progress_wrapper(streamer=sleeper))
<BLANKLINE>
Reginald:
('nap: 0', 'nap: 1', 'nap: 2')
>>> tuple(stream_progress_wrapper(
Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
>>> tuple(stream_iter_progress_wrapper(
... streamer=sleeper, progress_bar=False))
Reginald:
('nap: 0', 'nap: 1', 'nap: 2')
Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
"""
if isinstance(streamer, Callable):
streamer = streamer(*args, **kwargs)
Expand All @@ -53,6 +51,42 @@ def stream_progress_wrapper(
return streamer


def stream_progress_wrapper(
streamer: Callable,
task_str: str = REGINAL_PROMPT,
progress_bar: bool = True,
end: str = "\n",
*args,
**kwargs,
) -> chain | Generator | list | tuple | Callable:
"""Add a progress bar for iteration.
Examples
--------
>>> from time import sleep
>>> def sleeper(seconds: int = 3) -> str:
... sleep(seconds)
... return f'{seconds} seconds nap'
>>> stream_progress_wrapper(sleeper)
<BLANKLINE>
Reginald:
'3 seconds nap'
"""
if progress_bar:
with Progress(
TextColumn("{task.description}[progress.description]"),
SpinnerColumn(),
transient=True,
) as progress:
progress.add_task(task_str)
results: Any = streamer(*args, **kwargs)
print(task_str, end=end)
return results
else:
print(task_str, end=end)
return streamer(*args, **kwargs)


def get_env_var(
var: str, log: bool = True, secret_value: bool = True, default: str = None
) -> str | None:
Expand Down

0 comments on commit 8430c66

Please sign in to comment.