Skip to content

Commit

Permalink
improved eta
Browse files Browse the repository at this point in the history
  • Loading branch information
willmcgugan committed Mar 7, 2024
1 parent f55610e commit cb6b1fb
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 89 deletions.
96 changes: 96 additions & 0 deletions src/textual/eta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import bisect
from operator import itemgetter
from time import monotonic


class ETA:
"""Calculate speed and estimate time to arrival."""

def __init__(self, estimation_period: float = 30) -> None:
"""Create an ETA.
Args:
estimation_period: Period in seconds, used to calculate speed. Defaults to 30.
"""
self.estimation_period = estimation_period
self._start_time = monotonic()
self._samples: list[tuple[float, float]] = [(0.0, 0.0)]

def reset(self) -> None:
"""Start ETA calculations from current time."""
del self._samples[:]
self._start_time = monotonic()

@property
def _current_time(self) -> float:
return monotonic() - self._start_time

def add_sample(self, progress: float) -> None:
"""Add a new sample.
Args:
progress: Progress ratio (0 is start, 1 is complete).
"""
if self._samples and self._samples[-1][1] > progress:
# If progress goes backwards, we need to reset calculations
self.reset()
return
current_time = self._current_time
self._samples.append((current_time, progress))
if not (len(self._samples) % 100):
# Prune periodically so we don't accumulate vast amounts of samples
self._prune()

def _prune(self) -> None:
"""Prune old samples."""
if len(self._samples) <= 10:
# Keep at least 10 samples
return
prune_time = self._samples[-1][0] - self.estimation_period
index = bisect.bisect_left(self._samples, prune_time, key=itemgetter(0))
del self._samples[:index]

def _get_progress_at(self, time: float) -> tuple[float, float]:
"""Get the progress at a specific time."""
index = bisect.bisect_left(self._samples, time, key=itemgetter(0))
if index >= len(self._samples):
return self._samples[-1]
if index == 0:
return self._samples[0]
time1, progress1 = self._samples[index]
time2, progress2 = self._samples[index + 1]
factor = (time - time1) / (time2 - time1)
intermediate_progress = progress1 + (progress2 - progress1) * factor
return time, intermediate_progress

@property
def speed(self) -> float | None:
"""The current speed, or `None` if it couldn't be calculated."""

if len(self._samples) <= 2:
# Need at less 2 samples to calculate speed
return None

recent_sample_time, progress2 = self._samples[-1]
progress_start_time, progress1 = self._get_progress_at(
recent_sample_time - self.estimation_period
)
time_delta = recent_sample_time - progress_start_time
distance = progress2 - progress1
speed = distance / time_delta
return speed

@property
def eta(self) -> float | None:
"""Estimated seconds until completion, or `None` if no estimate can be made."""
current_time = self._current_time
if not self._samples:
return None
speed = self.speed
if not speed:
return None
recent_time, recent_progress = self._samples[-1]
time_since_sample = current_time - recent_time
remaining = 1.0 - (recent_progress + speed * time_since_sample)
eta = max(0, remaining / speed)
return eta
132 changes: 43 additions & 89 deletions src/textual/widgets/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from math import ceil
from time import monotonic
from typing import Callable, Optional
from typing import Optional

from rich.style import Style

from .._types import UnusedParameter
from ..app import ComposeResult, RenderResult
from ..eta import ETA
from ..geometry import clamp
from ..reactive import reactive
from ..renderables.bar import Bar as BarRenderable
from ..timer import Timer
from ..widget import Widget
from ..widgets import Label

Expand Down Expand Up @@ -80,7 +80,7 @@ def watch__percentage(self, percentage: float | None) -> None:
if percentage is not None:
self.auto_refresh = None
else:
self.auto_refresh = 1 / 15
self.auto_refresh = 1 / 5

def render(self) -> RenderResult:
"""Render the bar with the correct portion filled."""
Expand All @@ -105,6 +105,8 @@ def render_indeterminate(self) -> RenderResult:
# Width used to enable the visual effect of the bar going into the corners.
total_imaginary_width = width + highlighted_bar_width

