Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster gradients #4739

Merged
merged 12 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

- `TextArea.line_number_start` reactive attribute https://github.com/Textualize/textual/pull/4471
- Added "quality" parameter to `textual.color.Gradient` https://github.com/Textualize/textual/pull/4739
- Added `textual.color.Gradient.get_rich_color` https://github.com/Textualize/textual/pull/4739

### Fixed

Expand Down
96 changes: 77 additions & 19 deletions src/textual/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@
from rich.color_triplet import ColorTriplet
from typing_extensions import Final

from textual.css.scalar import percentage_string_to_float
from textual.css.tokenize import CLOSE_BRACE, COMMA, DECIMAL, OPEN_BRACE, PERCENT
from textual.suggestions import get_suggestion

from ._color_constants import COLOR_NAME_TO_RGB
from .css.scalar import percentage_string_to_float
from .css.tokenize import CLOSE_BRACE, COMMA, DECIMAL, OPEN_BRACE, PERCENT
from .geometry import clamp
from .suggestions import get_suggestion

_TRUECOLOR = ColorType.TRUECOLOR

Expand Down Expand Up @@ -224,6 +223,7 @@ def clamped(self) -> Color:
return color

@property
@lru_cache(1024)
def rich_color(self) -> RichColor:
"""This color encoded in Rich's Color class.

Expand Down Expand Up @@ -551,25 +551,74 @@ def get_contrast_text(self, alpha: float = 0.95) -> Color:
class Gradient:
"""Defines a color gradient."""

def __init__(self, *stops: tuple[float, Color]) -> None:
def __init__(self, *stops: tuple[float, Color | str], quality: int = 200) -> None:
"""Create a color gradient that blends colors to form a spectrum.

A gradient is defined by a sequence of "stops" consisting of a float and a color.
The stop indicate the color at that point on a spectrum between 0 and 1.
A gradient is defined by a sequence of "stops" consisting of a tuple containing a float and a color.
The stop indicates the color at that point on a spectrum between 0 and 1.
Colors may be given as a [Color][textual.color.Color] instance, or a string that
can be parsed into a Color (with [Color.parse][textual.color.Color.parse]).

The quality of the argument defines the number of _steps_ in the gradient.
200 was chosen so that there was no obvious banding in [LinearGradient][textual.renderables.gradient.LinearGradient].
Higher values are unlikely to yield any benefit, but lower values may result in quicker rendering.

Args:
stops: A colors stop.
stops: Color stops.
quality: The number of steps in the gradient.

Raises:
ValueError: If any stops are missing (must be at least a stop for 0 and 1).
"""
self._stops = sorted(stops)
parse = Color.parse
self._stops = sorted(
[
(
(position, parse(color))
if isinstance(color, str)
else (position, color)
)
for position, color in stops
]
)
if len(stops) < 2:
raise ValueError("At least 2 stops required.")
if self._stops[0][0] != 0.0:
raise ValueError("First stop must be 0.")
if self._stops[-1][0] != 1.0:
raise ValueError("Last stop must be 1.")
self._quality = quality
self._colors: list[Color] | None = None
self._rich_colors: list[RichColor] | None = None

@property
def colors(self) -> list[Color]:
"""A list of colors in the gradient."""
position = 0
quality = self._quality

if self._colors is None:
colors: list[Color] = []
add_color = colors.append
(stop1, color1), (stop2, color2) = self._stops[0:2]
for step_position in range(quality):
step = step_position / (quality - 1)
while step > stop2:
position += 1
(stop1, color1), (stop2, color2) = self._stops[
position : position + 2
]
add_color(color1.blend(color2, (step - stop1) / (stop2 - stop1)))
self._colors = colors
assert len(self._colors) == self._quality
return self._colors

@property
def rich_colors(self) -> list[RichColor]:
"""A list of colors in the gradient (for the Rich library)."""
if self._rich_colors is None:
self._rich_colors = [color.rich_color for color in self.colors]
return self._rich_colors

def get_color(self, position: float) -> Color:
"""Get a color from the gradient at a position between 0 and 1.
Expand All @@ -580,17 +629,26 @@ def get_color(self, position: float) -> Color:
position: A number between 0 and 1, where 0 is the first stop, and 1 is the last.

