Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add time remaining column to progress bars #7273

Merged
merged 10 commits into from
Apr 26, 2024
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,8 @@ def apply_function_over_dataset(
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)

with Progress(console=Console(theme=progressbar_theme)) as progress:
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar)
fonnesbeck marked this conversation as resolved.
Show resolved Hide resolved
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
fonnesbeck marked this conversation as resolved.
Show resolved Hide resolved
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ def _sample(
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, advance=1)
progress.update(task, advance=1, completed=True)
progress.update(task, refresh=True, advance=1)
progress.update(task, refresh=True, advance=1, completed=True)
except KeyboardInterrupt:
pass

Expand Down
6 changes: 5 additions & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -428,7 +428,10 @@ def __init__(
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self._show_progress = progressbar
self._divergences = 0
Expand Down Expand Up @@ -465,6 +468,7 @@ def __iter__(self):
self._divergences += 1
progress.update(
task,
refresh=True,
completed=self._completed_draws,
total=self._total_draws,
description=self._desc.format(self),
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle
import numpy as np

from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from pymc.backends.base import BaseTrace
from pymc.initial_point import PointType
Expand Down Expand Up @@ -104,7 +104,7 @@ def _sample_population(
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)

for _ in sampling:
progress.update(task, advance=1)
progress.update(task, advance=1, refresh=True)

return

Expand Down Expand Up @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
) as self._progress:
for c, stepper in enumerate(steppers):
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
Expand Down
14 changes: 12 additions & 2 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
import numpy as np

from arviz import InferenceData
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

import pymc

Expand Down Expand Up @@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with Progress(
TextColumn("{task.description}"),
SpinnerColumn(),
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
TextColumn("{task.fields[status]}"),
) as progress:
Expand Down Expand Up @@ -403,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
stage = update_data["stage"]
beta = update_data["beta"]
# update the progress bar for this task:
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)
progress.update(
status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True
)

return tuple(cloudpickle.loads(r.result()) for r in futures)
3 changes: 2 additions & 1 deletion pymc/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def find_MAP(
if isinstance(e, StopIteration):
pm._log.info(e)
finally:
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval)
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True)
print(file=sys.stdout)

mx0 = RaveledVars(mx0, x0.point_map_info)
Expand Down Expand Up @@ -223,6 +223,7 @@ def __init__(
*Progress.get_default_columns(),
TextColumn("{task.fields[loss]}"),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")

Expand Down
4 changes: 3 additions & 1 deletion pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def fit(
def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks):
i = 0
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Fitting", total=n, visible=progressbar)
for i in range(n):
step_func()
Expand Down
Loading