start: float
end: float
if self.app.animation_level == "none":
start = 0
end = width
Expand Down Expand Up @@ -188,15 +190,7 @@ class ETAStatus(Label):
content-align-horizontal: right;
}
"""

_label_text: reactive[str] = reactive("", repaint=False)
"""This is used as an auxiliary reactive to only refresh the label when needed."""
_percentage: reactive[float | None] = reactive[Optional[float]](None)
"""The percentage of progress that has been completed."""
_refresh_timer: Timer | None
"""Timer to update ETA status even when progress stalls."""
_start_time: float | None
"""The time when the widget started tracking progress."""
eta: reactive[float | None] = reactive[Optional[float]](None)

def __init__(
self,
Expand All @@ -205,62 +199,24 @@ def __init__(
classes: str | None = None,
disabled: bool = False,
):
super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self._percentage = None
self._label_text = "--:--:--"
self._start_time = None
self._refresh_timer = None

def on_mount(self) -> None:
"""Periodically refresh the countdown so that the ETA is always up to date."""
self._refresh_timer = self.set_interval(1 / 2, self.update_eta, pause=True)
super().__init__(
"--:--:--", name=name, id=id, classes=classes, disabled=disabled
)

def watch__percentage(self, percentage: float | None) -> None:
if percentage is None:
self._label_text = "--:--:--"
else:
if self._refresh_timer is not None:
self._refresh_timer.reset()
self.update_eta()

def update_eta(self) -> None:
"""Update the ETA display."""
percentage = self._percentage
delta = self._get_elapsed_time()
# We display --:--:-- if we haven't started, if we are done,
# or if we don't know when we started keeping track of time.
if not percentage or percentage >= 1 or not delta:
self._label_text = "--:--:--"
# If we are done, we can delete the timer that periodically refreshes
# the countdown display.
if percentage is not None and percentage >= 1:
self.auto_refresh = None
# Render a countdown timer with hh:mm:ss, unless it's a LONG time.
def render(self) -> RenderResult:
"""Render the ETA display."""
eta = self.eta
if eta is None:
return "--:--:--"
else:
left = ceil((delta / percentage) * (1 - percentage))
minutes, seconds = divmod(left, 60)
minutes, seconds = divmod(ceil(eta), 60)
hours, minutes = divmod(minutes, 60)
if hours > 999999:
self._label_text = "+999999h"
return "+999999h"
elif hours > 99:
self._label_text = f"{hours}h"
return f"{hours}h"
else:
self._label_text = f"{hours:02}:{minutes:02}:{seconds:02}"

def _get_elapsed_time(self) -> float:
"""Get time to estimate time to progress completion.
Returns:
The time elapsed since the bar started being animated.
"""
if self._start_time is None:
self._start_time = monotonic()
return 0
return monotonic() - self._start_time

def watch__label_text(self, label_text: str) -> None:
"""If the ETA label changed, update the renderable (which also refreshes)."""
self.update(label_text)
return f"{hours:02}:{minutes:02}:{seconds:02}"


class ProgressBar(Widget, can_focus=False):
Expand Down Expand Up @@ -296,6 +252,7 @@ class ProgressBar(Widget, can_focus=False):
print(progress_bar.percentage) # 0.5
```
"""
_display_eta: reactive[float | None] = reactive[Optional[float]](None)

def __init__(
self,
Expand Down Expand Up @@ -334,39 +291,28 @@ def key_space(self):
disabled: Whether the widget is disabled or not.
"""
super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self.total = total
self.show_bar = show_bar
self.show_percentage = show_percentage
self.show_eta = show_eta
self._eta = ETA()

self.total = total

def compose(self) -> ComposeResult:
# We create a closure so that we can determine what are the sub-widgets
# that are present and, therefore, will need to be notified about changes
# to the percentage.
def update_percentage(
widget: Bar | PercentageStatus | ETAStatus,
) -> Callable[[float | None], None]:
"""Closure to allow updating the percentage of a given widget."""

def updater(percentage: float | None) -> None:
"""Update the percentage reactive of the enclosed widget."""
widget._percentage = percentage
def on_mount(self) -> None:
def refresh_eta() -> None:
"""Refresh eta display."""
self._display_eta = self._eta.eta

return updater
self.set_interval(1 / 2, refresh_eta)

def compose(self) -> ComposeResult:
if self.show_bar:
bar = Bar(id="bar")
self.watch(self, "percentage", update_percentage(bar))
yield bar
yield Bar(id="bar").data_bind(_percentage=ProgressBar.percentage)
if self.show_percentage:
percentage_status = PercentageStatus(id="percentage")
self.watch(self, "percentage", update_percentage(percentage_status))
yield percentage_status
PercentageStatus(id="percentage").data_bind(
_percentage=ProgressBar.percentage
)
if self.show_eta:
eta_status = ETAStatus(id="eta")
self.watch(self, "percentage", update_percentage(eta_status))
yield eta_status
yield ETAStatus(id="eta").data_bind(eta=ProgressBar._display_eta)

def _validate_progress(self, progress: float) -> float:
"""Clamp the progress between 0 and the maximum total."""
Expand Down Expand Up @@ -406,7 +352,7 @@ def advance(self, advance: float = 1) -> None:
Args:
advance: Number of steps to advance progress by.
"""
self.progress += advance
self.update(advance=advance)

def update(
self,
Expand All @@ -431,8 +377,16 @@ def update(
advance: Advance the progress by this number of steps.
"""
if not isinstance(total, UnusedParameter):
if total != self.total:
self._eta.reset()
self.total = total
if not isinstance(progress, UnusedParameter):

elif not isinstance(progress, UnusedParameter):
self.progress = progress
if not isinstance(advance, UnusedParameter):
if self.progress is not None and self.total is not None:
self._eta.add_sample(self.progress / self.total)

elif not isinstance(advance, UnusedParameter):
self.progress += advance
if self.progress is not None and self.total is not None:
self._eta.add_sample(self.progress / self.total)

0 comments on commit cb6b1fb

Please sign in to comment.