diff --git a/CHANGELOG.md b/CHANGELOG.md index 74d23144..3a3ad174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135). - Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133). ### Fixed diff --git a/README.md b/README.md index c1fb97ba..321f39e3 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ ![CodeCov](https://img.shields.io/codecov/c/gh/metaopt/torchopt) ![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs) ![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads) - ![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github) ![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index c7e04e95..b2866407 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -285,6 +285,115 @@ Chain .. autofunction:: chain +Distributed Utilities +===================== + +.. currentmodule:: torchopt.distributed + +Initialization and Synchronization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + auto_init_rpc + barrier + +.. autofunction:: auto_init_rpc +.. autofunction:: barrier + +Process group information +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + get_world_info + get_world_rank + get_rank + get_world_size + get_local_rank + get_local_world_size + get_worker_id + +.. autofunction:: get_world_info +.. autofunction:: get_world_rank +.. autofunction:: get_rank +.. autofunction:: get_world_size +.. autofunction:: get_local_rank +.. autofunction:: get_local_world_size +.. autofunction:: get_worker_id + +Worker selection +~~~~~~~~~~~~~~~~ + +.. autosummary:: + + on_rank + not_on_rank + rank_zero_only + rank_non_zero_only + +.. autofunction:: on_rank +.. autofunction:: not_on_rank +.. autofunction:: rank_zero_only +.. autofunction:: rank_non_zero_only + +Remote Procedure Call (RPC) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + remote_async_call + remote_sync_call + +.. autofunction:: remote_async_call +.. autofunction:: remote_sync_call + +Predefined partitioners and reducers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + dim_partitioner + batch_partitioner + mean_reducer + sum_reducer + +.. autofunction:: dim_partitioner +.. autofunction:: batch_partitioner +.. autofunction:: mean_reducer +.. autofunction:: sum_reducer + +Function parallelization wrappers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + + parallelize + parallelize_async + parallelize_sync + +.. autofunction:: parallelize +.. autofunction:: parallelize_async +.. autofunction:: parallelize_sync + +Distributed Autograd +~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchopt.distributed.autograd + +.. autosummary:: + + context + get_gradients + backward + grad + +.. autofunction:: context +.. autofunction:: get_gradients +.. autofunction:: backward +.. autofunction:: grad + + General Utilities ================= diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst index f85eec3f..b6f00951 100644 --- a/docs/source/distributed/distributed.rst +++ b/docs/source/distributed/distributed.rst @@ -142,7 +142,6 @@ Initialization and Synchronization .. autosummary:: - torchopt.distributed.auto_init_rpc torchopt.distributed.barrier @@ -197,7 +196,6 @@ Process group information .. autosummary:: - torchopt.distributed.get_world_info torchopt.distributed.get_world_rank torchopt.distributed.get_rank @@ -228,7 +226,6 @@ Worker selection .. autosummary:: - torchopt.distributed.on_rank torchopt.distributed.not_on_rank torchopt.distributed.rank_zero_only @@ -275,7 +272,6 @@ Remote Procedure Call (RPC) .. autosummary:: - torchopt.distributed.remote_async_call torchopt.distributed.remote_sync_call @@ -354,7 +350,6 @@ Predefined partitioners and reducers .. autosummary:: - torchopt.distributed.dim_partitioner torchopt.distributed.batch_partitioner torchopt.distributed.mean_reducer @@ -439,7 +434,6 @@ Function parallelization wrappers .. autosummary:: - torchopt.distributed.parallelize torchopt.distributed.parallelize_async torchopt.distributed.parallelize_sync @@ -490,7 +484,6 @@ Distributed Autograd .. autosummary:: - torchopt.distributed.autograd.context torchopt.distributed.autograd.get_gradients torchopt.distributed.autograd.backward diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 8f9d6895..aac17046 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -171,3 +171,4 @@ issubclass abc ABCMeta subclasscheck +ctx diff --git a/tests/helpers.py b/tests/helpers.py index 4bba706e..23e178f0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy import itertools import os import random -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable import numpy as np import pytest @@ -137,7 +139,7 @@ def get_model(): @torch.no_grad() def get_models( device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: +) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]: seed_everything(seed=42) model_base = get_model().to(dtype=dtype) @@ -166,12 +168,12 @@ def get_models( @torch.no_grad() def assert_model_all_close( - model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]], + model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]], model_ref: nn.Module, model_base: nn.Module, dtype: torch.dtype = torch.float32, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: if isinstance(model, tuple): @@ -194,8 +196,8 @@ def assert_all_close( actual: torch.Tensor, expected: torch.Tensor, base: torch.Tensor = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: if base is not None: @@ -223,9 +225,9 @@ def assert_all_close( def assert_pytree_all_close( actual: TensorTree, expected: TensorTree, - base: Optional[TensorTree] = None, - rtol: Optional[float] = None, - atol: Optional[float] = None, + base: TensorTree | None = None, + rtol: float | None = None, + atol: float | None = None, equal_nan: bool = False, ) -> None: actual_leaves, actual_treespec = pytree.tree_flatten(actual) diff --git a/tests/test_alias.py b/tests/test_alias.py index c613d7d5..b609cf58 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -107,7 +109,7 @@ def test_sgd( def test_adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -177,7 +179,7 @@ def test_maml_adam( outer_lr: float, inner_lr: float, inner_update: int, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -263,7 +265,7 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op): def test_adamw( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, @@ -333,8 +335,8 @@ def test_adamw( def test_adam_accelerated_cuda( dtype: torch.dtype, lr: float, - optimizers: Tuple[Callable, torch.optim.Optimizer], - betas: Tuple[float, float], + optimizers: tuple[Callable, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, diff --git a/tests/test_implicit.py b/tests/test_implicit.py index ce0ee23b..9e3722d3 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import copy from collections import OrderedDict from types import FunctionType -from typing import Tuple import functorch import jax @@ -55,7 +56,7 @@ def forward(self, x): return self.fc(x) -def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]: +def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]: helpers.seed_everything(seed=42) def func(params, x): @@ -73,7 +74,7 @@ def func(params, x): @torch.no_grad() def get_model_torch( device: torch.types.Device = None, dtype: torch.dtype = torch.float32 -) -> Tuple[nn.Module, data.DataLoader]: +) -> tuple[nn.Module, data.DataLoader]: helpers.seed_everything(seed=42) model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype) diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py index 2c0966cc..61f8a7ad 100644 --- a/tests/test_meta_optim.py +++ b/tests/test_meta_optim.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Tuple +from __future__ import annotations import torch import torch.nn.functional as F @@ -40,7 +40,7 @@ def test_maml_meta_adam( outer_lr: float, inner_lr: float, inner_update: int, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, eps_root: float, weight_decay: float, diff --git a/tests/test_optim.py b/tests/test_optim.py index c43bc438..b2be7500 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import pytest @@ -96,7 +98,7 @@ def test_SGD( def test_Adam( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -154,7 +156,7 @@ def test_Adam( def test_AdamW( dtype: torch.dtype, lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -216,8 +218,8 @@ def test_AdamW( def test_Adam_accelerated_cuda( dtype: torch.dtype, lr: float, - optimizers: Tuple[torchopt.Optimizer, torch.optim.Optimizer], - betas: Tuple[float, float], + optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], + betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, @@ -339,7 +341,7 @@ def test_RMSProp( def test_FuncOptimizer( dtype: torch.dtype, lr: float, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, ) -> None: diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 9590acf8..ae714875 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch import numpy as np @@ -62,7 +64,7 @@ def test_lr_linear_schedule( dtype: torch.dtype, lr: float, total_iters: int, - optimizers: Tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer], inplace: bool, weight_decay: float, use_chain_flat: bool, diff --git a/tests/test_transform.py b/tests/test_transform.py index 4dfd034d..9598386d 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -13,13 +13,8 @@ # limitations under the License. # ============================================================================== -from typing import Tuple - -import functorch import torch -import torch.nn.functional as F -import helpers import torchopt diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi index bc3e8ebc..7ecfe7c2 100644 --- a/torchopt/_C/adam_op.pyi +++ b/torchopt/_C/adam_op.pyi @@ -15,7 +15,7 @@ # pylint: disable=all -from typing import Tuple +from __future__ import annotations import torch @@ -28,7 +28,7 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ... def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ... def forward_updates( @@ -42,10 +42,10 @@ def forward_updates( ) -> torch.Tensor: ... def backward_mu( dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_nu( dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... def backward_updates( dupdates: torch.Tensor, updates: torch.Tensor, @@ -55,4 +55,4 @@ def backward_updates( b2: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> tuple[torch.Tensor, torch.Tensor]: ... diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index 003a8a9f..ede60009 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -14,7 +14,9 @@ # ============================================================================== """The accelerated Ops.""" -from typing import Iterable, Optional, Union +from __future__ import annotations + +from typing import Iterable import torch @@ -22,7 +24,7 @@ from torchopt.typing import Device -def is_available(devices: Optional[Union[Device, Iterable[Device]]] = None) -> bool: +def is_available(devices: Device | Iterable[Device] | None = None) -> bool: """Check the availability of accelerated optimizer.""" op = AdamOp() diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py index 9f801b8d..ab5ea195 100644 --- a/torchopt/accelerated_op/_src/adam_op.py +++ b/torchopt/accelerated_op/_src/adam_op.py @@ -16,7 +16,7 @@ # pylint: disable=invalid-name,too-many-arguments,unused-argument -from typing import Tuple +from __future__ import annotations import torch @@ -30,7 +30,7 @@ def forward_( eps: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Adam forward inplace.""" mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1) nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2) @@ -80,7 +80,7 @@ def backward_mu( updates: torch.Tensor, mu: torch.Tensor, b1: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward mu.""" dupdates = dmu.mul(1.0 - b1) dmu = dmu.mul(b1) @@ -92,7 +92,7 @@ def backward_nu( updates: torch.Tensor, nu: torch.Tensor, b2: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward nu.""" dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2)) dnu = dnu.mul(b2) @@ -108,7 +108,7 @@ def backward_updates( b2: float, eps_root: float, count: int, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Adam backward updates.""" one_minus_pow_b1 = 1.0 - pow(b1, count) inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count) + eps_root) diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py index 6b93bf18..232513d6 100644 --- a/torchopt/accelerated_op/adam_op.py +++ b/torchopt/accelerated_op/adam_op.py @@ -16,8 +16,10 @@ # pylint: disable=c-extension-no-member,invalid-name +from __future__ import annotations + import contextlib -from typing import Any, Optional, Tuple +from typing import Any import torch @@ -132,9 +134,9 @@ def __call__( self, mu: torch.Tensor, nu: torch.Tensor, - updates: Optional[torch.Tensor], + updates: torch.Tensor | None, count: int, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Apply the Adam operator.""" if updates is None: return mu, nu, None diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index a7f90a79..08654577 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -31,7 +31,7 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the Adam optimizer.""" -from typing import Tuple +from __future__ import annotations from torchopt.alias.utils import ( _get_use_chain_flat, @@ -49,7 +49,7 @@ # pylint: disable-next=too-many-arguments def adam( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -68,26 +68,25 @@ def adam( - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 9aecc8ee..21ef84ef 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable from torchopt.alias.utils import ( _get_use_chain_flat, @@ -40,7 +42,7 @@ ) from torchopt.combine import chain from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam -from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule +from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule __all__ = ['adamw'] @@ -49,12 +51,12 @@ # pylint: disable-next=too-many-arguments,too-many-locals def adamw( lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, @@ -70,35 +72,34 @@ def adamw( - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 Args: - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is multiplied - with the learning rate. This is consistent with other frameworks such as PyTorch, but - different from (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to avoid - dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with other + frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight + decay is only multiplied with the "schedule multiplier", but not the base learning rate. + (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and - :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + :data:`False` for those you want to skip. Note that the Adam gradient transformations + are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index 18a5c5e8..f0eb92cd 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -69,28 +69,25 @@ def rmsprop( - Graves, 2013: https://arxiv.org/abs/1308.0850 Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude of + previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 61b3d6e4..7d86b538 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -64,21 +64,19 @@ def sgd( - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf Args: - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is not + used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 869aad87..b5088164 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -13,8 +13,9 @@ # limitations under the License. r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from __future__ import annotations + import threading -from typing import Optional, Tuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity @@ -93,9 +94,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' @@ -126,9 +127,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): @@ -151,9 +152,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' diff --git a/torchopt/base.py b/torchopt/base.py index bb37b147..b250c387 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -31,9 +31,11 @@ # ============================================================================== """The base classes for gradient transformation.""" +from __future__ import annotations + import itertools from abc import abstractmethod -from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Callable, NamedTuple from typing_extensions import Protocol # Python 3.8+ @@ -67,12 +69,11 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods """ @abstractmethod - def __call__(self, params: 'Params') -> 'OptState': + def __call__(self, params: Params) -> OptState: """Initialize the gradient transformation state. Args: - params: - The initial value of the parameters. + params (tree of Tensor): The initial value of the parameters. Returns: The initial state of the gradient transformation. @@ -93,21 +94,21 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods @abstractmethod def __call__( self, - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: """Transform the updates and state. Args: - updates: A tree of candidate updates. - state: The state of the gradient transformation. - params: (optional) - The current value of the parameters. - inplace: (optional) - If :data:`True`, modify updates and state using inplace operations. + updates (tree of Tensor): A tree of candidate updates. + state (tree of Tensor): The state of the gradient transformation. + params (tree of Tensor or None, optional): The current value of the parameters. + (default: :data:`None`) + inplace (bool, optional): If :data:`True`, modify updates and state using inplace + operations. (default: :data:`True`) Returns: The transformed ``updates``, and the updated ``state``. @@ -134,9 +135,9 @@ class GradientTransformation(NamedTuple): optimizer state. update: A pure function which takes as input a pytree of updates (with the same tree structure - as the original params ``pytree`` passed to :attr:`init`), the previous optimizer state - (which may have been initialized using the :attr:`init` function), and optionally the - ``inplace`` flag. The :attr:`update` function then returns the computed gradient + as the original params ``pytree`` passed to ``init``), the previous optimizer state + (which may have been initialized using the ``init`` function), and optionally the + ``inplace`` flag. The ``update`` function then returns the computed gradient updates, and a updates optimizer state. If the ``inplace`` flag is :data:`True`, the output results are the same instance as the input. """ @@ -145,7 +146,7 @@ class GradientTransformation(NamedTuple): update: TransformUpdateFn # pylint: disable-next=redefined-builtin - def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation': + def chain(self, next: GradientTransformation) -> ChainedGradientTransformation: """Chain two gradient transformations together.""" return ChainedGradientTransformation(self, next) @@ -157,9 +158,9 @@ class ChainedGradientTransformation(GradientTransformation): gradient transformations. """ - transformations: Tuple[GradientTransformation, ...] + transformations: tuple[GradientTransformation, ...] - def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation': + def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation: """Create a new chained gradient transformation.""" transformations = tuple( itertools.chain.from_iterable( @@ -175,16 +176,16 @@ def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTr init_fns, update_fns = tuple(zip(*transformations)) - def init_fn(params: 'Params') -> 'OptState': + def init_fn(params: Params) -> OptState: return tuple(fn(params) for fn in init_fns) def update_fn( - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in chain! Make sure you' @@ -219,15 +220,15 @@ def __hash__(self) -> int: """Return the hash of the chained gradient transformation.""" return hash(self.transformations) - def __getstate__(self) -> Tuple[GradientTransformation, ...]: + def __getstate__(self) -> tuple[GradientTransformation, ...]: """Return the state of the chained gradient transformation for serialization.""" return self.transformations - def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None: + def __setstate__(self, state: tuple[GradientTransformation, ...]) -> None: """Set the state of the chained gradient transformation from serialization.""" self.transformations = state - def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]: + def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]]]: """Serialize the chained gradient transformation.""" return ChainedGradientTransformation, (self.transformations,) @@ -240,18 +241,18 @@ def __new__(cls): return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) @staticmethod - def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument + def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument """Return empty state.""" return EmptyState() @staticmethod def update_fn( - updates: 'Updates', - state: 'OptState', + updates: Updates, + state: OptState, *, - params: Optional['Params'] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument - ) -> Tuple['Updates', 'OptState']: + ) -> tuple[Updates, OptState]: """Return updates unchanged.""" return updates, state diff --git a/torchopt/clip.py b/torchopt/clip.py index 2469d17a..b2aafb48 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -17,7 +17,7 @@ # ============================================================================== """Utilities for gradient clipping.""" -from typing import Optional, Tuple, Union +from __future__ import annotations import torch @@ -33,18 +33,19 @@ def clip_grad_norm( - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, + max_norm: float | int, + norm_type: float | int = 2.0, error_if_nonfinite: bool = False, ) -> GradientTransformation: """Clip gradient norm of an iterable of parameters. Args: max_norm (float or int): The maximum absolute value for each element in the update. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the - gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``. + norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity + norm. (default: :const:`2.0`) + error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm + of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``. + (default: :data:`False`) Returns: An ``(init_fn, update_fn)`` tuple. @@ -57,9 +58,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: available_updates = pytree.tree_leaves(updates) if len(available_updates) == 0: return updates, state diff --git a/torchopt/combine.py b/torchopt/combine.py index 82297426..0f1ed8ec 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -31,7 +31,7 @@ # ============================================================================== """Utilities to define a chained transformation.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity @@ -49,8 +49,8 @@ def chain(*transformations: GradientTransformation) -> GradientTransformation: :func:`update_fn` which chains the update transformations feeding the appropriate state to each. Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -66,8 +66,8 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati """Wrap around the inner transformations that manipulate the flattened tree structure (:class:``list``). Args: - *transformations: - A sequence of chainable ``(init_fn, update_fn)`` tuples. + *transformations (iterable of GradientTransformation): A sequence of chainable + ``(init_fn, update_fn)`` tuples. Returns: A single ``(init_fn, update_fn)`` tuple. @@ -86,9 +86,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) if params is not None: flat_params = pytree.tree_leaves(params, none_is_leaf=True) diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 377bc1f4..a5908963 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -16,9 +16,11 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple import functorch import torch @@ -47,7 +49,7 @@ def __init__( optimality_fn: Callable[..., TensorOrTensors], solution: TensorOrTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], *args: Any, ) -> None: self.optimality_fn = optimality_fn @@ -88,7 +90,7 @@ def _root_vjp( args: Args, grad_outputs: TupleOfTensors, output_is_tensor: bool, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> TupleOfOptionalTensors: if output_is_tensor: @@ -145,14 +147,14 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors: return tuple(true_output) -def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]: +def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: tuple[Any, ...]) -> tuple[Args, KwArgs]: nargs = len(flat_args) - len(kwarg_keys) args, kwarg_vals = flat_args[:nargs], flat_args[nargs:] kwargs = dict(zip(kwarg_keys, kwarg_vals)) return args, kwargs -def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]: +def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> tuple[Args, KwArgs]: bound = signature.bind(*args, **kwargs) bound.apply_defaults() return bound.args, bound.kwargs @@ -160,7 +162,7 @@ def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> def _signature_bind_and_match( signature: inspect.Signature, *args: Any, **kwargs: Any -) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]: +) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]: # We want to bind *args and **kwargs based on the provided signature, but also to associate the # resulting positional arguments back. To achieve this, we lift arguments to a triple: # @@ -193,13 +195,13 @@ def map_args_back(out_args): def _split_tensor_and_others( - mixed_tuple: Tuple[Any, ...], -) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]: - flattened: List[Any] + mixed_tuple: tuple[Any, ...], +) -> tuple[pytree.PyTreeSpec, tuple[bool, ...], TupleOfTensors, tuple[Any, ...]]: + flattened: list[Any] flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type] tensors: ListOfTensors = [] - non_tensors: List[Any] = [] - is_tensor_mask: List[bool] = [] + non_tensors: list[Any] = [] + is_tensor_mask: list[bool] = [] for item in flattened: is_tensor = isinstance(item, torch.Tensor) is_tensor_mask.append(is_tensor) @@ -212,10 +214,10 @@ def _split_tensor_and_others( def _merge_tensor_and_others( treespec: pytree.PyTreeSpec, - is_tensor_mask: Tuple[bool, ...], + is_tensor_mask: tuple[bool, ...], tensors: TupleOfTensors, - non_tensors: Tuple[Any, ...], -) -> Tuple[Any, ...]: + non_tensors: tuple[Any, ...], +) -> tuple[Any, ...]: tensor_counter = 0 non_tensor_counter = 0 results = [] @@ -231,13 +233,13 @@ def _merge_tensor_and_others( # pylint: disable-next=too-many-arguments,too-many-statements def _custom_root( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], optimality_fn: Callable[..., TensorOrTensors], solve: Callable[..., TensorOrTensors], - argnums: Tuple[int, ...], + argnums: tuple[int, ...], has_aux: bool, - reference_signature: Optional[Union[inspect.Signature, Callable]] = None, -) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]: + reference_signature: inspect.Signature | Callable | None = None, +) -> Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]: solver_fn_signature = inspect.signature(solver_fn) if reference_signature is None: @@ -249,16 +251,16 @@ def _custom_root( reference_signature = inspect.signature(fn) def make_custom_vjp_solver_fn( - solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], kwarg_keys: Sequence[str], - args_signs: Tuple[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]], ...], - ) -> Type[Function]: + args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...], + ) -> type[Function]: # pylint: disable-next=missing-class-docstring,abstract-method class ImplicitMetaGradient(Function): @staticmethod def forward( # type: ignore[override] # pylint: disable=arguments-differ ctx: Any, *flat_args: Any - ) -> Tuple[Any, ...]: + ) -> tuple[Any, ...]: output, aux, output_is_tensor = None, None, False args = [] @@ -361,12 +363,12 @@ def backward( # pylint: disable=too-many-locals @functools.wraps(solver_fn) def wrapped_solver_fn( *args: Any, **kwargs: Any - ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]: + ) -> TensorOrTensors | tuple[TensorOrTensors, Any]: args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs) keys, vals = list(kwargs.keys()), list(kwargs.values()) - args_signs: List[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]]] = [] - flat_args: List[Any] = [] + args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = [] + flat_args: list[Any] = [] args_offset = 0 for idx, arg in enumerate(args): if idx in argnums: @@ -410,12 +412,12 @@ def wrapped_solver_fn( def custom_root( optimality_fn: Callable[..., TensorOrTensors], - argnums: Union[int, Tuple[int, ...]], + argnums: int | tuple[int, ...], has_aux: bool = False, solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(), ) -> Callable[ - [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]], - Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]], + [Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]], + Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]], ]: """Return a decorator for adding implicit differentiation to a root solver. @@ -442,18 +444,17 @@ def solver_fn(params, arg1, arg2, ...): **In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.** Args: - optimality_fn: (callable) - An equation function, ``optimality_fn(params, *args)``. The invariant is - ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``. - argnums: (int or tuple of ints) - Specifies arguments to compute gradients with respect to. The ``argnums`` can be an - integer or a tuple of integers, which respect to the zero-based indices of the arguments - of the ``solver_fn(params, *args)`` function. The argument ``params`` is included - for the counting, while it is indexed as ``argnums=0``. - has_aux: (default: :data:`False`) - Whether the decorated solver function returns auxiliary data. - solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`) - a linear solver of the form ``solve(matvec, b)``. + optimality_fn (callable): An equation function, ``optimality_fn(params, *args)``. The + invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of + ``solution``. + argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The + ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based + indices of the arguments of the ``solver_fn(params, *args)`` function. The argument + ``params`` is included for the counting, while it is indexed as ``argnums=0``. + has_aux (bool, optional): Whether the decorated solver function returns auxiliary data. + (default: :data:`False`) + solve (callable, optional): A linear solver of the form ``solve(matvec, b)``. + (default: :func:`linear_solve.solve_normal_cg`) Returns: A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``. diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index f9bff4de..bbae37c9 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -16,10 +16,12 @@ # pylint: disable=redefined-builtin +from __future__ import annotations + import abc import functools import itertools -from typing import Any, Iterable, Optional, Tuple, Type +from typing import Any, Iterable import functorch import torch @@ -38,7 +40,7 @@ def _stateless_objective_fn( __flat_meta_params: TupleOfTensors, __params_names: Iterable[str], __meta_params_names: Iterable[str], - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, ) -> torch.Tensor: @@ -57,7 +59,7 @@ def _stateless_optimality_fn( __flat_meta_params: TupleOfTensors, __params_names: Iterable[str], __meta_params_names: Iterable[str], - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, ) -> TupleOfTensors: @@ -72,8 +74,8 @@ def _stateless_optimality_fn( def make_optimality_from_objective( - cls: Type['ImplicitMetaGradientModule'], -) -> Type['ImplicitMetaGradientModule']: + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: """Derives the optimality function of the objective function.""" if ( getattr(cls, 'objective', ImplicitMetaGradientModule.objective) @@ -81,7 +83,7 @@ def make_optimality_from_objective( ): raise TypeError('The objective function is not defined.') - def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors: + def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors: params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) @@ -102,8 +104,8 @@ def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfT def enable_implicit_gradients( - cls: Type['ImplicitMetaGradientModule'], -) -> Type['ImplicitMetaGradientModule']: + cls: type[ImplicitMetaGradientModule], +) -> type[ImplicitMetaGradientModule]: """Enable implicit gradients for the :func:`solve` method.""" cls_solve = cls.solve if getattr(cls_solve, '__implicit_gradients_enabled__', False): @@ -122,17 +124,17 @@ def stateless_solver_fn( __params_names: Iterable[str], __meta_params_names: Iterable[str], # pylint: enable=unused-argument - self: 'ImplicitMetaGradientModule', + self: ImplicitMetaGradientModule, *input, **kwargs, - ) -> Tuple[TupleOfTensors, Any]: + ) -> tuple[TupleOfTensors, Any]: """Solve the optimization problem.""" output = cls_solve(self, *input, **kwargs) flat_optimal_params = tuple(p.detach_() for p in self.parameters()) return flat_optimal_params, output @functools.wraps(cls_solve) - def wrapped(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Any: + def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any: """Solve the optimization problem.""" params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) @@ -159,9 +161,9 @@ class ImplicitMetaGradientModule(MetaGradientModule): _custom_optimality: bool _custom_objective: bool - linear_solve: Optional[LinearSolver] + linear_solve: LinearSolver | None - def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None: + def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: """Validate and initialize the subclass.""" super().__init_subclass__() cls.linear_solve = linear_solve diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index 80664d8b..43522028 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -14,8 +14,10 @@ # ============================================================================== """Zero-Order Gradient Estimation.""" +from __future__ import annotations + import functools -from typing import Any, Callable, List, Sequence, Tuple, Union +from typing import Any, Callable, Sequence from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -33,9 +35,7 @@ def __init__(self, sample_fn: SampleFunc) -> None: """Wrap a sample function to make it a :class:`Samplable` object.""" self.sample_fn = sample_fn - def sample( - self, sample_shape: torch.Size = torch.Size() - ) -> Union[torch.Tensor, Sequence[Numeric]]: + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" return self.sample_fn(sample_shape) @@ -44,14 +44,14 @@ def sample( def _zero_order_naive( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -59,7 +59,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -107,7 +107,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -119,7 +119,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -145,14 +145,14 @@ def add_perturbation(tensor, noises): def _zero_order_forward( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -160,7 +160,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -209,7 +209,7 @@ def backward( # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] def add_perturbation(tensor, noises): return tensor.add(noises, alpha=sigma) @@ -221,7 +221,7 @@ def add_perturbation(tensor, noises): flat_noisy_params = [ add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -248,14 +248,14 @@ def add_perturbation(tensor, noises): def _zero_order_antithetic( # pylint: disable=too-many-statements fn: Callable[..., torch.Tensor], distribution: Samplable, - argnums: Tuple[int, ...], + argnums: tuple[int, ...], num_samples: int, sigma: Numeric, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements diff_params = [args[argnum] for argnum in argnums] - flat_diff_params: List[Any] + flat_diff_params: list[Any] flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type] class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method @@ -263,7 +263,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: flat_diff_params = args[:-1] origin_args = list(args[-1][0]) - flat_args: List[Any] + flat_args: list[Any] flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type] ctx.args_treespec = args_treespec @@ -309,7 +309,7 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals flat_args.append(non_tensors[non_tensors_counter]) non_tensors_counter += 1 - args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] + args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] @@ -318,7 +318,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: add_perturbation_fn(t, n, alpha=sigma) for t, n in zip(flat_diff_params, noises) ] - noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment] + noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params ) @@ -349,28 +349,28 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: def zero_order( - distribution: Union[SampleFunc, Samplable], + distribution: SampleFunc | Samplable, method: Method = 'naive', - argnums: Union[int, Tuple[int, ...]] = (0,), + argnums: int | tuple[int, ...] = (0,), num_samples: int = 1, sigma: Numeric = 1.0, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Return a decorator for applying zero-order differentiation. Args: - distribution: (function or Samplable) - A samplable object that has method ``samplable.sample(sample_shape)`` or a function that - takes the shape as input and returns a shaped batch of samples. This is used to sample - perturbations from the given distribution. The distribution should be sphere symmetric. - method: (str) - The algorithm to use. The currently supported algorithms are :const:`'naive'`, - :const:`'forward'`, and :const:`'antithetic'`. Defaults to :const:`'naive'`. - argnums: (int or tuple of int, default: :const:`0`) - Specifies arguments to compute gradients with respect to. - num_samples: (int, default :const:`1`) - The number of sample to get the averaged estimated gradient. - sigma: (Numeric) - The standard deviation of the perturbation. Defaults to :const:`1.0`. + distribution (callable or Samplable): A samplable object that has method + ``samplable.sample(sample_shape)`` or a function that takes the shape as input and + returns a shaped batch of samples. This is used to sample perturbations from the given + distribution. The distribution should be sphere symmetric. + method (str, optional): The algorithm to use. The currently supported algorithms are + :const:`'naive'`, :const:`'forward'`, and :const:`'antithetic'`. + (default: :const:`'naive'`) + argnums (int or tuple of int, optional): Specifies arguments to compute gradients with + respect to. (default: :const:`0`) + num_samples (int, optional): The number of sample to get the averaged estimated gradient. + (default: :const:`1`) + sigma (float or Tensor, optional): The standard deviation of the perturbation. + (default: :const:`1.0`) Returns: A function decorator that enables zero-order gradient estimation. diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index d76ac444..65014fb9 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -16,9 +16,11 @@ # pylint: disable=redefined-builtin +from __future__ import annotations + import abc import functools -from typing import Sequence, Type, Union +from typing import Sequence import torch import torch.nn as nn @@ -32,11 +34,11 @@ def enable_zero_order_gradients( - cls: Type['ZeroOrderGradientModule'], + cls: type[ZeroOrderGradientModule], method: Method = 'naive', num_samples: int = 1, sigma: Numeric = 1.0, -) -> Type['ZeroOrderGradientModule']: +) -> type[ZeroOrderGradientModule]: """Enable zero-order gradient estimation for the :func:`forward` method.""" cls_forward = cls.forward if getattr(cls_forward, '__zero_order_gradients_enabled__', False): @@ -45,7 +47,7 @@ def enable_zero_order_gradients( ) @functools.wraps(cls_forward) - def wrapped(self: 'ZeroOrderGradientModule', *input, **kwargs) -> torch.Tensor: + def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor: """Do the forward pass calculation.""" params_names, flat_params = tuple(zip(*self.named_parameters())) @@ -91,7 +93,7 @@ def forward(self, *args, **kwargs) -> torch.Tensor: @abc.abstractmethod def sample( self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument - ) -> Union[torch.Tensor, Sequence[Numeric]]: + ) -> torch.Tensor | Sequence[Numeric]: # pylint: disable-next=line-too-long """Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.""" raise NotImplementedError diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 53f87fba..b46ad67e 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -14,6 +14,8 @@ # ============================================================================== """Distributed APIs.""" +from __future__ import annotations + import functools import sys from typing import ( @@ -73,8 +75,8 @@ class TensorDimensionPartitioner: while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int): The dimension to partition. + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -82,11 +84,12 @@ class TensorDimensionPartitioner: partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`True`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) """ def __init__( @@ -95,7 +98,7 @@ def __init__( *, exclusive: bool = False, keepdim: bool = False, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> None: """Initialize the partitioner instance.""" if not keepdim and not exclusive: @@ -111,7 +114,7 @@ def __call__( self, *args: Any, **kwargs: Any, - ) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]: + ) -> list[tuple[int, Args | None, KwArgs | None]]: """Partition the batch of inputs along the given dimension.""" if self.workers is None: workers = list(range(get_world_size())) @@ -120,7 +123,7 @@ def __call__( num_workers = len(workers) args_tree = (args, kwargs) - flat_args: List[Any] + flat_args: list[Any] flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type] batch_size = None @@ -137,8 +140,8 @@ def __call__( if batch_size is None: return [(get_world_rank(), args, kwargs.copy())] - dim_slices: List[Union[int, slice]] - batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined] + dim_slices: list[int | slice] + batch_slices: list[tuple[int | slice | Ellipsis.__class__, ...]] # type: ignore[name-defined] if self.exclusive: num_replicas = batch_size if self.keepdim: @@ -172,7 +175,7 @@ def __call__( for dim_slice in dim_slices ] - flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)] + flat_args_replicas: list[list[Any]] = [[] for _ in range(num_replicas)] for arg in flat_args: if isinstance(arg, torch.Tensor): for i, batch_slice in enumerate(batch_slices): @@ -181,7 +184,7 @@ def __call__( for i in range(num_replicas): flat_args_replicas[i].append(arg) - args_replicas: List[Tuple[Args, KwArgs]] = [ + args_replicas: list[tuple[Args, KwArgs]] = [ pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc] for args_replica in flat_args_replicas ] @@ -193,10 +196,10 @@ def __call__( def __reduce__( self, - ) -> Tuple[ - Callable[..., 'TensorDimensionPartitioner'], - Tuple[int], - Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]], + ) -> tuple[ + Callable[..., TensorDimensionPartitioner], + tuple[int], + dict[str, bool | Sequence[int | str] | None], ]: """Return a tuple that allows the partitioner to be pickled.""" return ( @@ -211,7 +214,7 @@ def dim_partitioner( *, exclusive: bool = False, keepdim: bool = True, - workers: Optional[Sequence[Union[int, str]]] = None, + workers: Sequence[int | str] | None = None, ) -> PartitionFunction: """Partition a batch of inputs along a given dimension. @@ -219,8 +222,8 @@ def dim_partitioner( while the non-tensor values will be broadcasted to partitions. Args: - dim: The dimension to partition. - exclusive: Whether to partition the batch exclusively. + dim (int, optional): The dimension to partition. (default: :const:`0`) + exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`) If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where ``batch_size`` is the size of the batch along the given dimension. Each batch sample will be assigned to a separate RPC call. @@ -228,11 +231,12 @@ def dim_partitioner( partitions, where ``num_workers`` is the number of workers in the world. When ``batch_size > num_workers``, there can be multiple batch samples forward in a single RPC call. - keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the - batch dimension. If :data:`False`, use select instead of slicing. This functionality - should be used with ``exclusive=True``. - workers: The workers to partition the batch to. If :data:`None`, the batch will be - partitioned to all workers in the world. + keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`False`) + If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of + slicing. This functionality should be used with ``exclusive=True``. + workers (sequence of int or str, or None, optional): The workers to partition the batch to. + If :data:`None`, the batch will be partitioned to all workers in the world. + (default: :data:`None`) Returns: A partition function. @@ -273,26 +277,26 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: def remote_async_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[Future[List[T]], Future[U]]: + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Future[list[T]] | Future[U]: """Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: A :class:`torch.Future` instance for the result. The result is at the current worker. @@ -330,26 +334,26 @@ def remote_async_call( def remote_sync_call( func: Callable[..., T], *, - args: Optional[Args] = None, - kwargs: Optional[KwArgs] = None, - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Union[List[T], U]: + args: Args | None = None, + kwargs: KwArgs | None = None, + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> list[T] | U: """Do an RPC synchronously on remote workers and return the result to the current worker. Args: - func (Callable[..., T]): The function to call. - args (Optional[Args], optional): The arguments to pass to the function. Defaults to - :data:`None`. - kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults - to :data:`None`. - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + func (callable): The function to call. + args (tuple of object or None, optional): The arguments to pass to the function. + (default: :data:`None`) + kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function. + (default: :data:`None`) + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The result of the RPC call. The result is at the current worker. @@ -365,10 +369,10 @@ def remote_sync_call( def parallelize_async( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]: + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., Future[list[T]] | Future[U]]]: """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. The @@ -376,13 +380,12 @@ def parallelize_async( return a :class:`torch.Future` instance of the result. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -392,9 +395,9 @@ def parallelize_async( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[Future[List[T]], Future[U]]]: + def wrapper(func: Callable[..., T]) -> Callable[..., Future[list[T]] | Future[U]]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: + def wrapped(*args: Any, **kwargs: Any) -> Future[list[T]] | Future[U]: return remote_async_call( func, args=args, @@ -423,22 +426,21 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]: def parallelize( - partitioner: Optional[Partitioner] = None, - reducer: Optional[Callable[[Iterable[T]], U]] = None, - timeout: Optional[float] = UNSET_RPC_TIMEOUT, -) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]: + partitioner: Partitioner | None = None, + reducer: Callable[[Iterable[T]], U] | None = None, + timeout: float | None = UNSET_RPC_TIMEOUT, +) -> Callable[[Callable[..., T]], Callable[..., list[T] | U]]: """Return a decorator for parallelizing a function. This decorator can be used to parallelize a function call across multiple workers. Args: - partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple - workers. Defaults to :func:`batch_partitioner`. - reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from - multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not - specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`. - timeout (float, optional): The timeout for the RPC call. Defaults to - :data:`rpc.api.UNSET_RPC_TIMEOUT`. + partitioner (int, str, or callable, optional): A partitioner that partitions the arguments + to multiple workers. (default: :func:`batch_partitioner`) + reducer (callable or None, optional): A reducer that reduces the results from multiple + workers. If :data:`None`, do not reduce the results. (default: :data:`None`) + timeout (float, optional): The timeout for the RPC call. + (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`) Returns: The decorator function. @@ -448,9 +450,9 @@ def parallelize( if reducer is None: reducer = mean_reducer # type: ignore[assignment] - def wrapper(func: Callable[..., T]) -> Callable[..., Union[List[T], U]]: + def wrapper(func: Callable[..., T]) -> Callable[..., list[T] | U]: @functools.wraps(func) - def wrapped(*args: Any, **kwargs: Any) -> Union[List[T], U]: + def wrapped(*args: Any, **kwargs: Any) -> list[T] | U: return remote_sync_call( func, args=args, diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 5fe51278..17fa9463 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -14,14 +14,15 @@ # ============================================================================== """Distributed Autograd.""" +from __future__ import annotations + from threading import Lock -from typing import Optional, overload import torch import torch.distributed.autograd as autograd from torch.distributed.autograd import context -from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors +from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors __all__ = ['is_available', 'context'] @@ -43,22 +44,23 @@ def backward( autograd_ctx_id: int, tensors: TensorOrTensors, retain_graph: bool = False, - inputs: Optional[TensorOrTensors] = None, + inputs: TensorOrTensors | None = None, ) -> None: """Perform distributed backward pass for local parameters. Compute the sum of gradients of given tensors with respect to graph leaves. Args: - autograd_ctx_id: The autograd context id. - tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed. + autograd_ctx_id (int): The autograd context id. + tensors (Tensor or sequence of Tensor): Tensors of which the derivative will be computed. retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. - inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will - accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the - gradient is accumulated into all the leaf Tensors that were used to compute the - attr::tensors. + (default: :data:`False`) + inputs (Tensor, sequence of Tensor, or None, optional): Inputs w.r.t. which the gradient + be will accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were used to + compute the ``tensors``. (default: :data:`None`) """ if inputs is not None: if isinstance(inputs, torch.Tensor): @@ -85,25 +87,6 @@ def backward( else: p.grad = g - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - ) -> TupleOfTensors: - ... - - @overload - def grad( - autograd_ctx_id: int, - outputs: TensorOrTensors, - inputs: TensorOrTensors, - retain_graph: bool = False, - allow_unused: bool = False, - ) -> TupleOfOptionalTensors: - ... - def grad( autograd_ctx_id: int, outputs: TensorOrTensors, @@ -114,16 +97,17 @@ def grad( """Compute and return the sum of gradients of outputs with respect to the inputs. Args: - autograd_ctx_id: The autograd context id. - outputs (sequence of Tensor): outputs of the differentiated function. - inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not - accumulated into ``.grad``). + autograd_ctx_id (int): The autograd context id. + outputs (Tensor or sequence of Tensor): Outputs of the differentiated function. + inputs (Tensor or sequence of Tensor): Inputs w.r.t. which the gradient will be returned + (and not accumulated into ``.grad``). retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to :data:`True` is not needed and often can be worked around in a much more efficient way. + (default: :data:`False`) allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. - Defaults to :data:`False`. + (default: :data:`False`) """ outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs) inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index 45140df1..804d4b9d 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -14,10 +14,12 @@ # ============================================================================== """Utilities for gathering information about the world.""" +from __future__ import annotations + import atexit import functools import os -from typing import Any, Callable, Iterable, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, NamedTuple, TypeVar import torch.distributed.rpc as rpc from torch.distributed.elastic.multiprocessing.errors import record @@ -127,32 +129,33 @@ def get_local_world_size() -> int: # pylint: disable-next=redefined-builtin,invalid-name -def get_worker_id(id: Optional[Union[str, int]] = None) -> int: +def get_worker_id(id: str | int | None = None) -> int: """Get the worker id from the given id.""" if isinstance(id, int): return id return rpc.get_worker_info(worker_name=id).id -def barrier(worker_names: Optional[Iterable[str]] = None) -> None: +def barrier(worker_names: Iterable[str] | None = None) -> None: r"""Synchronize local and remote RPC processes. This will block until all local and remote RPC processes specified under worker_names reach this method to wait for all outstanding work to complete. Args: - worker_names: The set of workers to synchronize. If :data:`None`, all workers. + worker_names (iterable of str or None, optional): The set of workers to synchronize. + If :data:`None`, all workers. (default: :data:`None`) """ worker_names = {} if worker_names is None else set(worker_names) rpc.api._barrier(worker_names) # pylint: disable=protected-access def auto_init_rpc( - worker_init_fn: Optional[Callable[[], None]] = None, + worker_init_fn: Callable[[], None] | None = None, worker_name_format: Callable[..., str] = default_worker_name_format, *, - backend: Optional['rpc.BackendType'] = None, - rpc_backend_options: Optional['rpc.RpcBackendOptions'] = None, + backend: rpc.BackendType | None = None, + rpc_backend_options: rpc.RpcBackendOptions | None = None, ) -> Callable[[F], F]: """Return a decorator to automatically initialize RPC on the decorated function.""" global _WORKER_NAME_FORMAT # pylint: disable=global-statement diff --git a/torchopt/hook.py b/torchopt/hook.py index 949c76e7..f188415c 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -14,7 +14,9 @@ # ============================================================================== """Hook utilities.""" -from typing import Callable, Optional, Tuple +from __future__ import annotations + +from typing import Callable import torch @@ -32,7 +34,7 @@ def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: def nan_to_num_hook( - nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None + nan: float = 0.0, posinf: float | None = None, neginf: float | None = None ) -> Callable[[torch.Tensor], torch.Tensor]: """Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers.""" @@ -59,9 +61,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: def f(g): return g.register_hook(hook) diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 94daee53..5456f076 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + from functools import partial -from typing import Callable, Optional, Union +from typing import Callable import torch @@ -100,14 +102,14 @@ def body_fn(value): def _isolve( _isolve_solve: Callable, - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: if x0 is None: x0 = pytree.tree_map(torch.zeros_like, b) @@ -133,14 +135,14 @@ def _isolve( def cg( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - x0: Optional[TensorTree] = None, + x0: TensorTree | None = None, *, rtol: float = 1e-5, atol: float = 0.0, - maxiter: Optional[int] = None, - M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None, + maxiter: int | None = None, + M: TensorTree | Callable[[TensorTree], TensorTree] | None = None, ) -> TensorTree: """Use Conjugate Gradient iteration to solve ``Ax = b``. @@ -153,30 +155,30 @@ def cg( solves converge. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - x0: (tensor or tree of tensors, optional) - Starting guess for the solution. Must have the same structure as ``b``. - rtol: (float, optional, default: :const:`1e-5`) - Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - atol: (float, optional, default: :const:`0.0`) - Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not - implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy - unless you explicitly pass ``atol`` to SciPy's ``cg``. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - M: (tensor or tree of tensors or function) - Pre-conditioner for ``A``. The pre-conditioner should approximate the inverse of ``A``. - Effective preconditioning dramatically improves the rate of convergence, which implies - that fewer iterations are needed to reach a given error tolerance. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + x0 (Tensor, tree of Tensor, or None, optional): Starting guess for the solution. Must have + the same structure as ``b``. If :data:`None`, use zero initialization. + (default: :data:`None`) + rtol (float, optional): Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`1e-5`) + atol (float, optional): Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. + We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from + SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`0.0`) + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + ``10 * size`` will be used, where ``size`` is the size of the flattened input tensor(s). + (default: :data:`None`) + M (Tensor, tree of Tensor, function, or None, optional): Pre-conditioner for ``A``. The + pre-conditioner should approximate the inverse of ``A``. Effective preconditioning + dramatically improves the rate of convergence, which implies that fewer iterations are + needed to reach a given error tolerance. If :data:`None`, no pre-conditioner will be + used. (default: :data:`None`) Returns: the Conjugate Gradient (CG) linear solver diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 04f5dd11..c1975203 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -16,13 +16,15 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional, Union +from typing import Callable import torch from torchopt import pytree -from torchopt.linalg.utils import cat_shapes, normalize_matvec +from torchopt.linalg.utils import normalize_matvec from torchopt.typing import TensorTree @@ -33,7 +35,7 @@ def _ns_solve( A: torch.Tensor, b: torch.Tensor, maxiter: int, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> torch.Tensor: """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: @@ -57,27 +59,26 @@ def _ns_solve( def ns( - A: Union[TensorTree, Callable[[TensorTree], TensorTree]], + A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: """Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - b: (tensor or tree of tensors) - Right hand side of the linear system representing a single vector. Can be stored as an - array or Python container of array(s) with any shape. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + b (Tensor or tree of Tensor): Right hand side of the linear system representing a single + vector. Can be stored as a tensor or Python container of tensor(s) with any shape. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. @@ -111,7 +112,7 @@ def ns( return inv_A_hat_b -def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None): """Use Neumann Series iteration to solve ``A^{-1}``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') @@ -134,28 +135,27 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None): def ns_inv( A: TensorTree, - maxiter: Optional[int] = None, + maxiter: int | None = None, *, - alpha: Optional[float] = None, + alpha: float | None = None, ) -> TensorTree: """Use Neumann Series iteration to solve ``A^{-1}``. Args: - A: (tensor or tree of tensors or function) - 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when - called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and - must return array(s) with the same structure and shape as its argument. - maxiter: (integer, optional) - Maximum number of iterations. Iteration will stop after maxiter steps even if the - specified tolerance has not been achieved. - alpha: (float, optional) - Decay coefficient. + A (Tensor or tree of Tensor): 2D array or function that calculates the linear map + (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a + hermitian, positive definite matrix, and must return tensor(s) with the same structure + and shape as its argument. + maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after + maxiter steps even if the specified tolerance has not been achieved. If :data:`None`, + :const:`10` will be used. (default: :const:`10`) + alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be + used. (default: :const:`1.0`) Returns: The Neumann Series (NS) matrix inversion approximation. """ if maxiter is None: - size = sum(cat_shapes(A)) - maxiter = 10 * size # copied from SciPy + maxiter = 10 return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A) diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py index 275232be..f301a624 100644 --- a/torchopt/linalg/utils.py +++ b/torchopt/linalg/utils.py @@ -14,8 +14,10 @@ # ============================================================================== """Utilities for linear algebra.""" +from __future__ import annotations + import itertools -from typing import Callable, Tuple, Union +from typing import Callable import torch @@ -23,14 +25,14 @@ from torchopt.typing import TensorTree -def cat_shapes(tree: TensorTree) -> Tuple[int, ...]: +def cat_shapes(tree: TensorTree) -> tuple[int, ...]: """Concatenate the shapes of the leaves of a tree of tensors.""" leaves = pytree.tree_leaves(tree) return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves)) def normalize_matvec( - matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]] + matvec: TensorTree | Callable[[TensorTree], TensorTree] ) -> Callable[[TensorTree], TensorTree]: """Normalize an argument for computing matrix-vector product.""" if callable(matvec): diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index f75ef9f4..844c9407 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec @@ -47,8 +49,8 @@ def _solve_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: """Solve ``A x = b`` using conjugate gradient. @@ -56,10 +58,12 @@ def _solve_cg( This assumes that ``A`` is a hermitian, positive definite matrix. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. - init: Optional initialization to be used by conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver. Returns: @@ -80,8 +84,10 @@ def solve_cg(**kwargs): This assumes that ``A`` is a hermitian, positive definite matrix. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - init: Optional initialization to be used by conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index c3224a52..399a0ef9 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable import torch @@ -49,7 +51,7 @@ def _solve_inv( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, + ridge: float | None = None, ns: bool = False, **kwargs, ) -> TensorTree: @@ -59,11 +61,13 @@ def _solve_inv( in memory. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tensor for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. @@ -94,9 +98,11 @@ def solve_inv(**kwargs): in memory. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``. - ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`, - materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A x + ridge x = b``. (default: :data:`None`) + ns (bool, optional): Whether to use Neumann Series matrix inversion approximation. + If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` + instead. (default: :data:`False`) **kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation solver :func:`torchopt.linalg.ns`. diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 3199a490..8d38f77a 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -33,8 +33,10 @@ # pylint: disable=invalid-name +from __future__ import annotations + import functools -from typing import Callable, Optional +from typing import Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec @@ -47,8 +49,8 @@ def _solve_normal_cg( matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x b: TensorTree, - ridge: Optional[float] = None, - init: Optional[TensorTree] = None, + ridge: float | None = None, + init: TensorTree | None = None, **kwargs, ) -> TensorTree: """Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient. @@ -57,10 +59,12 @@ def _solve_normal_cg( positive definite. Args: - matvec: A function that returns the product between ``A`` and a vector. - b: A tree of tensors for the right hand side of the equation. - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + matvec (callable): A function that returns the product between ``A`` and a vector. + b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. @@ -93,8 +97,10 @@ def solve_normal_cg(**kwargs): positive definite. Args: - ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``. - init: Optional initialization to be used by normal conjugate gradient. + ridge (float or None, optional): Optional ridge regularization. If provided, solves the + equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`) + init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by + conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`) **kwargs: Additional keyword arguments for the conjugate gradient solver :func:`torchopt.linalg.cg`. diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py index 9c2f7ced..f4f34e2a 100644 --- a/torchopt/linear_solve/utils.py +++ b/torchopt/linear_solve/utils.py @@ -31,7 +31,9 @@ # ============================================================================== """Utilities for linear algebra solvers.""" -from typing import Callable, Tuple +from __future__ import annotations + +from typing import Callable import functorch @@ -75,7 +77,7 @@ def ridge_matvec(y: TensorTree) -> TensorTree: def materialize_matvec( matvec: Callable[[TensorTree], TensorTree], x: TensorTree -) -> Tuple[ +) -> tuple[ TensorTree, Callable[[TensorTree], TensorTree], Callable[[TensorTree], TensorTree], diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index 3716f674..f8804864 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -14,8 +14,10 @@ # ============================================================================== """Base class for neural network modules that hold meta-parameters and meta-modules.""" +from __future__ import annotations + from collections import OrderedDict -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Iterator, NamedTuple import torch import torch.nn as nn @@ -27,8 +29,8 @@ class MetaInputsContainer(NamedTuple): """Container for parameters and modules in the constructor input arguments.""" - meta_parameters: Set[torch.Tensor] - meta_modules: Set[nn.Module] + meta_parameters: set[torch.Tensor] + meta_modules: set[nn.Module] class MetaGradientModule(nn.Module): # pylint: disable=abstract-method @@ -36,12 +38,12 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method _meta_inputs: MetaInputsContainer _meta_parameters: TensorContainer - _meta_modules: Dict[str, Optional[nn.Module]] + _meta_modules: dict[str, nn.Module | None] - def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': + def __new__(cls, *args, **kwargs) -> MetaGradientModule: """Create a new module instance.""" instance = super().__new__(cls) - flat_args: List[Any] + flat_args: list[Any] flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type] meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad} meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training} @@ -51,14 +53,14 @@ def __new__(cls, *args, **kwargs) -> 'MetaGradientModule': instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules) instance._meta_parameters: TensorContainer = OrderedDict() # type: ignore[misc] - instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc] + instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc] return instance def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument """Initialize a new module instance.""" super().__init__() - def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: + def __getattr__(self, name: str) -> torch.Tensor | nn.Module: """Get an attribute of the module.""" if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] @@ -83,7 +85,7 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") # pylint: disable-next=too-many-branches,too-many-statements - def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None: + def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: """Set an attribute of the module.""" def remove_from(*dicts_or_sets): @@ -186,18 +188,17 @@ def __delattr__(self, name: str) -> None: else: object.__delattr__(self, name) - def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + def register_parameter(self, name: str, param: torch.Tensor | None) -> None: r"""Add a parameter to the module. The parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the parameter. The parameter can be accessed from this module + using the given name. + param (Tensor or None): The parameter to be added to the module. If :data:`None`, then + operations that run on parameters, such as ``cuda``, are ignored. If :data:`None`, + the parameter is **not** included in the module's ``state_dict``. """ if '_parameters' not in self.__dict__: raise AttributeError('cannot assign parameter before Module.__init__() call') @@ -231,18 +232,17 @@ def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: self._parameters[name] = param # type: ignore - def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None: + def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None: r"""Add a meta-parameter to the module. The meta-parameter can be accessed as an attribute using given name. Args: - name (string): name of the parameter. The parameter can be accessed - from this module using the given name - param (torch.Tensor or None): parameter to be added to the module. If - ``None``, then operations that run on parameters, such as :attr:`cuda`, - are ignored. If ``None``, the parameter is **not** included in the - module's :attr:`state_dict`. + name (str): The name of the meta-parameter. The meta-parameter can be accessed from this + module using the given name. + param (Tensor or None): The meta-parameter to be added to the module. If :data:`None`, + then operations that run on meta-parameters, such as ``cuda``, are ignored. If + :data:`None`, the meta-parameter is **not** included in the module's ``state_dict``. """ if '_meta_parameters' not in self.__dict__: raise AttributeError( @@ -273,15 +273,15 @@ def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> N self._meta_parameters[name] = param - def add_module(self, name: str, module: Optional[nn.Module]) -> None: + def add_module(self, name: str, module: nn.Module | None) -> None: r"""Add a child module to the current module. The module can be accessed as an attribute using the given name. Args: - name (string): name of the child module. The child module can be - accessed from this module using the given name - module (Module): child module to be added to the module. + name (str): The name of the child module. The child module can be accessed from this + module using the given name + module (nn.Module or None): The child module to be added to the module. """ if not isinstance(module, nn.Module) and module is not None: raise TypeError(f'{torch.typename(module)} is not a Module subclass') @@ -301,19 +301,19 @@ def add_module(self, name: str, module: Optional[nn.Module]) -> None: self._modules[name] = module - def register_module(self, name: str, module: Optional[nn.Module]) -> None: + def register_module(self, name: str, module: nn.Module | None) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module) - def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + def add_meta_module(self, name: str, meta_module: nn.Module | None) -> None: r"""Add a child meta-module to the current module. The meta-module can be accessed as an attribute using the given name. Args: - name (string): name of the child meta-module. The child meta-module can be - accessed from this module using the given name - meta_module (Module): child meta-module to be added to the module. + name (str): The name of the child meta-module. The child meta-module can be accessed + from this module using the given name + meta_module (nn.Module or None): The child meta-module to be added to the module. """ if not isinstance(meta_module, nn.Module) and meta_module is not None: raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass') @@ -328,7 +328,7 @@ def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: self._meta_modules[name] = meta_module - def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None: + def register_meta_module(self, name: str, meta_module: nn.Module | None) -> None: r"""Alias for :func:`add_meta_module`.""" self.add_meta_module(name, meta_module) @@ -338,9 +338,9 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: This is typically passed to an optimizer. Args: - recurse (bool): if True, then yields parameters of this module and - all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + recurse (bool, optional): If :data:`True`, then yields parameters of this module and + all submodules. Otherwise, yields only meta-parameters that are direct members of + this module. (default: :data:`True`) Yields: Parameter: module meta-parameter @@ -358,14 +358,15 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]: def named_meta_parameters( self, prefix: str = '', recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: + ) -> Iterator[tuple[str, torch.Tensor]]: r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself. Args: - prefix (str): prefix to prepend to all meta-parameter names. - recurse (bool): if True, then yields meta-parameters of this module - and all submodules. Otherwise, yields only meta-parameters that - are direct members of this module. + prefix (str, optional): The prefix to prepend to all meta-parameter names. + (default: :const:`''`) + recurse (bool, optional): if :data:`True`, then yields meta-parameters of this module + and all submodules. Otherwise, yields only meta-parameters that are direct members + of this module. (default: :data:`True`) Yields: (string, Parameter): Tuple containing the name and parameter @@ -398,7 +399,7 @@ def meta_children(self) -> Iterator[nn.Module]: for _, module in self.named_meta_children(): yield module - def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]: + def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]: r"""Return an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself. Yields: @@ -430,15 +431,18 @@ def meta_modules(self) -> Iterator[nn.Module]: yield meta_module def named_meta_modules( - self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Module]]: + self, memo: set[nn.Module] | None = None, prefix: str = '', remove_duplicate: bool = True + ) -> Iterator[tuple[str, nn.Module]]: r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself. Args: - memo: a memo to store the set of meta-modules already added to the result - prefix: a prefix that will be added to the name of the meta-module - remove_duplicate: whether to remove the duplicated meta-module instances in the result - or not + memo (set of nn.Module or None, optional): A memory to store the set of meta-modules + already added to the result. If not provided, a new set will be created. + (default: :const:`None`) + prefix (str, optional): A prefix that will be added to the name of the meta-module. + (default: :const:`''`) + remove_duplicate (bool, optional): whether to remove the duplicated meta-module + instances in the result or not. (default: :const:`True`) Yields: (string, Module): Tuple of name and meta-module diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index 2fc0dbb4..9391352f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -14,8 +14,10 @@ # ============================================================================== """Utility functions for stateless module calls.""" +from __future__ import annotations + import contextlib -from typing import Dict, Generator, Iterable, Tuple, Union +from typing import Generator, Iterable import torch import torch.nn as nn @@ -29,9 +31,9 @@ def swap_state( module: nn.Module, - named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]], + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """Swap the module parameters and/or buffers.""" if not isinstance(named_tensors, dict): named_tensors = dict(named_tensors) @@ -84,7 +86,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: @contextlib.contextmanager def reparametrize( module: nn.Module, - named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]], + named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, ) -> Generator[nn.Module, None, None]: """Reparameterize the module parameters and/or buffers.""" diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index c56956f8..640eea1d 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -14,7 +14,9 @@ # ============================================================================== """Adam optimizer.""" -from typing import Iterable, Tuple +from __future__ import annotations + +from typing import Iterable import torch @@ -39,7 +41,7 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -50,25 +52,27 @@ def __init__( r"""Initialize the Adam optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 19c70678..7db5e750 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """AdamW optimizer.""" -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable, Iterable import torch from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['AdamW'] @@ -39,46 +41,48 @@ def __init__( self, params: Iterable[torch.Tensor], lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, maximize: bool = False, use_accelerated_op: bool = False, ) -> None: r"""Initialize the AdamW optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index e894b93b..aac3a782 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -14,7 +14,9 @@ # ============================================================================== """The base class for optimizers.""" -from typing import Callable, Iterable, List, Optional, Sequence, Tuple +from __future__ import annotations + +from typing import Callable, Iterable, Sequence import torch @@ -37,8 +39,8 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by ``alias.py`` or a customized ``chain`` provided by - ``combine.py``. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to :class:`torchopt.SGD`. """ @@ -46,9 +48,9 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation) raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_groups: List[TupleOfTensors] = [] - self.param_treespecs: List[pytree.PyTreeSpec] = [] - self.state_groups: List[OptState] = [] + self.param_groups: list[TupleOfTensors] = [] + self.param_treespecs: list[pytree.PyTreeSpec] = [] + self.state_groups: list[OptState] = [] if not isinstance(params, (list, tuple)): params = tuple(params) @@ -60,7 +62,8 @@ def zero_grad(self, set_to_none: bool = False) -> None: The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`. Args: - set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`. + set_to_none (bool, optional): Instead of setting to zero, set the ``grads`` to + :data:`None`. (default: :data:`False`) """ if set_to_none: @@ -80,7 +83,7 @@ def f(p): pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type] - def state_dict(self) -> Tuple[OptState, ...]: + def state_dict(self) -> tuple[OptState, ...]: """Return the state of the optimizer.""" return tuple(self.state_groups) @@ -88,18 +91,19 @@ def load_state_dict(self, state_dict: Sequence[OptState]) -> None: """Load the optimizer state. Args: - state_dict: Optimizer state. Should be an object returned from a call to - :meth:`state_dict`. + state_dict (sequence of tree of Tensor): Optimizer state. Should be an object returned + from a call to :meth:`state_dict`. """ self.state_groups[:] = list(state_dict) - def step(self, closure: Optional[Callable[[], torch.Tensor]] = None) -> Optional[torch.Tensor]: + def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tensor | None: """Perform a single optimization step. The behavior is similar to :meth:`torch.optim.Optimizer.step`. Args: - closure (callable, optional): A closure that reevaluates the model and returns the loss. + closure (callable or None, optional): A closure that reevaluates the model and returns + the loss. Optional for most optimizers. (default: :data:`None`) """ loss = None if closure is not None: @@ -120,7 +124,7 @@ def f(p): return loss def add_param_group(self, params: Params) -> None: - """Add a param group to the optimizer's :attr:`param_groups`.""" + """Add a param group to the optimizer's ``param_groups``.""" flat_params: TupleOfTensors flat_params, params_treespec = pytree.tree_flatten_as_tuple(params) self.param_groups.append(flat_params) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7e51a21b..9dce3412 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -14,7 +14,7 @@ # ============================================================================== """Functional optimizer wrappers.""" -from typing import Optional +from __future__ import annotations import torch @@ -41,26 +41,27 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods """ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None: - """Initialize the functional optimizer wrapper. + r"""Initialize the functional optimizer wrapper. Args: impl (GradientTransformation): A low level optimizer function, it could be a optimizer - function provided by `alias.py` or a customized `chain` provided by `combine.py`. - inplace (optional): (default: :data:`False`) - The default value of ``inplace`` for each optimization update. + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. + inplace (bool, optional): The default value of ``inplace`` for each optimization update. + (default: :data:`False`) """ if not isinstance(impl, GradientTransformation): raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.optim_state: Optional[OptState] = UninitializedState() + self.optim_state: OptState | None = UninitializedState() self.inplace: bool = bool(inplace) def step( self, loss: torch.Tensor, params: Params, - inplace: Optional[bool] = None, + inplace: bool | None = None, ) -> Params: r"""Compute the gradients of loss to the network parameters and update network parameters. @@ -69,13 +70,12 @@ def step( gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - loss that is used to compute the gradients to network parameters. - params: (tree of torch.Tensor) - An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - inplace (optional): (default: :data:`None`) - Whether to update the parameters in-place. If :data:`None`, use the default value - specified in the constructor. + loss (Tensor): The loss that is used to compute the gradients to network parameters. + params (tree of Tensor): An tree of :class:`torch.Tensor`\s. Specifies what tensors + should be optimized. + inplace (bool or None, optional): Whether to update the parameters in-place. If + :data:`None`, use the default value specified in the constructor. + (default: :data:`None`) """ if isinstance(self.optim_state, UninitializedState): self.optim_state = self.impl.init(params) diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py index 36d54857..bd9804b9 100644 --- a/torchopt/optim/meta/adam.py +++ b/torchopt/optim/meta/adam.py @@ -14,7 +14,7 @@ # ============================================================================== """Differentiable Adam optimizer.""" -from typing import Tuple +from __future__ import annotations import torch.nn as nn @@ -39,7 +39,7 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, *, @@ -51,28 +51,26 @@ def __init__( """Initialize the meta-Adam optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index dc869e30..c8a8ef9c 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -14,13 +14,15 @@ # ============================================================================== """Differentiable AdamW optimizer.""" -from typing import Any, Callable, Optional, Tuple, Union +from __future__ import annotations + +from typing import Callable import torch.nn as nn from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import Params, ScalarOrSchedule +from torchopt.typing import OptState, Params, ScalarOrSchedule __all__ = ['MetaAdamW'] @@ -39,12 +41,12 @@ def __init__( self, module: nn.Module, lr: ScalarOrSchedule = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, *, eps_root: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, moment_requires_grad: bool = False, maximize: bool = False, use_accelerated_op: bool = False, @@ -52,37 +54,35 @@ def __init__( """Initialize the meta-AdamW optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-3`) - This is a fixed global scaling factor. - betas: (default: :const:`(0.9, 0.999)`) - Coefficients used for computing running averages of gradient and its square. - eps: (default: :const:`1e-8`) - A small constant applied to denominator outside of the square root (as in the Adam - paper) to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`1e-2`) - Strength of the weight decay regularization. Note that this weight decay is - multiplied with the learning rate. This is consistent with other frameworks such as - PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only - multiplied with the "schedule multiplier", but not the base learning rate. - eps_root: (default: :data:`0.0`) - A small constant applied to denominator inside the square root (as in RMSProp), to - avoid dividing by zero when rescaling. This is needed for example when computing - (meta-)gradients through Adam. - mask: (default: :data:`None`) - A tree with same structure as (or a prefix of) the params PyTree, or a Callable that + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Strength of the weight decay regularization. Note that + this weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + the weight decay is only multiplied with the "schedule multiplier", but not the base + learning rate. (default: :const:`1e-2`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): + A tree with same structure as (or a prefix of) the params pytree, or a function that returns such a pytree given the params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want to apply the weight decay to, and :data:`False` for those you want to skip. Note that the Adam gradient - transformations are applied to all parameters. - moment_requires_grad: (default: :data:`False`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + transformations are applied to all parameters. (default: :data:`None`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator. + (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 8db4f0a7..c5c9ad73 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -14,7 +14,9 @@ # ============================================================================== """The base class for differentiable meta-optimizers.""" -from typing import List, Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import torch import torch.nn as nn @@ -33,14 +35,13 @@ class MetaOptimizer: """The base class for high-level differentiable optimizers.""" def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: - """Initialize the meta-optimizer. + r"""Initialize the meta-optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - impl: (GradientTransformation) - A low level optimizer function, it could be a optimizer function provided by - ``alias.py`` or a customized ``chain`` provided by ``combine.py``. + module (nn.Module): A network whose parameters should be optimized. + impl (GradientTransformation): A low level optimizer function, it could be a optimizer + function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed + transformation. Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or ``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to :class:`torchopt.MetaSGD`. @@ -49,8 +50,8 @@ def __init__(self, module: nn.Module, impl: GradientTransformation) -> None: raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation') self.impl: GradientTransformation = impl - self.param_containers_groups: List[ModuleTensorContainers] = [] - self.state_groups: List[OptState] = [] + self.param_containers_groups: list[ModuleTensorContainers] = [] + self.state_groups: list[OptState] = [] self.add_param_group(module) @@ -62,8 +63,8 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals gradients and update the network parameters without modifying tensors in-place. Args: - loss: (torch.Tensor) - The loss that is used to compute the gradients to the network parameters. + loss (torch.Tensor): The loss that is used to compute the gradients to the network + parameters. """ # Step parameter only for i, (param_container, state) in enumerate( @@ -94,12 +95,12 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals container.update(new_param) def add_param_group(self, module: nn.Module) -> None: - """Add a param group to the optimizer's :attr:`state_groups`.""" + """Add a param group to the optimizer's ``state_groups``.""" params_container = extract_module_containers(module, with_buffers=False)[0] self.param_containers_groups.append(params_container) self.state_groups.append(UninitializedState()) - def state_dict(self) -> Tuple[OptState, ...]: + def state_dict(self) -> tuple[OptState, ...]: """Extract the references of the optimizer states. Note that the states are references, so any in-place operations will change the states diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py index f4dfdae6..3aff20e1 100644 --- a/torchopt/optim/meta/rmsprop.py +++ b/torchopt/optim/meta/rmsprop.py @@ -50,30 +50,26 @@ def __init__( """Initialize the meta-RMSProp optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py index 5f9177e1..476ed9d6 100644 --- a/torchopt/optim/meta/sgd.py +++ b/torchopt/optim/meta/sgd.py @@ -47,23 +47,20 @@ def __init__( """Initialize the meta-SGD optimizer. Args: - module: (nn.Module) - A network whose parameters should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :const:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`True`) - If :data:`True` the momentums will be created with flag ``requires_grad=True``, this - flag is often used in Meta-Learning algorithms. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py index 9101984f..5c4e536f 100644 --- a/torchopt/optim/rmsprop.py +++ b/torchopt/optim/rmsprop.py @@ -52,30 +52,27 @@ def __init__( r"""Initialize the RMSProp optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what Tensors should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - alpha: (default: :const:`0.99`) - Smoothing constant, the decay used to track the magnitude of previous gradients. - eps: (default: :const:`1e-8`) - A small numerical constant to avoid dividing by zero when rescaling. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - centered: (default: :data:`False`) - If :data:`True`, use the variance of the past gradients to rescale the latest - gradients. - initial_scale: (default: :data:`0.0`) - Initialization of accumulators tracking the magnitude of previous updates. PyTorch - uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a - paper, verify the value used by the authors. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + alpha (float, optional): Smoothing constant, the decay used to track the magnitude of + previous gradients. (default: :const:`0.99`) + eps (float, optional): A small numerical constant to avoid dividing by zero when + rescaling. (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + centered (bool, optional): If :data:`True`, use the variance of the past gradients to + rescale the latest gradients. (default: :data:`False`) + initial_scale (float, optional): Initialization of accumulators tracking the magnitude + of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When + reproducing results from a paper, verify the value used by the authors. + (default: :data:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py index 223e856e..3da9595a 100644 --- a/torchopt/optim/sgd.py +++ b/torchopt/optim/sgd.py @@ -48,20 +48,21 @@ def __init__( r"""Initialize the SGD optimizer. Args: - params: (iterable of torch.Tensor) - An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: This is a fixed global scaling factor. - momentum: (default: :const:`0.0`) - The decay rate used by the momentum term. The momentum is not used when it is set to - :const:`0.0`. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable): This is a fixed global scaling factor or a learning rate + scheduler. + momentum (float, optional): The decay rate used by the momentum term. The momentum is + not used when it is set to :const:`0.0`. (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/pytree.py b/torchopt/pytree.py index 0abcf4fd..d3b2d181 100644 --- a/torchopt/pytree.py +++ b/torchopt/pytree.py @@ -14,9 +14,11 @@ # ============================================================================== """The PyTree utilities.""" +from __future__ import annotations + import functools import operator -from typing import Callable, List, Optional, Tuple +from typing import Callable import optree import optree.typing as typing # pylint: disable=unused-import @@ -47,19 +49,20 @@ def tree_flatten_as_tuple( tree: PyTree[T], - is_leaf: Optional[Callable[[T], bool]] = None, + is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, namespace: str = '', -) -> Tuple[Tuple[T, ...], PyTreeSpec]: +) -> tuple[tuple[T, ...], PyTreeSpec]: """Flatten a pytree to a tuple of leaves and a PyTreeSpec. Args: - tree: The pytree to flatten. - is_leaf: A function that returns :data:`True` if a given node is a leaf. - none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no - children. - namespace: The namespace of custom tree node types. + tree (pytree): The pytree to flatten. + is_leaf (callable or None, optional): An optionally specified function that returns + :data:`True` if a given node is a leaf. (default: :data:`None`) + none_is_leaf (bool, optional): If :data:`True`, :data:`None` is considered a leaf rather + than a internal node with no children. (default: :data:`False`) + namespace (str, optional): The namespace of custom tree node types. (default: :const:`''`) Returns: A tuple of (leaves, treespec). @@ -99,7 +102,7 @@ def tree_add(*trees: PyTree[T]) -> PyTree[T]: def tree_add_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: """Compute ``tree_x + alpha * tree_y``.""" if alpha is None: @@ -113,7 +116,7 @@ def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]: def tree_sub_scalar_mul( - tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None + tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None ) -> TensorTree: """Compute ``tree_x - alpha * tree_y``.""" if alpha is None: @@ -190,4 +193,4 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]: __all__.extend(['tree_as_rref', 'tree_to_here']) -del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef +del Callable, optree, rpc, Scalar, T, RRef diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 8a8e51e8..d54dbf17 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -52,18 +52,17 @@ def polynomial_schedule( """Construct a schedule with polynomial transition from init to end value. Args: - init_value: Initial value for the scalar to be annealed. - end_value: End value of the scalar to be annealed. - power: The power of the polynomial used to transition from ``init`` to ``end``. - transition_steps: - Number of steps over which annealing takes place, the scalar starts changing at - ``transition_begin`` steps and completes the transition by - ``transition_begin + transition_steps`` steps. - If ``transition_steps <= 0``, then the entire annealing process is disabled and the - value is held fixed at ``init_value``. - transition_begin: - Must be *positive*. After how many steps to start annealing (before this many steps the - scalar value is held fixed at ``init_value``). + init_value (float or Tensor): Initial value for the scalar to be annealed. + end_value (float or Tensor): End value of the scalar to be annealed. + power (float or Tensor): The power of the polynomial used to transition from ``init`` to + ``end``. + transition_steps (int): Number of steps over which annealing takes place, the scalar starts + changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) Returns: schedule: diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 772e6291..14745766 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -32,7 +32,9 @@ # ============================================================================== """Preset transformations for adding weight decay to updates.""" -from typing import Any, Callable, NamedTuple, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any, Callable, NamedTuple from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity @@ -59,7 +61,7 @@ class MaskedNode(NamedTuple): def masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Mask updates so only some are transformed, the rest are passed through. @@ -75,11 +77,12 @@ def masked( of :data:`True`. Args: - inner: Inner transformation to mask. - mask: A tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a tree given the params/updates. The leaves should be booleans, :data:`True` - for leaves/subtrees you want to apply the transformation to, and :data:`False` for those - you want to skip. The mask must be static for the gradient transformation to be jit-compilable. + inner (GradientTransformation): Inner transformation to mask. + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: A :class:`GradientTransformation` wrapping ``inner``. @@ -89,14 +92,14 @@ def masked( def _masked_flat( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _masked(inner, mask, already_flattened=True) def _masked( inner: GradientTransformation, - mask: Union[Any, Callable[[Params], Any]], + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: @@ -117,9 +120,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mask_tree = mask(updates) if callable(mask) else mask masked_updates = tree_mask(updates, mask_tree) masked_params = None if params is None else tree_mask(params, mask_tree) @@ -145,16 +148,17 @@ def update_fn( def add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: """Add parameter scaled by `weight_decay`. Args: - weight_decay: a scalar weight decay rate. - mask: a tree with same structure as (or a prefix of) the params tree, or a Callable that - returns such a pytree given the params/updates. The leaves should be booleans, - :data:`True` for leaves/subtrees you want to apply the transformation to, and - :data:`False` for those you want to skip. + weight_decay (float, optional): A scalar weight decay rate. (default: :const:`0.0`) + mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a + prefix of) the params tree, or a function that returns such a tree given the + params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want + to apply the transformation to, and :data:`False` for those you want to skip. + (default: :data:`None`) Returns: An (init_fn, update_fn) tuple. @@ -168,7 +172,7 @@ def add_decayed_weights( def _add_decayed_weights_flat( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, ) -> GradientTransformation: return _add_decayed_weights( weight_decay=weight_decay, @@ -179,7 +183,7 @@ def _add_decayed_weights_flat( def _add_decayed_weights( weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[Params], Any]]] = None, + mask: OptState | Callable[[Params], OptState] | None = None, *, already_flattened: bool = False, ) -> GradientTransformation: @@ -204,9 +208,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, + params: Params | None = None, inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: assert params is not None, ( 'Parameters are required for weight decay. ' 'Call `update(updates, state, params=params)` instead.' diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 2c0b9d5e..804f8219 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -14,7 +14,7 @@ # ============================================================================== """Preset transformations that replaces updates with non-finite values to the given numbers.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation @@ -23,8 +23,8 @@ def nan_to_num( nan: float = 0.0, - posinf: Optional[float] = None, - neginf: Optional[float] = None, + posinf: float | None = None, + neginf: float | None = None, ) -> GradientTransformation: """Replace updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers. @@ -39,9 +39,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 4afac163..639c903e 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -31,7 +31,7 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate.""" -from typing import Optional, Tuple +from __future__ import annotations from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation @@ -49,7 +49,7 @@ def scale(step_size: float) -> GradientTransformation: """Scale updates by some fixed scalar ``step_size``. Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. + step_size (float): A scalar corresponding to a fixed scaling factor for updates. Returns: An ``(init_fn, update_fn)`` tuple. @@ -80,9 +80,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g): diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index 039d31fb..36f30be9 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -88,17 +90,17 @@ def scale_by_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -169,9 +171,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened ) @@ -218,17 +220,17 @@ def scale_by_accelerated_adam( [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: - b1: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of grads. - b2: (default: :const:`0.999`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - eps_root: (default: :const:`0.0`) - Term added to the denominator inside the square-root to improve + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -285,9 +287,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined] op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace) @@ -303,9 +305,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined] treespec = pytree.tree_structure(updates, none_is_leaf=True) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 7a685f6b..7a0c8c20 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset transformations for scaling updates by exponential root mean-squared (RMS).""" -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -61,12 +63,11 @@ def scale_by_rms( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -121,9 +122,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: nu = update_moment.impl( # type: ignore[attr-defined] updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened ) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index 5556d111..d6e3b0fa 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -31,7 +31,9 @@ # ============================================================================== """Preset transformation for scaling updates by learning rate schedules.""" -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -54,9 +56,8 @@ def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation: """Scale updates using a custom schedule for the ``step_size``. Args: - step_size_fn: - A function that takes an update count as input and proposes the ``step_size`` to - multiply the updates by. + step_size_fn (callable): A function that takes an update count as input and proposes the + ``step_size`` to multiply the updates by. Returns: An ``(init_fn, update_fn)`` tuple. @@ -90,9 +91,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: if inplace: def f(g, c): # pylint: disable=invalid-name diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index c15a0d6c..228ed707 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -64,12 +66,11 @@ def scale_by_stddev( [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: - alpha: (default: :const:`0.9`) - Decay rate for the exponentially weighted average of squared grads. - eps: (default: :const:`1e-8`) - Term added to the denominator to improve numerical stability. - initial_scale: (default: :const:`0.0`) - Initial value for second moment + alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`) Returns: An (init_fn, update_fn) tuple. @@ -125,9 +126,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: mu = update_moment.impl( # type: ignore[attr-defined] updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened ) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 45e043f0..03d2441d 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -33,7 +33,9 @@ # pylint: disable=invalid-name -from typing import NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple import torch @@ -65,14 +67,12 @@ def trace( Both are frequently found in the optimization literature. Args: - momentum: (default: :const:`0.9`) - The decay rate for the trace of past updates. - dampening: (default: :const:`0.0`) - Dampening for momentum. - nesterov: (default: :data:`False`) - Whether to use Nesterov momentum. - moment_requires_grad: (default: :data:`False`) - If :data:`True`, states will be created with flag `requires_grad = True`. + momentum (float, optional): The decay rate for the trace of past updates. + (default: :const:`0.9`) + dampening (float, optional): Dampening for momentum. (default: :const:`0.0`) + nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) Returns: An (init_fn, update_fn) tuple. @@ -139,9 +139,9 @@ def update_fn( updates: Updates, state: OptState, *, - params: Optional[Params] = None, # pylint: disable=unused-argument + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> Tuple[Updates, OptState]: + ) -> tuple[Updates, OptState]: nonlocal first_call if nesterov: diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index a9f02295..77ba58ca 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -31,6 +31,8 @@ # ============================================================================== """Utilities for the preset transformations.""" +from __future__ import annotations + from collections import deque from typing import Any, Callable, Sequence diff --git a/torchopt/update.py b/torchopt/update.py index 3fdd38e1..9485896b 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -48,11 +48,11 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> :func:`tree_map` (e.g. if you want to manipulate updates in custom ways before applying them). Args: - params: A tree of parameters. - updates: - A tree of updates, the tree structure and the shape of the leaf nodes must match that - of ``params``. - inplace: If :data:`True`, will update params in a inplace manner. + params (tree of Tensor): A tree of parameters. + updates (tree of Tensor): A tree of updates, the tree structure and the shape of the leaf + nodes must match that of ``params``. + inplace (bool, optional): If :data:`True`, will update params in a inplace manner. + (default: :data:`True`) Returns: Updated parameters, with same structure, shape and type as ``params``. diff --git a/torchopt/utils.py b/torchopt/utils.py index 4deaba8b..12adb214 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -14,21 +14,11 @@ # ============================================================================== """Utilities for TorchOpt.""" +from __future__ import annotations + import copy import itertools -from typing import ( - TYPE_CHECKING, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -56,32 +46,30 @@ class ModuleState(NamedTuple): """Container for module state.""" - params: Tuple[Dict[str, torch.Tensor], ...] - buffers: Tuple[Dict[str, torch.Tensor], ...] - visual_contents: Optional[Dict] = None + params: tuple[dict[str, torch.Tensor], ...] + buffers: tuple[dict[str, torch.Tensor], ...] + visual_contents: dict | None = None detach_buffers: bool = False CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone'] -def stop_gradient(target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]) -> None: +def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None: """Stop the gradient for the input object. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the :func:`stop_gradient` operation is in-place. Args: - target: The target that to be detached from the computation graph, it could be a - :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the - :class:`torchopt.MetaOptimizer`, or just a plain list of tensors. - inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this - function will return a detached copy of the target. The in-place operation is fast and - memory efficient but may raise backpropagation error. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be + detached from the computation graph, it could be a :class:`nn.Module`, + :class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just + a plain list of tensors. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -108,7 +96,7 @@ def extract_state_dict( target: nn.Module, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', @@ -118,57 +106,62 @@ def extract_state_dict( @overload def extract_state_dict( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, enable_visual: bool = False, visual_prefix: str = '', -) -> Tuple[OptState, ...]: # pragma: no cover +) -> tuple[OptState, ...]: # pragma: no cover ... # pylint: disable-next=too-many-branches,too-many-locals def extract_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], + target: nn.Module | MetaOptimizer, *, by: CopyMode = 'reference', - device: Optional[Device] = None, + device: Device | None = None, with_buffers: bool = True, detach_buffers: bool = False, enable_visual: bool = False, visual_prefix: str = '', -) -> Union[ModuleState, Tuple[OptState, ...]]: +) -> ModuleState | tuple[OptState, ...]: """Extract target state. - Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the + Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the backpropagated gradient will flow over the tensor and continue flow to the tensors that is - connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the + connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the computation graph. Note that the extracted state is a reference, which means any in-place operator will affect the target that the state is extracted from. Args: - target: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`. - by: The extract policy of tensors in the target. + target (nn.Module or MetaOptimizer): It could be a :class:`nn.Module` or + :class:`torchopt.MetaOptimizer`. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - device: If specified, move the extracted state to the specified device. - with_buffers: Extract buffer together with parameters, this argument is only used if the - input target is :class:`nn.Module`. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - enable_visual: Add additional annotations, which could be used in computation graph - visualization. Currently, this flag only has effect on :class:`nn.Module` but we will - support :class:`torchopt.MetaOptimizer` later. - visual_prefix: Prefix for the visualization annotations. + device (Device or None, optional): If specified, move the extracted state to the specified + device. (default: :const:`None`) + with_buffers (bool, optional): Extract buffer together with parameters, this argument is + only used if the input target is :class:`nn.Module`. (default: :const:`True`) + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + enable_visual (bool, optional): Add additional annotations, which could be used in + computation graph visualization. Currently, this flag only has effect on + :class:`nn.Module` but we will support :class:`torchopt.MetaOptimizer` later. + (default: :const:`False`) + visual_prefix (str, optional): Prefix for the visualization annotations. + (default: :const:`''`) Returns: State extracted of the input object. @@ -228,9 +221,9 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: visual_contents = None - params: List[Dict[str, torch.Tensor]] = [] - buffers: List[Dict[str, torch.Tensor]] = [] - memo: Set[nn.Module] = set() + params: list[dict[str, torch.Tensor]] = [] + buffers: list[dict[str, torch.Tensor]] = [] + memo: set[nn.Module] = set() def update_params(container): if len(container) > 0: @@ -287,12 +280,12 @@ def get_variable(t): def extract_module_containers( module: nn.Module, with_buffers: bool = True -) -> Tuple[ModuleTensorContainers, ModuleTensorContainers]: +) -> tuple[ModuleTensorContainers, ModuleTensorContainers]: """Extract the references to the containers of parameters and buffers from a module.""" if isinstance(module, nn.Module): - params: List[TensorContainer] = [] - buffers: List[TensorContainer] = [] - memo: Set[nn.Module] = set() + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] + memo: set[nn.Module] = set() def update_container(container, items): if len(items) > 0: @@ -316,8 +309,8 @@ def update_container(container, items): def recover_state_dict( - target: Union[nn.Module, 'MetaOptimizer'], - state: Union[ModuleState, Sequence[OptState]], + target: nn.Module | MetaOptimizer, + state: ModuleState | Sequence[OptState], ) -> None: """Recover state. @@ -327,8 +320,8 @@ def recover_state_dict( modified. Args: - target: Target that need to recover. - state: The recovering state. + target (nn.Module or MetaOptimizer): Target that need to recover. + state (ModuleState or sequence of tree of Tensor): The recovering state. """ # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer @@ -344,10 +337,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad) return t.clone().detach_().requires_grad_(t.requires_grad) - buffers = cast( - Tuple[Dict[str, torch.Tensor], ...], - pytree.tree_map(clone_detach_, buffers), # type: ignore[arg-type] - ) + buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type] for tgt, src in itertools.chain( zip(params_containers, params), @@ -367,19 +357,19 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, + device: Device | None = None, ) -> nn.Module: # pragma: no cover ... @overload def module_clone( - target: 'MetaOptimizer', + target: MetaOptimizer, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, -) -> 'MetaOptimizer': # pragma: no cover + device: Device | None = None, +) -> MetaOptimizer: # pragma: no cover ... @@ -389,34 +379,36 @@ def module_clone( *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, + device: Device | None = None, ) -> TensorTree: # pragma: no cover ... # pylint: disable-next=too-many-locals def module_clone( - target: Union[nn.Module, 'MetaOptimizer', TensorTree], + target: nn.Module | MetaOptimizer | TensorTree, *, by: CopyMode = 'reference', detach_buffers: bool = False, - device: Optional[Device] = None, -) -> Union[nn.Module, 'MetaOptimizer', TensorTree]: + device: Device | None = None, +) -> nn.Module | MetaOptimizer | TensorTree: """Clone a module. Args: - target: The target to be cloned. - by: The extract policy of tensors in the target. + target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned. + by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`) - :const:`'reference'`: The extracted tensors will be references to the original tensors. - :const:`'copy'`: The extracted tensors will be clones of the original tensors. This - makes the copied tensors have :attr:`grad_fn` to be a ```` function - points to the original tensors. + makes the copied tensors have ``grad_fn`` to be a ```` function points + to the original tensors. - :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original tensors. The deep-copied tensors will detach from the original computation graph. - detach_buffers: Whether to detach the reference to the buffers, this argument is only used - if the input target is :class:`nn.Module` and ``by='reference'``. - device: If specified, move the cloned module to the specified device. + detach_buffers (bool, optional): Whether to detach the reference to the buffers, this + argument is only used if the input target is :class:`nn.Module` and ``by='reference'``. + (default: :const:`False`) + device (Device or None, optional): If specified, move the cloned module to the specified + device. (default: :const:`None`) Returns: The cloned module. @@ -499,7 +491,7 @@ def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover @overload -def module_detach_(target: 'MetaOptimizer') -> 'MetaOptimizer': # pragma: no cover +def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover ... @@ -509,12 +501,13 @@ def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover def module_detach_( - target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree] -) -> Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]: + target: ModuleState | nn.Module | MetaOptimizer | TensorTree, +) -> ModuleState | nn.Module | MetaOptimizer | TensorTree: """Detach a module from the computation graph. Args: - target: The target to be detached. + target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The + target to be detached. Returns: The detached module. diff --git a/torchopt/visual.py b/torchopt/visual.py index e8145240..7afe65a4 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -17,8 +17,10 @@ # ============================================================================== """Computation graph visualization.""" +from __future__ import annotations + from collections import namedtuple -from typing import Generator, Iterable, Mapping, Optional, Union, cast +from typing import Generator, Iterable, Mapping, cast import torch from graphviz import Digraph @@ -71,14 +73,13 @@ def truncate(s): # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( var: TensorOrTensors, - params: Optional[ - Union[ - Mapping[str, torch.Tensor], - ModuleState, - Generator, - Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]], - ] - ] = None, + params: ( + Mapping[str, torch.Tensor] + | ModuleState + | Generator + | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator] + | None + ) = None, show_attrs: bool = False, show_saved: bool = False, max_attr_chars: int = 50, @@ -89,7 +90,7 @@ def make_dot( and is either blue, orange, or green: - **Blue** - Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be + Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be populated during :meth:`backward`). - **Orange** Saved tensors of custom autograd functions as well as those saved by built-in backward @@ -100,16 +101,16 @@ def make_dot( If any output is a view, we represent its base tensor with a dark green node. Args: - var: Output tensor. - params: ([dict of (name, tensor) or state_dict]) - Parameters to add names to node that requires grad. - show_attrs: Whether to display non-tensor attributes of backward nodes - (Requires PyTorch version >= 1.9) - show_saved: Whether to display saved tensor nodes that are not by custom autograd - functions. Saved tensor nodes for custom functions, if present, are always displayed. - (Requires PyTorch version >= 1.9) - max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display - for any given attribute. + var (Tensor or sequence of Tensor): Output tensor. + params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional): + Parameters to add names to node that requires grad. (default: :data:`None`) + show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes. + (default: :data:`False`) + show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom + autograd functions. Saved tensor nodes for custom functions, if present, are always + displayed. (default: :data:`False`) + max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of + characters to display for any given attribute. (default: :const:`50`) """ param_map = {}