Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new method of specifying seed and using warnings.warn #3908

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions manim/_config/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def make_logger(
keywords=HIGHLIGHTED_KEYWORDS,
)

# redirect warnings.warn to logging.warn
logging.captureWarnings(True)
py_warning = logging.getLogger("py.warnings")
py_warning.addHandler(rich_handler)

# finally, the logger
logger = logging.getLogger("manim")
logger.addHandler(rich_handler)
Expand Down
35 changes: 29 additions & 6 deletions manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import threading
import time
import types
import warnings
from queue import Queue

import srt
Expand Down Expand Up @@ -96,19 +97,41 @@ class MyScene(Scene):
def construct(self):
self.play(Write(Text("Hello World!")))

You can specify the seed for random functions inside Manim via :attr:`Scene.random_seed`.

.. code-block:: python

class MyScene(Scene):
random_seed = 3

def construct(self): ...

"""

random_seed: int | None = None

def __init__(
self,
renderer=None,
camera_class=Camera,
always_update_mobjects=False,
random_seed=None,
skip_animations=False,
renderer: OpenGLRenderer | CairoRenderer | None = None,
camera_class: type[Camera] = Camera,
always_update_mobjects: bool = False,
random_seed: int | None = None,
skip_animations: bool = False,
):
self.camera_class = camera_class
self.always_update_mobjects = always_update_mobjects
self.random_seed = random_seed

if random_seed is not None:
warnings.warn(
"Setting the random seed in the Scene constructor is deprecated. "
"Please set the class attribute random_seed instead.",
category=FutureWarning,
# assuming no scene subclassing, this should refer
# to instantiation of the Scene
stacklevel=4,
)
self.random_seed = random_seed

self.skip_animations = skip_animations

self.animations = None
Expand Down
64 changes: 47 additions & 17 deletions manim/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
import inspect
import logging
import re
from collections.abc import Iterable
from typing import Any, Callable
import warnings
from collections.abc import Callable, Iterable
from typing import Any, TypeVar, cast, overload

from decorator import decorate, decorator
from typing_extensions import ParamSpec

logger = logging.getLogger("manim")

T = TypeVar("T")
P = ParamSpec("P")


def _get_callable_info(callable_: Callable, /) -> tuple[str, str]:
"""Returns type and name of a callable.
Expand Down Expand Up @@ -70,13 +75,33 @@ def _deprecation_text_component(
return f"deprecated {since}and {until}.{msg}"


@overload
def deprecated(
func: None = None,
since: str | None = None,
until: str | None = None,
replacement: str | None = None,
message: str = "",
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
Dismissed Show dismissed Hide dismissed


@overload
def deprecated(
func: Callable = None,
func: Callable[P, T],
since: str | None = None,
until: str | None = None,
replacement: str | None = None,
message: str | None = "",
) -> Callable:
message: str = "",
) -> Callable[P, T]: ...
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed


def deprecated(
func: Callable[P, T] | None = None,
since: str | None = None,
until: str | None = None,
replacement: str | None = None,
message: str = "",
) -> Callable[P, T] | Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to mark a callable as deprecated.

The decorated callable will cause a warning when used. The docstring of the
Expand Down Expand Up @@ -160,7 +185,11 @@ def foo():
"""
# If used as factory:
if func is None:
return lambda func: deprecated(func, since, until, replacement, message)

def wrapper(f: Callable[P, T]) -> Callable[P, T]:
return deprecated(f, since, until, replacement, message)

return wrapper

what, name = _get_callable_info(func)

Expand All @@ -187,7 +216,7 @@ def warning_msg(for_docs: bool = False) -> str:
deprecated = _deprecation_text_component(since, until, msg)
return f"The {what} {name} has been {deprecated}"

def deprecate_docs(func: Callable):
def deprecate_docs(func: Callable[..., object]) -> None:
"""Adjust docstring to indicate the deprecation.

Parameters
Expand All @@ -199,7 +228,7 @@ def deprecate_docs(func: Callable):
doc_string = func.__doc__ or ""
func.__doc__ = f"{doc_string}\n\n.. attention:: Deprecated\n {warning}"

def deprecate(func: Callable, *args, **kwargs):
def deprecate(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""The actual decorator used to extend the callables behavior.

Logs a warning message.
Expand All @@ -219,12 +248,12 @@ def deprecate(func: Callable, *args, **kwargs):
The return value of the given callable when being passed the given
arguments.
"""
logger.warning(warning_msg())
warnings.warn(warning_msg(), category=FutureWarning, stacklevel=1)
return func(*args, **kwargs)

if type(func).__name__ != "function":
deprecate_docs(func)
func.__init__ = decorate(func.__init__, deprecate)
func.__init__ = decorate(func.__init__, deprecate) # type: ignore[misc]
return func

func = decorate(func, deprecate)
Expand All @@ -236,10 +265,10 @@ def deprecated_params(
params: str | Iterable[str] | None = None,
since: str | None = None,
until: str | None = None,
message: str | None = "",
message: str = "",
redirections: None
| (Iterable[tuple[str, str] | Callable[..., dict[str, Any]]]) = None,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to mark parameters of a callable as deprecated.

It can also be used to automatically redirect deprecated parameter values to their
Expand Down Expand Up @@ -426,7 +455,7 @@ def foo(**kwargs):

redirections = list(redirections)

def warning_msg(func: Callable, used: list[str]):
def warning_msg(func: Callable, used: list[str]) -> str:
"""Generate the deprecation warning message.

Parameters
Expand All @@ -449,7 +478,7 @@ def warning_msg(func: Callable, used: list[str]):
deprecated = _deprecation_text_component(since, until, message)
return f"The parameter{parameter_s} {used_} of {what} {name} {has_have_been} {deprecated}"

def redirect_params(kwargs: dict, used: list[str]):
def redirect_params(kwargs: dict[str, Any], used: list[str]) -> None:
"""Adjust the keyword arguments as defined by the redirections.

Parameters
Expand All @@ -473,7 +502,7 @@ def redirect_params(kwargs: dict, used: list[str]):
if len(redirector_args) > 0:
kwargs.update(redirector(**redirector_args))

def deprecate_params(func, *args, **kwargs):
def deprecate_params(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""The actual decorator function used to extend the callables behavior.

Logs a warning message when a deprecated parameter is used and redirects it if
Expand Down Expand Up @@ -501,8 +530,9 @@ def deprecate_params(func, *args, **kwargs):
used.append(param)

if len(used) > 0:
logger.warning(warning_msg(func, used))
warnings.warn(warning_msg(func, used), category=FutureWarning, stacklevel=1)
redirect_params(kwargs, used)
return func(*args, **kwargs)

return decorator(deprecate_params)
# unfortunately, decorator has really bad stubs that involve Any
return cast(Callable, decorator(deprecate_params))
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ ignore_errors = True
[mypy-manim.utils.*]
ignore_errors = True

[mypy-manim.utils.deprecation]
ignore_errors = False

[mypy-manim.utils.iterables]
ignore_errors = False
warn_return_any = False
Expand Down
24 changes: 24 additions & 0 deletions tests/module/scene/test_scene.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import datetime
import random

import numpy as np
import pytest

from manim import Circle, FadeIn, Group, Mobject, Scene, Square
Expand Down Expand Up @@ -101,3 +103,25 @@ def assert_names(mobjs, names):
scene.replace(second, beta)
assert_names(scene.mobjects, ["alpha", "group", "fourth"])
assert_names(scene.mobjects[1], ["beta", "third"])


def test_random_seed():
class GoodScene(Scene):
random_seed = 3

class BadScene(Scene):
def __init__(self):
super().__init__(random_seed=3)

good = GoodScene()
assert good.random_seed == 3
random_random = random.random()
np_random = np.random.random()

with pytest.warns(FutureWarning):
bad = BadScene()
assert bad.random_seed == 3

# check that they both actually set the seed
assert random.random() == random_random
assert np.random.random() == np_random
Loading
Loading