diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f8419466..4814c681 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.5 + rev: v18.1.6 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.7 + rev: v0.4.9 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -43,7 +43,7 @@ repos: hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: diff --git a/pyproject.toml b/pyproject.toml index ed93944a..d343e04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,7 @@ extend-exclude = ["examples"] select = [ "E", "W", # pycodestyle "F", # pyflakes + "C90", # mccabe "UP", # pyupgrade "ANN", # flake8-annotations "S", # flake8-bandit @@ -243,7 +244,10 @@ select = [ "COM", # flake8-commas "C4", # flake8-comprehensions "EXE", # flake8-executable + "FA", # flake8-future-annotations + "LOG", # flake8-logging "ISC", # flake8-implicit-str-concat + "INP", # flake8-no-pep420 "PIE", # flake8-pie "PYI", # flake8-pyi "Q", # flake8-quotes @@ -251,6 +255,10 @@ select = [ "RET", # flake8-return "SIM", # flake8-simplify "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "PERF", # perflint + "FURB", # refurb + "TRY", # tryceratops "RUF", # ruff ] ignore = [ @@ -268,9 +276,9 @@ ignore = [ # S101: use of `assert` detected # internal use and may never raise at runtime "S101", - # PLR0402: use from {module} import {name} in lieu of alias - # use alias for import convention (e.g., `import torch.nn as nn`) - "PLR0402", + # TRY003: avoid specifying long messages outside the exception class + # long messages are necessary for clarity + "TRY003", ] typing-modules = ["torchopt.typing"] @@ -296,6 +304,9 @@ typing-modules = ["torchopt.typing"] "F401", # unused-import "F811", # redefined-while-unused ] +"docs/source/conf.py" = [ + "INP001", # flake8-no-pep420 +] [tool.ruff.lint.flake8-annotations] allow-star-arg-any = true diff --git a/tests/helpers.py b/tests/helpers.py index 0dc415d4..ca5aa443 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,7 +20,7 @@ import itertools import os import random -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import numpy as np import pytest @@ -30,7 +30,10 @@ from torch.utils import data from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree BATCH_SIZE = 64 diff --git a/tests/test_alias.py b/tests/test_alias.py index 58b5a328..3c42d7c8 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch import pytest @@ -26,7 +26,10 @@ import torchopt from torchopt import pytree from torchopt.alias.utils import _set_use_chain_flat -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree @helpers.parametrize( diff --git a/tests/test_implicit.py b/tests/test_implicit.py index ff0ba15c..6cccb716 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -18,7 +18,7 @@ import copy import re from collections import OrderedDict -from types import FunctionType +from typing import TYPE_CHECKING import functorch import numpy as np @@ -47,6 +47,10 @@ HAS_JAX = False +if TYPE_CHECKING: + from types import FunctionType + + BATCH_SIZE = 8 NUM_UPDATES = 3 @@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader: inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_solve_normal_cg( +def test_imaml_solve_normal_cg( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -251,7 +255,7 @@ def outer_level(p, xs, ys): inner_update=[20, 50, 100], ns=[False, True], ) -def test_imaml_solve_inv( +def test_imaml_solve_inv( # noqa: C901 dtype: torch.dtype, lr: float, inner_lr: float, @@ -375,7 +379,12 @@ def outer_level(p, xs, ys): inner_lr=[2e-2, 2e-3], inner_update=[20, 50, 100], ) -def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None: +def test_imaml_module( # noqa: C901 + dtype: torch.dtype, + lr: float, + inner_lr: float, + inner_update: int, +) -> None: np_dtype = helpers.dtype_torch2numpy(dtype) jax_model, jax_params = get_model_jax(dtype=np_dtype) @@ -763,7 +772,7 @@ def solve(self): make_optimality_from_objective(MyModule2) -def test_module_abstract_methods() -> None: +def test_module_abstract_methods() -> None: # noqa: C901 class MyModule1(torchopt.nn.ImplicitMetaGradientModule): def objective(self): return torch.tensor(0.0) @@ -809,7 +818,7 @@ def solve(self): class MyModule5(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def optimality(self): + def optimality(cls): return () def solve(self): @@ -846,7 +855,7 @@ def solve(self): class MyModule8(torchopt.nn.ImplicitMetaGradientModule): @classmethod - def objective(self): + def objective(cls): return () def solve(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 5215e7b3..57c35e47 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +import operator + import torch import torchopt @@ -80,7 +82,7 @@ def test_module_clone() -> None: assert y.is_cuda -def test_extract_state_dict(): +def test_extract_state_dict(): # noqa: C901 fc = torch.nn.Linear(1, 1) state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta')) for param_dict in state_dict.params: @@ -121,7 +123,7 @@ def test_extract_state_dict(): loss = fc(torch.ones(1, 1)).sum() optim.step(loss) state_dict = torchopt.extract_state_dict(optim) - same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups)) + same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups)) assert all(pytree.tree_flatten(same)[0]) diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 5e568526..830072e3 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -81,50 +81,50 @@ __all__ = [ - 'accelerated_op_available', - 'adam', - 'adamax', - 'adadelta', - 'radam', - 'adamw', - 'adagrad', - 'rmsprop', - 'sgd', - 'clip_grad_norm', - 'nan_to_num', - 'register_hook', - 'chain', - 'Optimizer', 'SGD', - 'Adam', - 'AdaMax', - 'Adamax', 'AdaDelta', - 'Adadelta', - 'RAdam', - 'AdamW', 'AdaGrad', + 'AdaMax', + 'Adadelta', 'Adagrad', - 'RMSProp', - 'RMSprop', - 'MetaOptimizer', - 'MetaSGD', - 'MetaAdam', - 'MetaAdaMax', - 'MetaAdamax', + 'Adam', + 'AdamW', + 'Adamax', + 'FuncOptimizer', 'MetaAdaDelta', - 'MetaAdadelta', - 'MetaRAdam', - 'MetaAdamW', 'MetaAdaGrad', + 'MetaAdaMax', + 'MetaAdadelta', 'MetaAdagrad', + 'MetaAdam', + 'MetaAdamW', + 'MetaAdamax', + 'MetaOptimizer', + 'MetaRAdam', 'MetaRMSProp', 'MetaRMSprop', - 'FuncOptimizer', + 'MetaSGD', + 'Optimizer', + 'RAdam', + 'RMSProp', + 'RMSprop', + 'accelerated_op_available', + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', 'apply_updates', + 'chain', + 'clip_grad_norm', 'extract_state_dict', - 'recover_state_dict', - 'stop_gradient', 'module_clone', 'module_detach_', + 'nan_to_num', + 'radam', + 'recover_state_dict', + 'register_hook', + 'rmsprop', + 'sgd', + 'stop_gradient', ] diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 103b6fc0..90452046 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -16,12 +16,15 @@ from __future__ import annotations -from typing import Iterable +from typing import TYPE_CHECKING, Iterable import torch from torchopt.accelerated_op.adam_op import AdamOp -from torchopt.typing import Device + + +if TYPE_CHECKING: + from torchopt.typing import Device def is_available(devices: Device | Iterable[Device] | None = None) -> bool: @@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool: return False updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) - return True except Exception: # noqa: BLE001 # pylint: disable=broad-except return False + return True diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index bc999766..d7f9796d 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -18,7 +18,11 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch def forward_( diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 3cfb5b8b..5767c5d7 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -41,4 +41,13 @@ from torchopt.alias.sgd import sgd -__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd'] +__all__ = [ + 'adadelta', + 'adagrad', + 'adam', + 'adamax', + 'adamw', + 'radam', + 'rmsprop', + 'sgd', +] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py index fb0b551a..910cb13e 100644 --- a/torchopt/alias/adadelta.py +++ b/torchopt/alias/adadelta.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adadelta -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adadelta'] diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index 9419e908..0ae0eb8e 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -33,6 +33,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -40,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adam'] diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py index f80c0c2f..3da16713 100644 --- a/torchopt/alias/adamax.py +++ b/torchopt/alias/adamax.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_adamax -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['adamax'] diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 38d4d5ac..2dc72ef1 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -33,7 +33,7 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable from torchopt.alias.utils import ( _get_use_chain_flat, @@ -42,7 +42,10 @@ ) from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py index 56d3d3d5..9e2880ee 100644 --- a/torchopt/alias/radam.py +++ b/torchopt/alias/radam.py @@ -16,6 +16,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt.alias.utils import ( _get_use_chain_flat, flip_sign_and_add_weight_decay, @@ -23,7 +25,10 @@ ) from torchopt.combine import chain from torchopt.transform import scale_by_radam -from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +if TYPE_CHECKING: + from torchopt.typing import GradientTransformation, ScalarOrSchedule __all__ = ['radam'] diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 49f8784d..0f41e822 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -16,14 +16,18 @@ from __future__ import annotations import threading - -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform import scale, scale_by_schedule from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] @@ -68,7 +72,7 @@ def _flip_sign_and_add_weight_decay_flat( ) -def _flip_sign_and_add_weight_decay( +def _flip_sign_and_add_weight_decay( # noqa: C901 weight_decay: float = 0.0, maximize: bool = False, *, diff --git a/torchopt/base.py b/torchopt/base.py index 572708e2..81892e17 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -44,10 +44,10 @@ __all__ = [ + 'ChainedGradientTransformation', 'EmptyState', - 'UninitializedState', 'GradientTransformation', - 'ChainedGradientTransformation', + 'UninitializedState', 'identity', ] diff --git a/torchopt/clip.py b/torchopt/clip.py index 55ae83fc..d64afc58 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -19,11 +19,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['clip_grad_norm'] diff --git a/torchopt/combine.py b/torchopt/combine.py index 158ec982..15345286 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -33,9 +33,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['chain', 'chain_flat'] diff --git a/torchopt/diff/implicit/__init__.py b/torchopt/diff/implicit/__init__.py index 21737015..4cff14c6 100644 --- a/torchopt/diff/implicit/__init__.py +++ b/torchopt/diff/implicit/__init__.py @@ -19,4 +19,4 @@ from torchopt.diff.implicit.nn import ImplicitMetaGradientModule -__all__ = ['custom_root', 'ImplicitMetaGradientModule'] +__all__ = ['ImplicitMetaGradientModule', 'custom_root'] diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index d3efda2c..11ba0153 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -37,20 +37,23 @@ import functools import inspect -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple import functorch import torch from torch.autograd import Function from torchopt import linear_solve, pytree -from torchopt.typing import ( - ListOfOptionalTensors, - ListOfTensors, - TensorOrTensors, - TupleOfOptionalTensors, - TupleOfTensors, -) + + +if TYPE_CHECKING: + from torchopt.typing import ( + ListOfOptionalTensors, + ListOfTensors, + TensorOrTensors, + TupleOfOptionalTensors, + TupleOfTensors, + ) __all__ = ['custom_root'] @@ -253,7 +256,7 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements -def _custom_root( +def _custom_root( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], @@ -271,7 +274,7 @@ def _custom_root( fn = getattr(reference_signature, 'subfn', reference_signature) reference_signature = inspect.signature(fn) - def make_custom_vjp_solver_fn( + def make_custom_vjp_solver_fn( # noqa: C901 solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], args_signs: tuple[tuple[int, int, type[tuple | list] | None], ...], diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index 8719f675..6b214cb8 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -22,15 +22,19 @@ import functools import inspect import itertools -from typing import Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable import functorch -import torch from torchopt.diff.implicit.decorator import custom_root from torchopt.nn.module import MetaGradientModule from torchopt.nn.stateless import reparametrize, swap_state -from torchopt.typing import LinearSolver, TupleOfTensors + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import LinearSolver, TupleOfTensors __all__ = ['ImplicitMetaGradientModule'] diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index f00e097a..4369f4e5 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -25,7 +25,7 @@ from torchopt.diff.zero_order.nn import ZeroOrderGradientModule -__all__ = ['zero_order', 'ZeroOrderGradientModule'] +__all__ = ['ZeroOrderGradientModule', 'zero_order'] class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index b1126636..e498b43c 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -17,6 +17,7 @@ from __future__ import annotations import functools +import itertools from typing import Any, Callable, Literal, Sequence from typing_extensions import TypeAlias # Python 3.10+ @@ -43,7 +44,7 @@ def sample( return self.sample_fn(sample_shape) -def _zero_order_naive( # pylint: disable=too-many-statements +def _zero_order_naive( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -51,7 +52,7 @@ def _zero_order_naive( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -81,7 +82,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) @@ -122,9 +123,9 @@ def add_perturbation( for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -149,7 +150,7 @@ def add_perturbation( return apply -def _zero_order_forward( # pylint: disable=too-many-statements +def _zero_order_forward( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -157,7 +158,7 @@ def _zero_order_forward( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -187,7 +188,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors, output) @@ -226,9 +227,9 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] - ] + flat_noisy_params = list( + itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params, @@ -254,7 +255,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return apply -def _zero_order_antithetic( # pylint: disable=too-many-statements +def _zero_order_antithetic( # noqa: C901 # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, argnums: tuple[int, ...], @@ -262,7 +263,7 @@ def _zero_order_antithetic( # pylint: disable=too-many-statements sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) - def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements + def apply(*args: Any) -> torch.Tensor: # noqa: C901 # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] @@ -292,7 +293,7 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: output = fn(*origin_args) if not isinstance(output, torch.Tensor): - raise RuntimeError('`output` must be a tensor.') + raise TypeError('`output` must be a tensor.') if output.ndim != 0: raise RuntimeError('`output` must be a scalar tensor.') ctx.save_for_backward(*flat_diff_params, *tensors) diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index 7ac12bb4..eeddabeb 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -20,14 +20,17 @@ import abc import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import torch import torch.nn as nn from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order from torchopt.nn.stateless import reparametrize -from torchopt.typing import Numeric, TupleOfTensors + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, TupleOfTensors __all__ = ['ZeroOrderGradientModule'] diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 117af9ab..97be682f 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -42,15 +42,15 @@ __all__ = [ 'TensorDimensionPartitioner', - 'dim_partitioner', 'batch_partitioner', + 'dim_partitioner', 'mean_reducer', - 'sum_reducer', - 'remote_async_call', - 'remote_sync_call', 'parallelize', 'parallelize_async', 'parallelize_sync', + 'remote_async_call', + 'remote_sync_call', + 'sum_reducer', ] @@ -107,7 +107,7 @@ def __init__( self.workers = workers # pylint: disable-next=too-many-branches,too-many-locals - def __call__( + def __call__( # noqa: C901 self, *args: Any, **kwargs: Any, @@ -310,7 +310,7 @@ def remote_async_call( elif callable(partitioner): partitions = partitioner(*args, **kwargs) # type: ignore[assignment] else: - raise ValueError(f'Invalid partitioner: {partitioner!r}.') + raise TypeError(f'Invalid partitioner: {partitioner!r}.') futures = [] for rank, worker_args, worker_kwargs in partitions: diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index f7da4f46..71afdb86 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -17,15 +17,18 @@ from __future__ import annotations from threading import Lock +from typing import TYPE_CHECKING import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors +if TYPE_CHECKING: + from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors -__all__ = ['is_available', 'context'] + +__all__ = ['context', 'is_available'] LOCK = Lock() @@ -121,7 +124,7 @@ def grad( for p in inputs: try: grads.append(all_local_grads[p]) - except KeyError as ex: + except KeyError as ex: # noqa: PERF203 if not allow_unused: raise RuntimeError( 'One of the differentiated Tensors appears to not have been used in the ' @@ -131,4 +134,4 @@ def grad( return tuple(grads) - __all__ += ['DistAutogradContext', 'get_gradients', 'backward', 'grad'] + __all__ += ['DistAutogradContext', 'backward', 'get_gradients', 'grad'] diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index a61280c5..610e52a0 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -26,19 +26,19 @@ __all__ = [ - 'get_world_info', - 'get_world_rank', - 'get_rank', - 'get_world_size', + 'auto_init_rpc', + 'barrier', 'get_local_rank', 'get_local_world_size', + 'get_rank', 'get_worker_id', - 'barrier', - 'auto_init_rpc', - 'on_rank', + 'get_world_info', + 'get_world_rank', + 'get_world_size', 'not_on_rank', - 'rank_zero_only', + 'on_rank', 'rank_non_zero_only', + 'rank_zero_only', ] diff --git a/torchopt/hook.py b/torchopt/hook.py index b51e29eb..c11b92f6 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -16,16 +16,19 @@ from __future__ import annotations -from typing import Callable - -import torch +from typing import TYPE_CHECKING, Callable from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates -__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['nan_to_num_hook', 'register_hook', 'zero_nan_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index a82ff877..1096a5af 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -36,14 +36,17 @@ from __future__ import annotations from functools import partial -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import cat_shapes, normalize_matvec from torchopt.pytree import tree_vdot_real -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['cg'] diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index b049a5ad..5fc8d478 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -19,13 +19,16 @@ from __future__ import annotations import functools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree from torchopt.linalg.utils import normalize_matvec -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree __all__ = ['ns', 'ns_inv'] diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index a5ac765d..bbcc80aa 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -17,12 +17,15 @@ from __future__ import annotations import itertools -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def cat_shapes(tree: TensorTree) -> tuple[int, ...]: diff --git a/torchopt/linear_solve/__init__.py b/torchopt/linear_solve/__init__.py index 2d61eb6d..43ca1da0 100644 --- a/torchopt/linear_solve/__init__.py +++ b/torchopt/linear_solve/__init__.py @@ -36,4 +36,4 @@ from torchopt.linear_solve.normal_cg import solve_normal_cg -__all__ = ['solve_cg', 'solve_normal_cg', 'solve_inv'] +__all__ = ['solve_cg', 'solve_inv', 'solve_normal_cg'] diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index f4127639..23814cc2 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_cg'] diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index f37be8c5..4dbe1542 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -36,13 +36,16 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import torch from torchopt import linalg, pytree from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_inv'] diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 405ab43c..a5af49b2 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -36,11 +36,14 @@ from __future__ import annotations import functools -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec -from torchopt.typing import LinearSolver, TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_normal_cg'] diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 5e4bf7bd..9d1b8779 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -33,12 +33,15 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import functorch from torchopt import pytree -from torchopt.typing import TensorTree + + +if TYPE_CHECKING: + from torchopt.typing import TensorTree def make_rmatvec( diff --git a/torchopt/nn/__init__.py b/torchopt/nn/__init__.py index 7665f201..b55e49d7 100644 --- a/torchopt/nn/__init__.py +++ b/torchopt/nn/__init__.py @@ -21,10 +21,10 @@ __all__ = [ - 'MetaGradientModule', 'ImplicitMetaGradientModule', + 'MetaGradientModule', 'ZeroOrderGradientModule', - 'reparametrize', 'reparameterize', + 'reparametrize', 'swap_state', ] diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 419afb6a..8c40f58a 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -17,14 +17,17 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any, Iterator, NamedTuple +from typing import TYPE_CHECKING, Any, Iterator, NamedTuple from typing_extensions import Self # Python 3.11+ import torch import torch.nn as nn from torchopt import pytree -from torchopt.typing import TensorContainer + + +if TYPE_CHECKING: + from torchopt.typing import TensorContainer class MetaInputsContainer(NamedTuple): @@ -61,7 +64,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused """Initialize a new module instance.""" super().__init__() - def __getattr__(self, name: str) -> torch.Tensor | nn.Module: + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: # noqa: C901 """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] @@ -86,7 +89,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: # noqa: C901 """Set an attribute of the module.""" def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None: diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index d3437d0d..c7f92b86 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -17,13 +17,15 @@ from __future__ import annotations import contextlib -from typing import Generator, Iterable +from typing import TYPE_CHECKING, Generator, Iterable -import torch -import torch.nn as nn +if TYPE_CHECKING: + import torch + import torch.nn as nn -__all__ = ['swap_state', 'reparametrize', 'reparameterize'] + +__all__ = ['reparameterize', 'reparametrize', 'swap_state'] MISSING: torch.Tensor = object() # type: ignore[assignment] diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py index a64e00e4..600b69c5 100644 --- a/torchopt/optim/adadelta.py +++ b/torchopt/optim/adadelta.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaDelta', 'Adadelta'] diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 277b7105..06091281 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaGrad', 'Adagrad'] diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 6ff68a69..555af22e 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['Adam'] diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py index f693723c..e4996e85 100644 --- a/torchopt/optim/adamax.py +++ b/torchopt/optim/adamax.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['AdaMax', 'Adamax'] diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 463f245f..a60061ea 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable, Iterable - -import torch +from typing import TYPE_CHECKING, Callable, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7bb27877..fa287f04 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -16,13 +16,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torchopt.base import GradientTransformation, UninitializedState -from torchopt.typing import OptState, Params from torchopt.update import apply_updates +if TYPE_CHECKING: + from torchopt.typing import OptState, Params + + __all__ = ['FuncOptimizer'] diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py index 49bdf23c..eb386ae3 100644 --- a/torchopt/optim/meta/adadelta.py +++ b/torchopt/optim/meta/adadelta.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaDelta', 'MetaAdadelta'] diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 58d913aa..129c1338 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaGrad', 'MetaAdagrad'] diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index bac71790..7a78ea7f 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdam'] diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py index 568a46f7..d6b40427 100644 --- a/torchopt/optim/meta/adamax.py +++ b/torchopt/optim/meta/adamax.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaMax', 'MetaAdamax'] diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index 05387b77..62864582 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Callable - -import torch.nn as nn +from typing import TYPE_CHECKING, Callable from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import OptState, Params, ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py index a32670d0..bb07b5ba 100644 --- a/torchopt/optim/meta/radam.py +++ b/torchopt/optim/meta/radam.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch.nn as nn +from typing import TYPE_CHECKING from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch.nn as nn + + from torchopt.typing import ScalarOrSchedule __all__ = ['MetaRAdam'] diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py index bba8c0d4..20e9dd22 100644 --- a/torchopt/optim/radam.py +++ b/torchopt/optim/radam.py @@ -16,13 +16,16 @@ from __future__ import annotations -from typing import Iterable - -import torch +from typing import TYPE_CHECKING, Iterable from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import ScalarOrSchedule + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import ScalarOrSchedule __all__ = ['RAdam'] diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 6adea0e8..53abc2d2 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -18,7 +18,7 @@ import functools import operator -from typing import Callable +from typing import TYPE_CHECKING, Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -26,7 +26,9 @@ import torch.distributed.rpc as rpc from optree import * # pylint: disable=wildcard-import,unused-wildcard-import -from torchopt.typing import Future, RRef, Scalar, T, TensorTree + +if TYPE_CHECKING: + from torchopt.typing import Future, RRef, Scalar, T, TensorTree __all__ = [ diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index 8e5545a4..d3d3eff5 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -35,4 +35,4 @@ from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule -__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule'] +__all__ = ['exponential_decay', 'linear_schedule', 'polynomial_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 0925e164..c19c54b9 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -31,11 +31,15 @@ # ============================================================================== """Exponential learning rate decay.""" +from __future__ import annotations + import logging import math -from typing import Optional +from typing import TYPE_CHECKING + -from torchopt.typing import Numeric, Scalar, Schedule +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule __all__ = ['exponential_decay'] @@ -48,7 +52,7 @@ def exponential_decay( transition_begin: int = 0, transition_steps: int = 1, staircase: bool = False, - end_value: Optional[float] = None, + end_value: float | None = None, ) -> Schedule: """Construct a schedule with either continuous or discrete exponential decay. diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 2482f769..d2a5160c 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -31,15 +31,20 @@ # ============================================================================== """Polynomial learning rate schedules.""" +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import numpy as np import torch -from torchopt.typing import Numeric, Scalar, Schedule + +if TYPE_CHECKING: + from torchopt.typing import Numeric, Scalar, Schedule -__all__ = ['polynomial_schedule', 'linear_schedule'] +__all__ = ['linear_schedule', 'polynomial_schedule'] def polynomial_schedule( diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index adef5596..fa59a43b 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -46,18 +46,18 @@ __all__ = [ - 'trace', - 'scale', - 'scale_by_schedule', 'add_decayed_weights', 'masked', + 'nan_to_num', + 'scale', + 'scale_by_accelerated_adam', + 'scale_by_adadelta', 'scale_by_adam', 'scale_by_adamax', - 'scale_by_adadelta', 'scale_by_radam', - 'scale_by_accelerated_adam', - 'scale_by_rss', 'scale_by_rms', + 'scale_by_rss', + 'scale_by_schedule', 'scale_by_stddev', - 'nan_to_num', + 'trace', ] diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 950682cf..0cb67837 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -34,17 +34,20 @@ from __future__ import annotations -from typing import Any, Callable, NamedTuple - -import torch +from typing import TYPE_CHECKING, Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates -__all__ = ['masked', 'add_decayed_weights'] +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['add_decayed_weights', 'masked'] class MaskedState(NamedTuple): @@ -189,7 +192,7 @@ def _add_decayed_weights_flat( ) -def _add_decayed_weights( +def _add_decayed_weights( # noqa: C901 weight_decay: float = 0.0, mask: OptState | Callable[[Params], OptState] | None = None, *, diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index d3530853..740df1b0 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -16,11 +16,16 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates def nan_to_num( diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 493b7196..2b492bdf 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -33,12 +33,17 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import OptState, Params, Updates __all__ = ['scale'] diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index f389d293..6d05e5dd 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adadelta'] diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index b08c6a14..d45d1eb2 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -35,7 +35,7 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch @@ -43,10 +43,13 @@ from torchopt.accelerated_op import AdamOp from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates -__all__ = ['scale_by_adam', 'scale_by_accelerated_adam'] +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_accelerated_adam', 'scale_by_adam'] TRIPLE_PYTREE_SPEC = pytree.tree_structure((0, 1, 2), none_is_leaf=True) # type: ignore[arg-type] @@ -277,7 +280,7 @@ def _scale_by_accelerated_adam_flat( # pylint: disable-next=too-many-arguments -def _scale_by_accelerated_adam( +def _scale_by_accelerated_adam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index f11ed311..cfacbf35 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -19,14 +19,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_adamax'] diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py index fad32b13..95f26149 100644 --- a/torchopt/transform/scale_by_radam.py +++ b/torchopt/transform/scale_by_radam.py @@ -20,14 +20,17 @@ from __future__ import annotations import math -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_radam'] @@ -89,7 +92,7 @@ def _scale_by_radam_flat( ) -def _scale_by_radam( +def _scale_by_radam( # noqa: C901 b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 4ee67ed0..f2141388 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rms'] diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 9bc97206..642b2e5c 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rss'] diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 48f3f271..499e2adb 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -33,14 +33,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ -from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates + + +if TYPE_CHECKING: + from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates __all__ = ['scale_by_schedule'] diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 6b99f31a..5a3e6655 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_, update_moment -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_stddev'] diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 9bf37e2f..219cbbec 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -35,14 +35,17 @@ from __future__ import annotations -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple import torch from torchopt import pytree from torchopt.base import GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Updates + + +if TYPE_CHECKING: + from torchopt.typing import OptState, Params, Updates __all__ = ['trace'] @@ -101,7 +104,7 @@ def _trace_flat( ) -def _trace( +def _trace( # noqa: C901 momentum: float = 0.9, dampening: float = 0.0, nesterov: bool = False, @@ -136,7 +139,7 @@ def init_fn(params: Params) -> OptState: first_call = True - def update_fn( + def update_fn( # noqa: C901 updates: Updates, state: OptState, *, diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index ec4e51c1..9b38d561 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -34,15 +34,18 @@ from __future__ import annotations from collections import deque -from typing import Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence import torch from torchopt import pytree -from torchopt.typing import TensorTree, Updates -__all__ = ['tree_map_flat', 'tree_map_flat_', 'inc_count', 'update_moment'] +if TYPE_CHECKING: + from torchopt.typing import TensorTree, Updates + + +__all__ = ['inc_count', 'tree_map_flat', 'tree_map_flat_', 'update_moment'] INT64_MAX = torch.iinfo(torch.int64).max @@ -161,7 +164,7 @@ def _update_moment_flat( # pylint: disable-next=too-many-arguments -def _update_moment( +def _update_moment( # noqa: C901 updates: Updates, moments: TensorTree, decay: float, diff --git a/torchopt/typing.py b/torchopt/typing.py index 60d11e0e..fcd888fb 100644 --- a/torchopt/typing.py +++ b/torchopt/typing.py @@ -14,6 +14,8 @@ # ============================================================================== """Typing utilities.""" +from __future__ import annotations + import abc from typing import ( Callable, @@ -45,39 +47,39 @@ __all__ = [ - 'GradientTransformation', 'ChainedGradientTransformation', + 'Device', + 'Distribution', 'EmptyState', - 'UninitializedState', - 'Params', - 'Updates', + 'Future', + 'GradientTransformation', + 'LinearSolver', + 'ListOfOptionalTensors', + 'ListOfTensors', + 'ModuleTensorContainers', + 'Numeric', 'OptState', + 'OptionalTensor', + 'OptionalTensorOrOptionalTensors', + 'OptionalTensorTree', + 'Params', + 'PyTree', + 'Samplable', + 'SampleFunc', 'Scalar', - 'Numeric', - 'Schedule', 'ScalarOrSchedule', - 'PyTree', - 'Tensor', - 'OptionalTensor', - 'ListOfTensors', - 'TupleOfTensors', + 'Schedule', + 'SequenceOfOptionalTensors', 'SequenceOfTensors', + 'Size', + 'Tensor', + 'TensorContainer', 'TensorOrTensors', 'TensorTree', - 'ListOfOptionalTensors', 'TupleOfOptionalTensors', - 'SequenceOfOptionalTensors', - 'OptionalTensorOrOptionalTensors', - 'OptionalTensorTree', - 'TensorContainer', - 'ModuleTensorContainers', - 'Future', - 'LinearSolver', - 'Device', - 'Size', - 'Distribution', - 'SampleFunc', - 'Samplable', + 'TupleOfTensors', + 'UninitializedState', + 'Updates', ] T = TypeVar('T') @@ -138,7 +140,7 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods def sample( self, sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument - ) -> Union[Tensor, Sequence[Numeric]]: + ) -> Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" raise NotImplementedError diff --git a/torchopt/update.py b/torchopt/update.py index 8636d7a4..3f2d71fe 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -33,10 +33,15 @@ from __future__ import annotations -import torch +from typing import TYPE_CHECKING from torchopt import pytree -from torchopt.typing import Params, Updates + + +if TYPE_CHECKING: + import torch + + from torchopt.typing import Params, Updates __all__ = ['apply_updates'] diff --git a/torchopt/utils.py b/torchopt/utils.py index c067d570..5f9202a3 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -34,11 +34,11 @@ __all__ = [ 'ModuleState', - 'stop_gradient', 'extract_state_dict', - 'recover_state_dict', 'module_clone', 'module_detach_', + 'recover_state_dict', + 'stop_gradient', ] @@ -115,7 +115,7 @@ def extract_state_dict( # pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals -def extract_state_dict( +def extract_state_dict( # noqa: C901 target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', @@ -272,7 +272,7 @@ def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: return pytree.tree_map(get_variable, state) # type: ignore[arg-type,return-value] - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') def extract_module_containers( @@ -346,7 +346,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: state = cast(Sequence[OptState], state) target.load_state_dict(state) else: - raise RuntimeError(f'Unexpected class of {target}') + raise TypeError(f'Unexpected class of {target}') @overload @@ -383,7 +383,7 @@ def module_clone( # pylint: disable-next=too-many-locals -def module_clone( +def module_clone( # noqa: C901 target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', diff --git a/torchopt/visual.py b/torchopt/visual.py index d7885889..7638d7ec 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -19,16 +19,19 @@ from __future__ import annotations -from typing import Any, Generator, Iterable, Mapping, cast +from typing import TYPE_CHECKING, Any, Generator, Iterable, Mapping, cast import torch from graphviz import Digraph from torchopt import pytree -from torchopt.typing import TensorTree from torchopt.utils import ModuleState +if TYPE_CHECKING: + from torchopt.typing import TensorTree + + __all__ = ['make_dot', 'resize_graph'] @@ -69,7 +72,7 @@ def truncate(s: str) -> str: # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals -def make_dot( +def make_dot( # noqa: C901 var: TensorTree, params: ( Mapping[str, torch.Tensor] @@ -153,7 +156,7 @@ def get_var_name_with_flag(var: torch.Tensor) -> str | None: return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn: Any) -> None: # pylint: disable=too-many-branches + def add_nodes(fn: Any) -> None: # noqa: C901 # pylint: disable=too-many-branches assert not isinstance(fn, torch.Tensor) if fn in seen: return