Returns:
A color.
A Textual color.
"""
# TODO: consider caching
position = clamp(position, 0.0, 1.0)
for (stop1, color1), (stop2, color2) in zip(self._stops, self._stops[1:]):
if stop2 >= position >= stop1:
return color1.blend(
color2,
(position - stop1) / (stop2 - stop1),
)
raise AssertionError("Can't get here if `_stops` is valid")
quality = self._quality - 1
color_index = int(clamp(position * quality, 0, quality))
return self.colors[color_index]

def get_rich_color(self, position: float) -> RichColor:
"""Get a (Rich) color from the gradient at a position between 0 and 1.

Positions that are between stops will return a blended color.

Args:
position: A number between 0 and 1, where 0 is the first stop, and 1 is the last.

Returns:
A (Rich) color.
"""
quality = self._quality - 1
color_index = int(clamp(position * quality, 0, quality))
return self.rich_colors[color_index]


# Color constants
Expand Down
39 changes: 10 additions & 29 deletions src/textual/renderables/gradient.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

from functools import lru_cache
from math import cos, pi, sin
from typing import Sequence

from rich.color import Color as RichColor
from rich.console import Console, ConsoleOptions, RenderResult
from rich.segment import Segment
from rich.style import Style
Expand Down Expand Up @@ -59,6 +57,7 @@ def __init__(
(stop, Color.parse(color) if isinstance(color, str) else color)
for stop, color in stops
]
self._color_gradient = Gradient(*self._stops)

def __rich_console__(
self, console: Console, options: ConsoleOptions
Expand All @@ -75,47 +74,29 @@ def __rich_console__(

new_line = Segment.line()

color_gradient = Gradient(*self._stops)

_Segment = Segment
get_color = color_gradient.get_color
get_color = self._color_gradient.get_rich_color
from_color = Style.from_color

@lru_cache(maxsize=1024)
def get_rich_color(color_offset: int) -> RichColor:
"""Get a Rich color in the gradient.

Args:
color_index: A offset within the color gradient normalized between 0 and 255.

Returns:
A Rich color.
"""
return get_color(color_offset / 255).rich_color

for line_y in range(height):
point_y = float(line_y) * 2 - center_y
point_x = 0 - center_x

x1 = (center_x + (point_x * cos_angle - point_y * sin_angle)) / width * 255
x1 = (center_x + (point_x * cos_angle - point_y * sin_angle)) / width
x2 = (
(center_x + (point_x * cos_angle - (point_y + 1.0) * sin_angle))
/ width
* 255
)
center_x + (point_x * cos_angle - (point_y + 1.0) * sin_angle)
) / width
point_x = width - center_x
end_x1 = (
(center_x + (point_x * cos_angle - point_y * sin_angle)) / width * 255
)
end_x1 = (center_x + (point_x * cos_angle - point_y * sin_angle)) / width
delta_x = (end_x1 - x1) / width

if abs(delta_x) < 0.0001:
# Special case for verticals
yield _Segment(
"▀" * width,
from_color(
get_rich_color(int(x1)),
get_rich_color(int(x2)),
get_color(x1),
get_color(x2),
),
)

Expand All @@ -124,8 +105,8 @@ def get_rich_color(color_offset: int) -> RichColor:
_Segment(
"▀",
from_color(
get_rich_color(int(x1 + x * delta_x)),
get_rich_color(int(x2 + x * delta_x)),
get_color(x1 + x * delta_x),
get_color(x2 + x * delta_x),
),
)
for x in range(width)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ def test_gradient_errors():
def test_gradient():
gradient = Gradient(
(0, Color(255, 0, 0)),
(0.5, Color(0, 0, 255)),
(0.5, "blue"),
(1, Color(0, 255, 0)),
quality=11,
)

assert gradient.get_color(-1) == Color(255, 0, 0)
Expand All @@ -255,7 +256,3 @@ def test_gradient():
assert gradient.get_color(1.2) == Color(0, 255, 0)
assert gradient.get_color(0.5) == Color(0, 0, 255)
assert gradient.get_color(0.7) == Color(0, 101, 153)

gradient._stops.pop()
with pytest.raises(AssertionError):
gradient.get_color(1.0)
Loading