From 7b2347c29a6167265ea0f31c0a1674bc866528cd Mon Sep 17 00:00:00 2001 From: Dr Griffith Rees Date: Wed, 12 Jun 2024 17:17:33 +0100 Subject: [PATCH] feat: add `stream_iter_progress_wrapper` and refactor `stream_progress_wrapper` --- reginald/utils.py | 54 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/reginald/utils.py b/reginald/utils.py index e021cafe..00e9ad1e 100644 --- a/reginald/utils.py +++ b/reginald/utils.py @@ -9,14 +9,14 @@ REGINAL_PROMPT: Final[str] = "Reginald: " -def stream_progress_wrapper( - streamer: Generator | list | tuple | Callable | chain, +def stream_iter_progress_wrapper( + streamer: Iterable | Callable | chain, task_str: str = REGINAL_PROMPT, progress_bar: bool = True, - end: str = "\n", + end: str = "", *args, **kwargs, -) -> chain | Generator | list | tuple | Callable: +) -> Iterable: """Add a progress bar for iteration. Examples @@ -26,14 +26,12 @@ def stream_progress_wrapper( ... for nap in range(naps): ... sleep(1) ... yield f'nap: {nap}' - >>> tuple(stream_progress_wrapper(streamer=sleeper)) + >>> tuple(stream_iter_progress_wrapper(streamer=sleeper)) - Reginald: - ('nap: 0', 'nap: 1', 'nap: 2') - >>> tuple(stream_progress_wrapper( + Reginald: ('nap: 0', 'nap: 1', 'nap: 2') + >>> tuple(stream_iter_progress_wrapper( ... streamer=sleeper, progress_bar=False)) - Reginald: - ('nap: 0', 'nap: 1', 'nap: 2') + Reginald: ('nap: 0', 'nap: 1', 'nap: 2') """ if isinstance(streamer, Callable): streamer = streamer(*args, **kwargs) @@ -53,6 +51,42 @@ def stream_progress_wrapper( return streamer +def stream_progress_wrapper( + streamer: Callable, + task_str: str = REGINAL_PROMPT, + progress_bar: bool = True, + end: str = "\n", + *args, + **kwargs, +) -> chain | Generator | list | tuple | Callable: + """Add a progress bar for iteration. + + Examples + -------- + >>> from time import sleep + >>> def sleeper(seconds: int = 3) -> str: + ... sleep(seconds) + ... return f'{seconds} seconds nap' + >>> stream_progress_wrapper(sleeper) + + Reginald: + '3 seconds nap' + """ + if progress_bar: + with Progress( + TextColumn("{task.description}[progress.description]"), + SpinnerColumn(), + transient=True, + ) as progress: + progress.add_task(task_str) + results: Any = streamer(*args, **kwargs) + print(task_str, end=end) + return results + else: + print(task_str, end=end) + return streamer(*args, **kwargs) + + def get_env_var( var: str, log: bool = True, secret_value: bool = True, default: str = None ) -> str | None: