diff --git a/CHANGELOG.md b/CHANGELOG.md
index 74d23144..3a3ad174 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
+- Use postponed evaluation of annotations and update doctring style by [@XuehaiPan](https://github.com/XuehaiPan) in [#135](https://github.com/metaopt/torchopt/pull/135).
- Rewrite setup CUDA Toolkit logic by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/torchopt/pull/133).
### Fixed
diff --git a/README.md b/README.md
index c1fb97ba..321f39e3 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,6 @@
![CodeCov](https://img.shields.io/codecov/c/gh/metaopt/torchopt)
![Documentation Status](https://img.shields.io/readthedocs/torchopt?logo=readthedocs)
![Downloads](https://static.pepy.tech/personalized-badge/torchopt?period=total&left_color=grey&right_color=blue&left_text=downloads)
- ![GitHub Repo Stars](https://img.shields.io/github/stars/metaopt/torchopt?color=brightgreen&logo=github)
![License](https://img.shields.io/github/license/metaopt/torchopt?label=license&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCIgd2lkdGg9IjI0IiBoZWlnaHQ9IjI0IiBmaWxsPSIjZmZmZmZmIj48cGF0aCBmaWxsLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xMi43NSAyLjc1YS43NS43NSAwIDAwLTEuNSAwVjQuNUg5LjI3NmExLjc1IDEuNzUgMCAwMC0uOTg1LjMwM0w2LjU5NiA1Ljk1N0EuMjUuMjUgMCAwMTYuNDU1IDZIMi4zNTNhLjc1Ljc1IDAgMTAwIDEuNUgzLjkzTC41NjMgMTUuMThhLjc2Mi43NjIgMCAwMC4yMS44OGMuMDguMDY0LjE2MS4xMjUuMzA5LjIyMS4xODYuMTIxLjQ1Mi4yNzguNzkyLjQzMy42OC4zMTEgMS42NjIuNjIgMi44NzYuNjJhNi45MTkgNi45MTkgMCAwMDIuODc2LS42MmMuMzQtLjE1NS42MDYtLjMxMi43OTItLjQzMy4xNS0uMDk3LjIzLS4xNTguMzEtLjIyM2EuNzUuNzUgMCAwMC4yMDktLjg3OEw1LjU2OSA3LjVoLjg4NmMuMzUxIDAgLjY5NC0uMTA2Ljk4NC0uMzAzbDEuNjk2LTEuMTU0QS4yNS4yNSAwIDAxOS4yNzUgNmgxLjk3NXYxNC41SDYuNzYzYS43NS43NSAwIDAwMCAxLjVoMTAuNDc0YS43NS43NSAwIDAwMC0xLjVIMTIuNzVWNmgxLjk3NGMuMDUgMCAuMS4wMTUuMTQuMDQzbDEuNjk3IDEuMTU0Yy4yOS4xOTcuNjMzLjMwMy45ODQuMzAzaC44ODZsLTMuMzY4IDcuNjhhLjc1Ljc1IDAgMDAuMjMuODk2Yy4wMTIuMDA5IDAgMCAuMDAyIDBhMy4xNTQgMy4xNTQgMCAwMC4zMS4yMDZjLjE4NS4xMTIuNDUuMjU2Ljc5LjRhNy4zNDMgNy4zNDMgMCAwMDIuODU1LjU2OCA3LjM0MyA3LjM0MyAwIDAwMi44NTYtLjU2OWMuMzM4LS4xNDMuNjA0LS4yODcuNzktLjM5OWEzLjUgMy41IDAgMDAuMzEtLjIwNi43NS43NSAwIDAwLjIzLS44OTZMMjAuMDcgNy41aDEuNTc4YS43NS43NSAwIDAwMC0xLjVoLTQuMTAyYS4yNS4yNSAwIDAxLS4xNC0uMDQzbC0xLjY5Ny0xLjE1NGExLjc1IDEuNzUgMCAwMC0uOTg0LS4zMDNIMTIuNzVWMi43NXpNMi4xOTMgMTUuMTk4YTUuNDE4IDUuNDE4IDAgMDAyLjU1Ny42MzUgNS40MTggNS40MTggMCAwMDIuNTU3LS42MzVMNC43NSA5LjM2OGwtMi41NTcgNS44M3ptMTQuNTEtLjAyNGMuMDgyLjA0LjE3NC4wODMuMjc1LjEyNi41My4yMjMgMS4zMDUuNDUgMi4yNzIuNDVhNS44NDYgNS44NDYgMCAwMDIuNTQ3LS41NzZMMTkuMjUgOS4zNjdsLTIuNTQ3IDUuODA3eiI+PC9wYXRoPjwvc3ZnPgo=)
diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst
index c7e04e95..b2866407 100644
--- a/docs/source/api/api.rst
+++ b/docs/source/api/api.rst
@@ -285,6 +285,115 @@ Chain
.. autofunction:: chain
+Distributed Utilities
+=====================
+
+.. currentmodule:: torchopt.distributed
+
+Initialization and Synchronization
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ auto_init_rpc
+ barrier
+
+.. autofunction:: auto_init_rpc
+.. autofunction:: barrier
+
+Process group information
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ get_world_info
+ get_world_rank
+ get_rank
+ get_world_size
+ get_local_rank
+ get_local_world_size
+ get_worker_id
+
+.. autofunction:: get_world_info
+.. autofunction:: get_world_rank
+.. autofunction:: get_rank
+.. autofunction:: get_world_size
+.. autofunction:: get_local_rank
+.. autofunction:: get_local_world_size
+.. autofunction:: get_worker_id
+
+Worker selection
+~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ on_rank
+ not_on_rank
+ rank_zero_only
+ rank_non_zero_only
+
+.. autofunction:: on_rank
+.. autofunction:: not_on_rank
+.. autofunction:: rank_zero_only
+.. autofunction:: rank_non_zero_only
+
+Remote Procedure Call (RPC)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ remote_async_call
+ remote_sync_call
+
+.. autofunction:: remote_async_call
+.. autofunction:: remote_sync_call
+
+Predefined partitioners and reducers
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ dim_partitioner
+ batch_partitioner
+ mean_reducer
+ sum_reducer
+
+.. autofunction:: dim_partitioner
+.. autofunction:: batch_partitioner
+.. autofunction:: mean_reducer
+.. autofunction:: sum_reducer
+
+Function parallelization wrappers
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+
+ parallelize
+ parallelize_async
+ parallelize_sync
+
+.. autofunction:: parallelize
+.. autofunction:: parallelize_async
+.. autofunction:: parallelize_sync
+
+Distributed Autograd
+~~~~~~~~~~~~~~~~~~~~
+
+.. currentmodule:: torchopt.distributed.autograd
+
+.. autosummary::
+
+ context
+ get_gradients
+ backward
+ grad
+
+.. autofunction:: context
+.. autofunction:: get_gradients
+.. autofunction:: backward
+.. autofunction:: grad
+
+
General Utilities
=================
diff --git a/docs/source/distributed/distributed.rst b/docs/source/distributed/distributed.rst
index f85eec3f..b6f00951 100644
--- a/docs/source/distributed/distributed.rst
+++ b/docs/source/distributed/distributed.rst
@@ -142,7 +142,6 @@ Initialization and Synchronization
.. autosummary::
-
torchopt.distributed.auto_init_rpc
torchopt.distributed.barrier
@@ -197,7 +196,6 @@ Process group information
.. autosummary::
-
torchopt.distributed.get_world_info
torchopt.distributed.get_world_rank
torchopt.distributed.get_rank
@@ -228,7 +226,6 @@ Worker selection
.. autosummary::
-
torchopt.distributed.on_rank
torchopt.distributed.not_on_rank
torchopt.distributed.rank_zero_only
@@ -275,7 +272,6 @@ Remote Procedure Call (RPC)
.. autosummary::
-
torchopt.distributed.remote_async_call
torchopt.distributed.remote_sync_call
@@ -354,7 +350,6 @@ Predefined partitioners and reducers
.. autosummary::
-
torchopt.distributed.dim_partitioner
torchopt.distributed.batch_partitioner
torchopt.distributed.mean_reducer
@@ -439,7 +434,6 @@ Function parallelization wrappers
.. autosummary::
-
torchopt.distributed.parallelize
torchopt.distributed.parallelize_async
torchopt.distributed.parallelize_sync
@@ -490,7 +484,6 @@ Distributed Autograd
.. autosummary::
-
torchopt.distributed.autograd.context
torchopt.distributed.autograd.get_gradients
torchopt.distributed.autograd.backward
diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt
index 8f9d6895..aac17046 100644
--- a/docs/source/spelling_wordlist.txt
+++ b/docs/source/spelling_wordlist.txt
@@ -171,3 +171,4 @@ issubclass
abc
ABCMeta
subclasscheck
+ctx
diff --git a/tests/helpers.py b/tests/helpers.py
index 4bba706e..23e178f0 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================
+from __future__ import annotations
+
import copy
import itertools
import os
import random
-from typing import Iterable, Optional, Tuple, Union
+from typing import Iterable
import numpy as np
import pytest
@@ -137,7 +139,7 @@ def get_model():
@torch.no_grad()
def get_models(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
-) -> Tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
+) -> tuple[nn.Module, nn.Module, nn.Module, data.DataLoader]:
seed_everything(seed=42)
model_base = get_model().to(dtype=dtype)
@@ -166,12 +168,12 @@ def get_models(
@torch.no_grad()
def assert_model_all_close(
- model: Union[nn.Module, Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]],
+ model: nn.Module | tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]],
model_ref: nn.Module,
model_base: nn.Module,
dtype: torch.dtype = torch.float32,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
+ rtol: float | None = None,
+ atol: float | None = None,
equal_nan: bool = False,
) -> None:
if isinstance(model, tuple):
@@ -194,8 +196,8 @@ def assert_all_close(
actual: torch.Tensor,
expected: torch.Tensor,
base: torch.Tensor = None,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
+ rtol: float | None = None,
+ atol: float | None = None,
equal_nan: bool = False,
) -> None:
if base is not None:
@@ -223,9 +225,9 @@ def assert_all_close(
def assert_pytree_all_close(
actual: TensorTree,
expected: TensorTree,
- base: Optional[TensorTree] = None,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
+ base: TensorTree | None = None,
+ rtol: float | None = None,
+ atol: float | None = None,
equal_nan: bool = False,
) -> None:
actual_leaves, actual_treespec = pytree.tree_flatten(actual)
diff --git a/tests/test_alias.py b/tests/test_alias.py
index c613d7d5..b609cf58 100644
--- a/tests/test_alias.py
+++ b/tests/test_alias.py
@@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================
-from typing import Callable, Tuple
+from __future__ import annotations
+
+from typing import Callable
import functorch
import pytest
@@ -107,7 +109,7 @@ def test_sgd(
def test_adam(
dtype: torch.dtype,
lr: float,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
@@ -177,7 +179,7 @@ def test_maml_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
@@ -263,7 +265,7 @@ def maml_inner_solver_torchopt(params, data, use_accelerated_op):
def test_adamw(
dtype: torch.dtype,
lr: float,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
@@ -333,8 +335,8 @@ def test_adamw(
def test_adam_accelerated_cuda(
dtype: torch.dtype,
lr: float,
- optimizers: Tuple[Callable, torch.optim.Optimizer],
- betas: Tuple[float, float],
+ optimizers: tuple[Callable, torch.optim.Optimizer],
+ betas: tuple[float, float],
eps: float,
inplace: bool,
weight_decay: float,
diff --git a/tests/test_implicit.py b/tests/test_implicit.py
index ce0ee23b..9e3722d3 100644
--- a/tests/test_implicit.py
+++ b/tests/test_implicit.py
@@ -13,10 +13,11 @@
# limitations under the License.
# ==============================================================================
+from __future__ import annotations
+
import copy
from collections import OrderedDict
from types import FunctionType
-from typing import Tuple
import functorch
import jax
@@ -55,7 +56,7 @@ def forward(self, x):
return self.fc(x)
-def get_model_jax(dtype: np.dtype = np.float32) -> Tuple[FunctionType, OrderedDict]:
+def get_model_jax(dtype: np.dtype = np.float32) -> tuple[FunctionType, OrderedDict]:
helpers.seed_everything(seed=42)
def func(params, x):
@@ -73,7 +74,7 @@ def func(params, x):
@torch.no_grad()
def get_model_torch(
device: torch.types.Device = None, dtype: torch.dtype = torch.float32
-) -> Tuple[nn.Module, data.DataLoader]:
+) -> tuple[nn.Module, data.DataLoader]:
helpers.seed_everything(seed=42)
model = FcNet(MODEL_NUM_INPUTS, MODEL_NUM_CLASSES).to(dtype=dtype)
diff --git a/tests/test_meta_optim.py b/tests/test_meta_optim.py
index 2c0966cc..61f8a7ad 100644
--- a/tests/test_meta_optim.py
+++ b/tests/test_meta_optim.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
-from typing import Tuple
+from __future__ import annotations
import torch
import torch.nn.functional as F
@@ -40,7 +40,7 @@ def test_maml_meta_adam(
outer_lr: float,
inner_lr: float,
inner_update: int,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
eps_root: float,
weight_decay: float,
diff --git a/tests/test_optim.py b/tests/test_optim.py
index c43bc438..b2be7500 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================
-from typing import Callable, Tuple
+from __future__ import annotations
+
+from typing import Callable
import functorch
import pytest
@@ -96,7 +98,7 @@ def test_SGD(
def test_Adam(
dtype: torch.dtype,
lr: float,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
weight_decay: float,
maximize: bool,
@@ -154,7 +156,7 @@ def test_Adam(
def test_AdamW(
dtype: torch.dtype,
lr: float,
- betas: Tuple[float, float],
+ betas: tuple[float, float],
eps: float,
weight_decay: float,
maximize: bool,
@@ -216,8 +218,8 @@ def test_AdamW(
def test_Adam_accelerated_cuda(
dtype: torch.dtype,
lr: float,
- optimizers: Tuple[torchopt.Optimizer, torch.optim.Optimizer],
- betas: Tuple[float, float],
+ optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer],
+ betas: tuple[float, float],
eps: float,
weight_decay: float,
maximize: bool,
@@ -339,7 +341,7 @@ def test_RMSProp(
def test_FuncOptimizer(
dtype: torch.dtype,
lr: float,
- optimizers: Tuple[Callable, torch.optim.Optimizer],
+ optimizers: tuple[Callable, torch.optim.Optimizer],
inplace: bool,
weight_decay: float,
) -> None:
diff --git a/tests/test_schedule.py b/tests/test_schedule.py
index 9590acf8..ae714875 100644
--- a/tests/test_schedule.py
+++ b/tests/test_schedule.py
@@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================
-from typing import Callable, Tuple
+from __future__ import annotations
+
+from typing import Callable
import functorch
import numpy as np
@@ -62,7 +64,7 @@ def test_lr_linear_schedule(
dtype: torch.dtype,
lr: float,
total_iters: int,
- optimizers: Tuple[Callable, torch.optim.Optimizer],
+ optimizers: tuple[Callable, torch.optim.Optimizer],
inplace: bool,
weight_decay: float,
use_chain_flat: bool,
diff --git a/tests/test_transform.py b/tests/test_transform.py
index 4dfd034d..9598386d 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -13,13 +13,8 @@
# limitations under the License.
# ==============================================================================
-from typing import Tuple
-
-import functorch
import torch
-import torch.nn.functional as F
-import helpers
import torchopt
diff --git a/torchopt/_C/adam_op.pyi b/torchopt/_C/adam_op.pyi
index bc3e8ebc..7ecfe7c2 100644
--- a/torchopt/_C/adam_op.pyi
+++ b/torchopt/_C/adam_op.pyi
@@ -15,7 +15,7 @@
# pylint: disable=all
-from typing import Tuple
+from __future__ import annotations
import torch
@@ -28,7 +28,7 @@ def forward_(
eps: float,
eps_root: float,
count: int,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ...
def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ...
def forward_updates(
@@ -42,10 +42,10 @@ def forward_updates(
) -> torch.Tensor: ...
def backward_mu(
dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float
-) -> Tuple[torch.Tensor, torch.Tensor]: ...
+) -> tuple[torch.Tensor, torch.Tensor]: ...
def backward_nu(
dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float
-) -> Tuple[torch.Tensor, torch.Tensor]: ...
+) -> tuple[torch.Tensor, torch.Tensor]: ...
def backward_updates(
dupdates: torch.Tensor,
updates: torch.Tensor,
@@ -55,4 +55,4 @@ def backward_updates(
b2: float,
eps_root: float,
count: int,
-) -> Tuple[torch.Tensor, torch.Tensor]: ...
+) -> tuple[torch.Tensor, torch.Tensor]: ...
diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py
index 003a8a9f..ede60009 100644
--- a/torchopt/accelerated_op/__init__.py
+++ b/torchopt/accelerated_op/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""The accelerated Ops."""
-from typing import Iterable, Optional, Union
+from __future__ import annotations
+
+from typing import Iterable
import torch
@@ -22,7 +24,7 @@
from torchopt.typing import Device
-def is_available(devices: Optional[Union[Device, Iterable[Device]]] = None) -> bool:
+def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
"""Check the availability of accelerated optimizer."""
op = AdamOp()
diff --git a/torchopt/accelerated_op/_src/adam_op.py b/torchopt/accelerated_op/_src/adam_op.py
index 9f801b8d..ab5ea195 100644
--- a/torchopt/accelerated_op/_src/adam_op.py
+++ b/torchopt/accelerated_op/_src/adam_op.py
@@ -16,7 +16,7 @@
# pylint: disable=invalid-name,too-many-arguments,unused-argument
-from typing import Tuple
+from __future__ import annotations
import torch
@@ -30,7 +30,7 @@ def forward_(
eps: float,
eps_root: float,
count: int,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adam forward inplace."""
mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1)
nu = nu.mul_(b2).addcmul_(updates, updates, value=1.0 - b2)
@@ -80,7 +80,7 @@ def backward_mu(
updates: torch.Tensor,
mu: torch.Tensor,
b1: float,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor]:
"""Adam backward mu."""
dupdates = dmu.mul(1.0 - b1)
dmu = dmu.mul(b1)
@@ -92,7 +92,7 @@ def backward_nu(
updates: torch.Tensor,
nu: torch.Tensor,
b2: float,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor]:
"""Adam backward nu."""
dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2))
dnu = dnu.mul(b2)
@@ -108,7 +108,7 @@ def backward_updates(
b2: float,
eps_root: float,
count: int,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor]:
"""Adam backward updates."""
one_minus_pow_b1 = 1.0 - pow(b1, count)
inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count) + eps_root)
diff --git a/torchopt/accelerated_op/adam_op.py b/torchopt/accelerated_op/adam_op.py
index 6b93bf18..232513d6 100644
--- a/torchopt/accelerated_op/adam_op.py
+++ b/torchopt/accelerated_op/adam_op.py
@@ -16,8 +16,10 @@
# pylint: disable=c-extension-no-member,invalid-name
+from __future__ import annotations
+
import contextlib
-from typing import Any, Optional, Tuple
+from typing import Any
import torch
@@ -132,9 +134,9 @@ def __call__(
self,
mu: torch.Tensor,
nu: torch.Tensor,
- updates: Optional[torch.Tensor],
+ updates: torch.Tensor | None,
count: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply the Adam operator."""
if updates is None:
return mu, nu, None
diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py
index a7f90a79..08654577 100644
--- a/torchopt/alias/adam.py
+++ b/torchopt/alias/adam.py
@@ -31,7 +31,7 @@
# ==============================================================================
"""Preset :class:`GradientTransformation` for the Adam optimizer."""
-from typing import Tuple
+from __future__ import annotations
from torchopt.alias.utils import (
_get_use_chain_flat,
@@ -49,7 +49,7 @@
# pylint: disable-next=too-many-arguments
def adam(
lr: ScalarOrSchedule = 1e-3,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
@@ -68,26 +68,25 @@ def adam(
- Kingma et al, 2014: https://arxiv.org/abs/1412.6980
Args:
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to avoid
- dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
+ scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square root
+ (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ eps_root (float, optional): A small constant applied to denominator inside the square root
+ (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example
+ when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with
+ flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms.
+ (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of minimizing.
+ (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
Returns:
The corresponding :class:`GradientTransformation` instance.
diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py
index 9aecc8ee..21ef84ef 100644
--- a/torchopt/alias/adamw.py
+++ b/torchopt/alias/adamw.py
@@ -31,7 +31,9 @@
# ==============================================================================
"""Preset :class:`GradientTransformation` for the AdamW optimizer."""
-from typing import Any, Callable, Optional, Tuple, Union
+from __future__ import annotations
+
+from typing import Callable
from torchopt.alias.utils import (
_get_use_chain_flat,
@@ -40,7 +42,7 @@
)
from torchopt.combine import chain
from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam
-from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule
+from torchopt.typing import GradientTransformation, OptState, Params, ScalarOrSchedule
__all__ = ['adamw']
@@ -49,12 +51,12 @@
# pylint: disable-next=too-many-arguments,too-many-locals
def adamw(
lr: ScalarOrSchedule = 1e-3,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
moment_requires_grad: bool = False,
maximize: bool = False,
use_accelerated_op: bool = False,
@@ -70,35 +72,34 @@ def adamw(
- Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101
Args:
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`1e-2`)
- Strength of the weight decay regularization. Note that this weight decay is multiplied
- with the learning rate. This is consistent with other frameworks such as PyTorch, but
- different from (Loshchilov et al, 2019) where the weight decay is only multiplied with
- the "schedule multiplier", but not the base learning rate.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to avoid
- dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- mask: (default: :data:`None`)
- A tree with same structure as (or a prefix of) the params PyTree, or a Callable that
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
+ scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square root
+ (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Strength of the weight decay regularization. Note that this
+ weight decay is multiplied with the learning rate. This is consistent with other
+ frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight
+ decay is only multiplied with the "schedule multiplier", but not the base learning rate.
+ (default: :const:`1e-2`)
+ eps_root (float, optional): A small constant applied to denominator inside the square root
+ (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example
+ when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ mask (tree of Tensor, callable, or None, optional):
+ A tree with same structure as (or a prefix of) the params pytree, or a function that
returns such a pytree given the params/updates. The leaves should be booleans,
:data:`True` for leaves/subtrees you want to apply the weight decay to, and
- :data:`False` for those you want to skip. Note that the Adam gradient
- transformations are applied to all parameters.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ :data:`False` for those you want to skip. Note that the Adam gradient transformations
+ are applied to all parameters. (default: :data:`None`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with
+ flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms.
+ (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
Returns:
The corresponding :class:`GradientTransformation` instance.
diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py
index 18a5c5e8..f0eb92cd 100644
--- a/torchopt/alias/rmsprop.py
+++ b/torchopt/alias/rmsprop.py
@@ -69,28 +69,25 @@ def rmsprop(
- Graves, 2013: https://arxiv.org/abs/1308.0850
Args:
- lr: (default: :const:`1e-2`)
- This is a fixed global scaling factor.
- alpha: (default: :const:`0.99`)
- Smoothing constant, the decay used to track the magnitude of previous gradients.
- eps: (default: :const:`1e-8`)
- A small numerical constant to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- centered: (default: :data:`False`)
- If :data:`True`, use the variance of the past gradients to rescale the latest
- gradients.
- initial_scale: (default: :data:`0.0`)
- Initialization of accumulators tracking the magnitude of previous updates. PyTorch
- uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a
- paper, verify the value used by the authors.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
+ scheduler. (default: :const:`1e-2`)
+ alpha (float, optional): Smoothing constant, the decay used to track the magnitude of
+ previous gradients. (default: :const:`0.99`)
+ eps (float, optional): A small numerical constant to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is not
+ used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ centered (bool, optional): If :data:`True`, use the variance of the past gradients to
+ rescale the latest gradients. (default: :data:`False`)
+ initial_scale (float, optional): Initialization of accumulators tracking the magnitude of
+ previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When
+ reproducing results from a paper, verify the value used by the authors.
+ (default: :data:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
Returns:
The corresponding :class:`GradientTransformation` instance.
diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py
index 61b3d6e4..7d86b538 100644
--- a/torchopt/alias/sgd.py
+++ b/torchopt/alias/sgd.py
@@ -64,21 +64,19 @@ def sgd(
- Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
Args:
- lr: This is a fixed global scaling factor.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- dampening: (default: :const:`0.0`)
- Dampening for momentum.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ lr (float or callable): This is a fixed global scaling factor or a learning rate
+ scheduler.
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is not
+ used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ dampening (float, optional): Dampening for momentum. (default: :const:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with
+ flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms.
+ (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
Returns:
The corresponding :class:`GradientTransformation` instance.
diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py
index 869aad87..b5088164 100644
--- a/torchopt/alias/utils.py
+++ b/torchopt/alias/utils.py
@@ -13,8 +13,9 @@
# limitations under the License.
r"""Utilities for the aliases of preset :class:`GradientTransformation`\s for optimizers."""
+from __future__ import annotations
+
import threading
-from typing import Optional, Tuple
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation, identity
@@ -93,9 +94,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
assert params is not None, (
'Parameters are required for weight decay. '
'Call `update(updates, state, params=params)` instead.'
@@ -126,9 +127,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
if inplace:
def f(g):
@@ -151,9 +152,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
assert params is not None, (
'Parameters are required for weight decay. '
'Call `update(updates, state, params=params)` instead.'
diff --git a/torchopt/base.py b/torchopt/base.py
index bb37b147..b250c387 100644
--- a/torchopt/base.py
+++ b/torchopt/base.py
@@ -31,9 +31,11 @@
# ==============================================================================
"""The base classes for gradient transformation."""
+from __future__ import annotations
+
import itertools
from abc import abstractmethod
-from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, NamedTuple
from typing_extensions import Protocol # Python 3.8+
@@ -67,12 +69,11 @@ class TransformInitFn(Protocol): # pylint: disable=too-few-public-methods
"""
@abstractmethod
- def __call__(self, params: 'Params') -> 'OptState':
+ def __call__(self, params: Params) -> OptState:
"""Initialize the gradient transformation state.
Args:
- params:
- The initial value of the parameters.
+ params (tree of Tensor): The initial value of the parameters.
Returns:
The initial state of the gradient transformation.
@@ -93,21 +94,21 @@ class TransformUpdateFn(Protocol): # pylint: disable=too-few-public-methods
@abstractmethod
def __call__(
self,
- updates: 'Updates',
- state: 'OptState',
+ updates: Updates,
+ state: OptState,
*,
- params: Optional['Params'] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple['Updates', 'OptState']:
+ ) -> tuple[Updates, OptState]:
"""Transform the updates and state.
Args:
- updates: A tree of candidate updates.
- state: The state of the gradient transformation.
- params: (optional)
- The current value of the parameters.
- inplace: (optional)
- If :data:`True`, modify updates and state using inplace operations.
+ updates (tree of Tensor): A tree of candidate updates.
+ state (tree of Tensor): The state of the gradient transformation.
+ params (tree of Tensor or None, optional): The current value of the parameters.
+ (default: :data:`None`)
+ inplace (bool, optional): If :data:`True`, modify updates and state using inplace
+ operations. (default: :data:`True`)
Returns:
The transformed ``updates``, and the updated ``state``.
@@ -134,9 +135,9 @@ class GradientTransformation(NamedTuple):
optimizer state.
update:
A pure function which takes as input a pytree of updates (with the same tree structure
- as the original params ``pytree`` passed to :attr:`init`), the previous optimizer state
- (which may have been initialized using the :attr:`init` function), and optionally the
- ``inplace`` flag. The :attr:`update` function then returns the computed gradient
+ as the original params ``pytree`` passed to ``init``), the previous optimizer state
+ (which may have been initialized using the ``init`` function), and optionally the
+ ``inplace`` flag. The ``update`` function then returns the computed gradient
updates, and a updates optimizer state. If the ``inplace`` flag is :data:`True`, the
output results are the same instance as the input.
"""
@@ -145,7 +146,7 @@ class GradientTransformation(NamedTuple):
update: TransformUpdateFn
# pylint: disable-next=redefined-builtin
- def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation':
+ def chain(self, next: GradientTransformation) -> ChainedGradientTransformation:
"""Chain two gradient transformations together."""
return ChainedGradientTransformation(self, next)
@@ -157,9 +158,9 @@ class ChainedGradientTransformation(GradientTransformation):
gradient transformations.
"""
- transformations: Tuple[GradientTransformation, ...]
+ transformations: tuple[GradientTransformation, ...]
- def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation':
+ def __new__(cls, *transformations: GradientTransformation) -> ChainedGradientTransformation:
"""Create a new chained gradient transformation."""
transformations = tuple(
itertools.chain.from_iterable(
@@ -175,16 +176,16 @@ def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTr
init_fns, update_fns = tuple(zip(*transformations))
- def init_fn(params: 'Params') -> 'OptState':
+ def init_fn(params: Params) -> OptState:
return tuple(fn(params) for fn in init_fns)
def update_fn(
- updates: 'Updates',
- state: 'OptState',
+ updates: Updates,
+ state: OptState,
*,
- params: Optional['Params'] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple['Updates', 'OptState']:
+ ) -> tuple[Updates, OptState]:
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you'
@@ -219,15 +220,15 @@ def __hash__(self) -> int:
"""Return the hash of the chained gradient transformation."""
return hash(self.transformations)
- def __getstate__(self) -> Tuple[GradientTransformation, ...]:
+ def __getstate__(self) -> tuple[GradientTransformation, ...]:
"""Return the state of the chained gradient transformation for serialization."""
return self.transformations
- def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None:
+ def __setstate__(self, state: tuple[GradientTransformation, ...]) -> None:
"""Set the state of the chained gradient transformation from serialization."""
self.transformations = state
- def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]:
+ def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...]]]:
"""Serialize the chained gradient transformation."""
return ChainedGradientTransformation, (self.transformations,)
@@ -240,18 +241,18 @@ def __new__(cls):
return super().__new__(cls, init=cls.init_fn, update=cls.update_fn)
@staticmethod
- def init_fn(params: 'Params') -> 'OptState': # pylint: disable=unused-argument
+ def init_fn(params: Params) -> OptState: # pylint: disable=unused-argument
"""Return empty state."""
return EmptyState()
@staticmethod
def update_fn(
- updates: 'Updates',
- state: 'OptState',
+ updates: Updates,
+ state: OptState,
*,
- params: Optional['Params'] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True, # pylint: disable=unused-argument
- ) -> Tuple['Updates', 'OptState']:
+ ) -> tuple[Updates, OptState]:
"""Return updates unchanged."""
return updates, state
diff --git a/torchopt/clip.py b/torchopt/clip.py
index 2469d17a..b2aafb48 100644
--- a/torchopt/clip.py
+++ b/torchopt/clip.py
@@ -17,7 +17,7 @@
# ==============================================================================
"""Utilities for gradient clipping."""
-from typing import Optional, Tuple, Union
+from __future__ import annotations
import torch
@@ -33,18 +33,19 @@
def clip_grad_norm(
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
+ max_norm: float | int,
+ norm_type: float | int = 2.0,
error_if_nonfinite: bool = False,
) -> GradientTransformation:
"""Clip gradient norm of an iterable of parameters.
Args:
max_norm (float or int): The maximum absolute value for each element in the update.
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
- infinity norm.
- error_if_nonfinite (bool): if :data:`True`, an error is thrown if the total norm of the
- gradients from :attr:`updates` is ``nan``, ``inf``, or ``-inf``.
+ norm_type (float or int, optional): Type of the used p-norm. Can be ``'inf'`` for infinity
+ norm. (default: :const:`2.0`)
+ error_if_nonfinite (bool, optional): If :data:`True`, an error is thrown if the total norm
+ of the gradients from ``updates`` is ``nan``, ``inf``, or ``-inf``.
+ (default: :data:`False`)
Returns:
An ``(init_fn, update_fn)`` tuple.
@@ -57,9 +58,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
available_updates = pytree.tree_leaves(updates)
if len(available_updates) == 0:
return updates, state
diff --git a/torchopt/combine.py b/torchopt/combine.py
index 82297426..0f1ed8ec 100644
--- a/torchopt/combine.py
+++ b/torchopt/combine.py
@@ -31,7 +31,7 @@
# ==============================================================================
"""Utilities to define a chained transformation."""
-from typing import Optional, Tuple
+from __future__ import annotations
from torchopt import pytree
from torchopt.base import ChainedGradientTransformation, GradientTransformation, identity
@@ -49,8 +49,8 @@ def chain(*transformations: GradientTransformation) -> GradientTransformation:
:func:`update_fn` which chains the update transformations feeding the appropriate state to each.
Args:
- *transformations:
- A sequence of chainable ``(init_fn, update_fn)`` tuples.
+ *transformations (iterable of GradientTransformation): A sequence of chainable
+ ``(init_fn, update_fn)`` tuples.
Returns:
A single ``(init_fn, update_fn)`` tuple.
@@ -66,8 +66,8 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati
"""Wrap around the inner transformations that manipulate the flattened tree structure (:class:``list``).
Args:
- *transformations:
- A sequence of chainable ``(init_fn, update_fn)`` tuples.
+ *transformations (iterable of GradientTransformation): A sequence of chainable
+ ``(init_fn, update_fn)`` tuples.
Returns:
A single ``(init_fn, update_fn)`` tuple.
@@ -86,9 +86,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True)
if params is not None:
flat_params = pytree.tree_leaves(params, none_is_leaf=True)
diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py
index 377bc1f4..a5908963 100644
--- a/torchopt/diff/implicit/decorator.py
+++ b/torchopt/diff/implicit/decorator.py
@@ -16,9 +16,11 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
import functools
import inspect
-from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Sequence, Tuple
import functorch
import torch
@@ -47,7 +49,7 @@ def __init__(
optimality_fn: Callable[..., TensorOrTensors],
solution: TensorOrTensors,
output_is_tensor: bool,
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
*args: Any,
) -> None:
self.optimality_fn = optimality_fn
@@ -88,7 +90,7 @@ def _root_vjp(
args: Args,
grad_outputs: TupleOfTensors,
output_is_tensor: bool,
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
) -> TupleOfOptionalTensors:
if output_is_tensor:
@@ -145,14 +147,14 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors:
return tuple(true_output)
-def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: Tuple[Any, ...]) -> Tuple[Args, KwArgs]:
+def _extract_kwargs(kwarg_keys: Sequence[str], flat_args: tuple[Any, ...]) -> tuple[Args, KwArgs]:
nargs = len(flat_args) - len(kwarg_keys)
args, kwarg_vals = flat_args[:nargs], flat_args[nargs:]
kwargs = dict(zip(kwarg_keys, kwarg_vals))
return args, kwargs
-def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> Tuple[Args, KwArgs]:
+def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) -> tuple[Args, KwArgs]:
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
return bound.args, bound.kwargs
@@ -160,7 +162,7 @@ def _signature_bind(signature: inspect.Signature, *args: Any, **kwargs: Any) ->
def _signature_bind_and_match(
signature: inspect.Signature, *args: Any, **kwargs: Any
-) -> Tuple[Args, KwArgs, Callable[[Args], Tuple[Args, KwArgs]]]:
+) -> tuple[Args, KwArgs, Callable[[Args], tuple[Args, KwArgs]]]:
# We want to bind *args and **kwargs based on the provided signature, but also to associate the
# resulting positional arguments back. To achieve this, we lift arguments to a triple:
#
@@ -193,13 +195,13 @@ def map_args_back(out_args):
def _split_tensor_and_others(
- mixed_tuple: Tuple[Any, ...],
-) -> Tuple[pytree.PyTreeSpec, Tuple[bool, ...], TupleOfTensors, Tuple[Any, ...]]:
- flattened: List[Any]
+ mixed_tuple: tuple[Any, ...],
+) -> tuple[pytree.PyTreeSpec, tuple[bool, ...], TupleOfTensors, tuple[Any, ...]]:
+ flattened: list[Any]
flattened, treespec = pytree.tree_flatten(mixed_tuple, none_is_leaf=True) # type: ignore[arg-type]
tensors: ListOfTensors = []
- non_tensors: List[Any] = []
- is_tensor_mask: List[bool] = []
+ non_tensors: list[Any] = []
+ is_tensor_mask: list[bool] = []
for item in flattened:
is_tensor = isinstance(item, torch.Tensor)
is_tensor_mask.append(is_tensor)
@@ -212,10 +214,10 @@ def _split_tensor_and_others(
def _merge_tensor_and_others(
treespec: pytree.PyTreeSpec,
- is_tensor_mask: Tuple[bool, ...],
+ is_tensor_mask: tuple[bool, ...],
tensors: TupleOfTensors,
- non_tensors: Tuple[Any, ...],
-) -> Tuple[Any, ...]:
+ non_tensors: tuple[Any, ...],
+) -> tuple[Any, ...]:
tensor_counter = 0
non_tensor_counter = 0
results = []
@@ -231,13 +233,13 @@ def _merge_tensor_and_others(
# pylint: disable-next=too-many-arguments,too-many-statements
def _custom_root(
- solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]],
+ solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
optimality_fn: Callable[..., TensorOrTensors],
solve: Callable[..., TensorOrTensors],
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
has_aux: bool,
- reference_signature: Optional[Union[inspect.Signature, Callable]] = None,
-) -> Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]:
+ reference_signature: inspect.Signature | Callable | None = None,
+) -> Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]:
solver_fn_signature = inspect.signature(solver_fn)
if reference_signature is None:
@@ -249,16 +251,16 @@ def _custom_root(
reference_signature = inspect.signature(fn)
def make_custom_vjp_solver_fn(
- solver_fn: Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]],
+ solver_fn: Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
kwarg_keys: Sequence[str],
- args_signs: Tuple[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]], ...],
- ) -> Type[Function]:
+ args_signs: tuple[tuple[int, int, type[tuple] | type[list] | None], ...],
+ ) -> type[Function]:
# pylint: disable-next=missing-class-docstring,abstract-method
class ImplicitMetaGradient(Function):
@staticmethod
def forward( # type: ignore[override] # pylint: disable=arguments-differ
ctx: Any, *flat_args: Any
- ) -> Tuple[Any, ...]:
+ ) -> tuple[Any, ...]:
output, aux, output_is_tensor = None, None, False
args = []
@@ -361,12 +363,12 @@ def backward( # pylint: disable=too-many-locals
@functools.wraps(solver_fn)
def wrapped_solver_fn(
*args: Any, **kwargs: Any
- ) -> Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]:
+ ) -> TensorOrTensors | tuple[TensorOrTensors, Any]:
args, kwargs = _signature_bind(solver_fn_signature, *args, **kwargs)
keys, vals = list(kwargs.keys()), list(kwargs.values())
- args_signs: List[Tuple[int, int, Optional[Union[Type[tuple], Type[list]]]]] = []
- flat_args: List[Any] = []
+ args_signs: list[tuple[int, int, type[tuple] | type[list] | None]] = []
+ flat_args: list[Any] = []
args_offset = 0
for idx, arg in enumerate(args):
if idx in argnums:
@@ -410,12 +412,12 @@ def wrapped_solver_fn(
def custom_root(
optimality_fn: Callable[..., TensorOrTensors],
- argnums: Union[int, Tuple[int, ...]],
+ argnums: int | tuple[int, ...],
has_aux: bool = False,
solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
) -> Callable[
- [Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]]],
- Callable[..., Union[TensorOrTensors, Tuple[TensorOrTensors, Any]]],
+ [Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]],
+ Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
]:
"""Return a decorator for adding implicit differentiation to a root solver.
@@ -442,18 +444,17 @@ def solver_fn(params, arg1, arg2, ...):
**In best practice, the ``optimality_fn`` should have the same signature as ``solver_fn``.**
Args:
- optimality_fn: (callable)
- An equation function, ``optimality_fn(params, *args)``. The invariant is
- ``optimality_fn(solution, *args) == 0`` at the solution / root of ``solution``.
- argnums: (int or tuple of ints)
- Specifies arguments to compute gradients with respect to. The ``argnums`` can be an
- integer or a tuple of integers, which respect to the zero-based indices of the arguments
- of the ``solver_fn(params, *args)`` function. The argument ``params`` is included
- for the counting, while it is indexed as ``argnums=0``.
- has_aux: (default: :data:`False`)
- Whether the decorated solver function returns auxiliary data.
- solve: (callable, optional, default: :func:`linear_solve.solve_normal_cg`)
- a linear solver of the form ``solve(matvec, b)``.
+ optimality_fn (callable): An equation function, ``optimality_fn(params, *args)``. The
+ invariant is ``optimality_fn(solution, *args) == 0`` at the solution / root of
+ ``solution``.
+ argnums (int or tuple of int): Specifies arguments to compute gradients with respect to. The
+ ``argnums`` can be an integer or a tuple of integers, which respect to the zero-based
+ indices of the arguments of the ``solver_fn(params, *args)`` function. The argument
+ ``params`` is included for the counting, while it is indexed as ``argnums=0``.
+ has_aux (bool, optional): Whether the decorated solver function returns auxiliary data.
+ (default: :data:`False`)
+ solve (callable, optional): A linear solver of the form ``solve(matvec, b)``.
+ (default: :func:`linear_solve.solve_normal_cg`)
Returns:
A solver function decorator, i.e., ``custom_root(optimality_fn)(solver_fn)``.
diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py
index f9bff4de..bbae37c9 100644
--- a/torchopt/diff/implicit/nn/module.py
+++ b/torchopt/diff/implicit/nn/module.py
@@ -16,10 +16,12 @@
# pylint: disable=redefined-builtin
+from __future__ import annotations
+
import abc
import functools
import itertools
-from typing import Any, Iterable, Optional, Tuple, Type
+from typing import Any, Iterable
import functorch
import torch
@@ -38,7 +40,7 @@ def _stateless_objective_fn(
__flat_meta_params: TupleOfTensors,
__params_names: Iterable[str],
__meta_params_names: Iterable[str],
- self: 'ImplicitMetaGradientModule',
+ self: ImplicitMetaGradientModule,
*input,
**kwargs,
) -> torch.Tensor:
@@ -57,7 +59,7 @@ def _stateless_optimality_fn(
__flat_meta_params: TupleOfTensors,
__params_names: Iterable[str],
__meta_params_names: Iterable[str],
- self: 'ImplicitMetaGradientModule',
+ self: ImplicitMetaGradientModule,
*input,
**kwargs,
) -> TupleOfTensors:
@@ -72,8 +74,8 @@ def _stateless_optimality_fn(
def make_optimality_from_objective(
- cls: Type['ImplicitMetaGradientModule'],
-) -> Type['ImplicitMetaGradientModule']:
+ cls: type[ImplicitMetaGradientModule],
+) -> type[ImplicitMetaGradientModule]:
"""Derives the optimality function of the objective function."""
if (
getattr(cls, 'objective', ImplicitMetaGradientModule.objective)
@@ -81,7 +83,7 @@ def make_optimality_from_objective(
):
raise TypeError('The objective function is not defined.')
- def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfTensors:
+ def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors:
params_names, flat_params = tuple(zip(*self.named_parameters()))
meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))
@@ -102,8 +104,8 @@ def optimality(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> TupleOfT
def enable_implicit_gradients(
- cls: Type['ImplicitMetaGradientModule'],
-) -> Type['ImplicitMetaGradientModule']:
+ cls: type[ImplicitMetaGradientModule],
+) -> type[ImplicitMetaGradientModule]:
"""Enable implicit gradients for the :func:`solve` method."""
cls_solve = cls.solve
if getattr(cls_solve, '__implicit_gradients_enabled__', False):
@@ -122,17 +124,17 @@ def stateless_solver_fn(
__params_names: Iterable[str],
__meta_params_names: Iterable[str],
# pylint: enable=unused-argument
- self: 'ImplicitMetaGradientModule',
+ self: ImplicitMetaGradientModule,
*input,
**kwargs,
- ) -> Tuple[TupleOfTensors, Any]:
+ ) -> tuple[TupleOfTensors, Any]:
"""Solve the optimization problem."""
output = cls_solve(self, *input, **kwargs)
flat_optimal_params = tuple(p.detach_() for p in self.parameters())
return flat_optimal_params, output
@functools.wraps(cls_solve)
- def wrapped(self: 'ImplicitMetaGradientModule', *input, **kwargs) -> Any:
+ def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any:
"""Solve the optimization problem."""
params_names, flat_params = tuple(zip(*self.named_parameters()))
meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))
@@ -159,9 +161,9 @@ class ImplicitMetaGradientModule(MetaGradientModule):
_custom_optimality: bool
_custom_objective: bool
- linear_solve: Optional[LinearSolver]
+ linear_solve: LinearSolver | None
- def __init_subclass__(cls, linear_solve: Optional[LinearSolver] = None) -> None:
+ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None:
"""Validate and initialize the subclass."""
super().__init_subclass__()
cls.linear_solve = linear_solve
diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py
index 80664d8b..43522028 100644
--- a/torchopt/diff/zero_order/decorator.py
+++ b/torchopt/diff/zero_order/decorator.py
@@ -14,8 +14,10 @@
# ==============================================================================
"""Zero-Order Gradient Estimation."""
+from __future__ import annotations
+
import functools
-from typing import Any, Callable, List, Sequence, Tuple, Union
+from typing import Any, Callable, Sequence
from typing_extensions import Literal # Python 3.8+
from typing_extensions import TypeAlias # Python 3.10+
@@ -33,9 +35,7 @@ def __init__(self, sample_fn: SampleFunc) -> None:
"""Wrap a sample function to make it a :class:`Samplable` object."""
self.sample_fn = sample_fn
- def sample(
- self, sample_shape: torch.Size = torch.Size()
- ) -> Union[torch.Tensor, Sequence[Numeric]]:
+ def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
return self.sample_fn(sample_shape)
@@ -44,14 +44,14 @@ def sample(
def _zero_order_naive( # pylint: disable=too-many-statements
fn: Callable[..., torch.Tensor],
distribution: Samplable,
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
num_samples: int,
sigma: Numeric,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
diff_params = [args[argnum] for argnum in argnums]
- flat_diff_params: List[Any]
+ flat_diff_params: list[Any]
flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type]
class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method
@@ -59,7 +59,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
flat_diff_params = args[:-1]
origin_args = list(args[-1][0])
- flat_args: List[Any]
+ flat_args: list[Any]
flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type]
ctx.args_treespec = args_treespec
@@ -107,7 +107,7 @@ def backward( # pylint: disable=too-many-locals
flat_args.append(non_tensors[non_tensors_counter])
non_tensors_counter += 1
- args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
+ args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
def add_perturbation(tensor, noises):
return tensor.add(noises, alpha=sigma)
@@ -119,7 +119,7 @@ def add_perturbation(tensor, noises):
flat_noisy_params = [
add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)
]
- noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment]
+ noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
diff_params_treespec, flat_noisy_params
)
@@ -145,14 +145,14 @@ def add_perturbation(tensor, noises):
def _zero_order_forward( # pylint: disable=too-many-statements
fn: Callable[..., torch.Tensor],
distribution: Samplable,
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
num_samples: int,
sigma: Numeric,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
diff_params = [args[argnum] for argnum in argnums]
- flat_diff_params: List[Any]
+ flat_diff_params: list[Any]
flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type]
class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method
@@ -160,7 +160,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
flat_diff_params = args[:-1]
origin_args = list(args[-1][0])
- flat_args: List[Any]
+ flat_args: list[Any]
flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type]
ctx.args_treespec = args_treespec
@@ -209,7 +209,7 @@ def backward( # pylint: disable=too-many-locals
flat_args.append(non_tensors[non_tensors_counter])
non_tensors_counter += 1
- args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
+ args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
def add_perturbation(tensor, noises):
return tensor.add(noises, alpha=sigma)
@@ -221,7 +221,7 @@ def add_perturbation(tensor, noises):
flat_noisy_params = [
add_perturbation(t, n) for t, n in zip(flat_diff_params, noises)
]
- noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment]
+ noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
diff_params_treespec, flat_noisy_params
)
@@ -248,14 +248,14 @@ def add_perturbation(tensor, noises):
def _zero_order_antithetic( # pylint: disable=too-many-statements
fn: Callable[..., torch.Tensor],
distribution: Samplable,
- argnums: Tuple[int, ...],
+ argnums: tuple[int, ...],
num_samples: int,
sigma: Numeric,
) -> Callable[..., torch.Tensor]:
@functools.wraps(fn)
def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements
diff_params = [args[argnum] for argnum in argnums]
- flat_diff_params: List[Any]
+ flat_diff_params: list[Any]
flat_diff_params, diff_params_treespec = pytree.tree_flatten(diff_params) # type: ignore[arg-type]
class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-method
@@ -263,7 +263,7 @@ class ZeroOrder(Function): # pylint: disable=missing-class-docstring,abstract-m
def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor:
flat_diff_params = args[:-1]
origin_args = list(args[-1][0])
- flat_args: List[Any]
+ flat_args: list[Any]
flat_args, args_treespec = pytree.tree_flatten(origin_args, none_is_leaf=True) # type: ignore[arg-type]
ctx.args_treespec = args_treespec
@@ -309,7 +309,7 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals
flat_args.append(non_tensors[non_tensors_counter])
non_tensors_counter += 1
- args: List[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
+ args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment]
param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc]
@@ -318,7 +318,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor:
add_perturbation_fn(t, n, alpha=sigma)
for t, n in zip(flat_diff_params, noises)
]
- noisy_params: List[Any] = pytree.tree_unflatten( # type: ignore[assignment]
+ noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment]
diff_params_treespec, flat_noisy_params
)
@@ -349,28 +349,28 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor:
def zero_order(
- distribution: Union[SampleFunc, Samplable],
+ distribution: SampleFunc | Samplable,
method: Method = 'naive',
- argnums: Union[int, Tuple[int, ...]] = (0,),
+ argnums: int | tuple[int, ...] = (0,),
num_samples: int = 1,
sigma: Numeric = 1.0,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Return a decorator for applying zero-order differentiation.
Args:
- distribution: (function or Samplable)
- A samplable object that has method ``samplable.sample(sample_shape)`` or a function that
- takes the shape as input and returns a shaped batch of samples. This is used to sample
- perturbations from the given distribution. The distribution should be sphere symmetric.
- method: (str)
- The algorithm to use. The currently supported algorithms are :const:`'naive'`,
- :const:`'forward'`, and :const:`'antithetic'`. Defaults to :const:`'naive'`.
- argnums: (int or tuple of int, default: :const:`0`)
- Specifies arguments to compute gradients with respect to.
- num_samples: (int, default :const:`1`)
- The number of sample to get the averaged estimated gradient.
- sigma: (Numeric)
- The standard deviation of the perturbation. Defaults to :const:`1.0`.
+ distribution (callable or Samplable): A samplable object that has method
+ ``samplable.sample(sample_shape)`` or a function that takes the shape as input and
+ returns a shaped batch of samples. This is used to sample perturbations from the given
+ distribution. The distribution should be sphere symmetric.
+ method (str, optional): The algorithm to use. The currently supported algorithms are
+ :const:`'naive'`, :const:`'forward'`, and :const:`'antithetic'`.
+ (default: :const:`'naive'`)
+ argnums (int or tuple of int, optional): Specifies arguments to compute gradients with
+ respect to. (default: :const:`0`)
+ num_samples (int, optional): The number of sample to get the averaged estimated gradient.
+ (default: :const:`1`)
+ sigma (float or Tensor, optional): The standard deviation of the perturbation.
+ (default: :const:`1.0`)
Returns:
A function decorator that enables zero-order gradient estimation.
diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py
index d76ac444..65014fb9 100644
--- a/torchopt/diff/zero_order/nn/module.py
+++ b/torchopt/diff/zero_order/nn/module.py
@@ -16,9 +16,11 @@
# pylint: disable=redefined-builtin
+from __future__ import annotations
+
import abc
import functools
-from typing import Sequence, Type, Union
+from typing import Sequence
import torch
import torch.nn as nn
@@ -32,11 +34,11 @@
def enable_zero_order_gradients(
- cls: Type['ZeroOrderGradientModule'],
+ cls: type[ZeroOrderGradientModule],
method: Method = 'naive',
num_samples: int = 1,
sigma: Numeric = 1.0,
-) -> Type['ZeroOrderGradientModule']:
+) -> type[ZeroOrderGradientModule]:
"""Enable zero-order gradient estimation for the :func:`forward` method."""
cls_forward = cls.forward
if getattr(cls_forward, '__zero_order_gradients_enabled__', False):
@@ -45,7 +47,7 @@ def enable_zero_order_gradients(
)
@functools.wraps(cls_forward)
- def wrapped(self: 'ZeroOrderGradientModule', *input, **kwargs) -> torch.Tensor:
+ def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor:
"""Do the forward pass calculation."""
params_names, flat_params = tuple(zip(*self.named_parameters()))
@@ -91,7 +93,7 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
@abc.abstractmethod
def sample(
self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument
- ) -> Union[torch.Tensor, Sequence[Numeric]]:
+ ) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
raise NotImplementedError
diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py
index 53f87fba..b46ad67e 100644
--- a/torchopt/distributed/api.py
+++ b/torchopt/distributed/api.py
@@ -14,6 +14,8 @@
# ==============================================================================
"""Distributed APIs."""
+from __future__ import annotations
+
import functools
import sys
from typing import (
@@ -73,8 +75,8 @@ class TensorDimensionPartitioner:
while the non-tensor values will be broadcasted to partitions.
Args:
- dim: The dimension to partition.
- exclusive: Whether to partition the batch exclusively.
+ dim (int): The dimension to partition.
+ exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`)
If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where
``batch_size`` is the size of the batch along the given dimension. Each batch sample
will be assigned to a separate RPC call.
@@ -82,11 +84,12 @@ class TensorDimensionPartitioner:
partitions, where ``num_workers`` is the number of workers in the world. When
``batch_size > num_workers``, there can be multiple batch samples forward in a single
RPC call.
- keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
- batch dimension. If :data:`False`, use select instead of slicing. This functionality
- should be used with ``exclusive=True``.
- workers: The workers to partition the batch to. If :data:`None`, the batch will be
- partitioned to all workers in the world.
+ keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`True`)
+ If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of
+ slicing. This functionality should be used with ``exclusive=True``.
+ workers (sequence of int or str, or None, optional): The workers to partition the batch to.
+ If :data:`None`, the batch will be partitioned to all workers in the world.
+ (default: :data:`None`)
"""
def __init__(
@@ -95,7 +98,7 @@ def __init__(
*,
exclusive: bool = False,
keepdim: bool = False,
- workers: Optional[Sequence[Union[int, str]]] = None,
+ workers: Sequence[int | str] | None = None,
) -> None:
"""Initialize the partitioner instance."""
if not keepdim and not exclusive:
@@ -111,7 +114,7 @@ def __call__(
self,
*args: Any,
**kwargs: Any,
- ) -> List[Tuple[int, Optional[Args], Optional[KwArgs]]]:
+ ) -> list[tuple[int, Args | None, KwArgs | None]]:
"""Partition the batch of inputs along the given dimension."""
if self.workers is None:
workers = list(range(get_world_size()))
@@ -120,7 +123,7 @@ def __call__(
num_workers = len(workers)
args_tree = (args, kwargs)
- flat_args: List[Any]
+ flat_args: list[Any]
flat_args, treespec = pytree.tree_flatten(args_tree) # type: ignore[arg-type]
batch_size = None
@@ -137,8 +140,8 @@ def __call__(
if batch_size is None:
return [(get_world_rank(), args, kwargs.copy())]
- dim_slices: List[Union[int, slice]]
- batch_slices: List[Tuple[Union[int, slice, Ellipsis.__class__], ...]] # type: ignore[name-defined]
+ dim_slices: list[int | slice]
+ batch_slices: list[tuple[int | slice | Ellipsis.__class__, ...]] # type: ignore[name-defined]
if self.exclusive:
num_replicas = batch_size
if self.keepdim:
@@ -172,7 +175,7 @@ def __call__(
for dim_slice in dim_slices
]
- flat_args_replicas: List[List[Any]] = [[] for _ in range(num_replicas)]
+ flat_args_replicas: list[list[Any]] = [[] for _ in range(num_replicas)]
for arg in flat_args:
if isinstance(arg, torch.Tensor):
for i, batch_slice in enumerate(batch_slices):
@@ -181,7 +184,7 @@ def __call__(
for i in range(num_replicas):
flat_args_replicas[i].append(arg)
- args_replicas: List[Tuple[Args, KwArgs]] = [
+ args_replicas: list[tuple[Args, KwArgs]] = [
pytree.tree_unflatten(treespec, args_replica) # type: ignore[misc]
for args_replica in flat_args_replicas
]
@@ -193,10 +196,10 @@ def __call__(
def __reduce__(
self,
- ) -> Tuple[
- Callable[..., 'TensorDimensionPartitioner'],
- Tuple[int],
- Dict[str, Union[bool, Optional[Sequence[Union[int, str]]]]],
+ ) -> tuple[
+ Callable[..., TensorDimensionPartitioner],
+ tuple[int],
+ dict[str, bool | Sequence[int | str] | None],
]:
"""Return a tuple that allows the partitioner to be pickled."""
return (
@@ -211,7 +214,7 @@ def dim_partitioner(
*,
exclusive: bool = False,
keepdim: bool = True,
- workers: Optional[Sequence[Union[int, str]]] = None,
+ workers: Sequence[int | str] | None = None,
) -> PartitionFunction:
"""Partition a batch of inputs along a given dimension.
@@ -219,8 +222,8 @@ def dim_partitioner(
while the non-tensor values will be broadcasted to partitions.
Args:
- dim: The dimension to partition.
- exclusive: Whether to partition the batch exclusively.
+ dim (int, optional): The dimension to partition. (default: :const:`0`)
+ exclusive (bool, optional): Whether to partition the batch exclusively. (default: :data:`False`)
If :data:`True`, the batch will be partitioned into ``batch_size`` partitions, where
``batch_size`` is the size of the batch along the given dimension. Each batch sample
will be assigned to a separate RPC call.
@@ -228,11 +231,12 @@ def dim_partitioner(
partitions, where ``num_workers`` is the number of workers in the world. When
``batch_size > num_workers``, there can be multiple batch samples forward in a single
RPC call.
- keepdim: Whether to keep the partitioned dimension. Defaults to :data:`True`, i.e., keep the
- batch dimension. If :data:`False`, use select instead of slicing. This functionality
- should be used with ``exclusive=True``.
- workers: The workers to partition the batch to. If :data:`None`, the batch will be
- partitioned to all workers in the world.
+ keepdim (bool, optional): Whether to keep the partitioned dimension. (default: :data:`False`)
+ If :data:`True`, keep the batch dimension. If :data:`False`, use select instead of
+ slicing. This functionality should be used with ``exclusive=True``.
+ workers (sequence of int or str, or None, optional): The workers to partition the batch to.
+ If :data:`None`, the batch will be partitioned to all workers in the world.
+ (default: :data:`None`)
Returns:
A partition function.
@@ -273,26 +277,26 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
def remote_async_call(
func: Callable[..., T],
*,
- args: Optional[Args] = None,
- kwargs: Optional[KwArgs] = None,
- partitioner: Optional[Partitioner] = None,
- reducer: Optional[Callable[[Iterable[T]], U]] = None,
- timeout: Optional[float] = UNSET_RPC_TIMEOUT,
-) -> Union[Future[List[T]], Future[U]]:
+ args: Args | None = None,
+ kwargs: KwArgs | None = None,
+ partitioner: Partitioner | None = None,
+ reducer: Callable[[Iterable[T]], U] | None = None,
+ timeout: float | None = UNSET_RPC_TIMEOUT,
+) -> Future[list[T]] | Future[U]:
"""Asynchronously do an RPC on remote workers and return the a :class:`torch.Future` instance at the current worker.
Args:
- func (Callable[..., T]): The function to call.
- args (Optional[Args], optional): The arguments to pass to the function. Defaults to
- :data:`None`.
- kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
- to :data:`None`.
- partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
- workers. Defaults to :func:`batch_partitioner`.
- reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
- multiple workers. Defaults to :data:`None`.
- timeout (float, optional): The timeout for the RPC call. Defaults to
- :data:`rpc.api.UNSET_RPC_TIMEOUT`.
+ func (callable): The function to call.
+ args (tuple of object or None, optional): The arguments to pass to the function.
+ (default: :data:`None`)
+ kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function.
+ (default: :data:`None`)
+ partitioner (int, str, or callable, optional): A partitioner that partitions the arguments
+ to multiple workers. (default: :func:`batch_partitioner`)
+ reducer (callable or None, optional): A reducer that reduces the results from multiple
+ workers. If :data:`None`, do not reduce the results. (default: :data:`None`)
+ timeout (float, optional): The timeout for the RPC call.
+ (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`)
Returns:
A :class:`torch.Future` instance for the result. The result is at the current worker.
@@ -330,26 +334,26 @@ def remote_async_call(
def remote_sync_call(
func: Callable[..., T],
*,
- args: Optional[Args] = None,
- kwargs: Optional[KwArgs] = None,
- partitioner: Optional[Partitioner] = None,
- reducer: Optional[Callable[[Iterable[T]], U]] = None,
- timeout: Optional[float] = UNSET_RPC_TIMEOUT,
-) -> Union[List[T], U]:
+ args: Args | None = None,
+ kwargs: KwArgs | None = None,
+ partitioner: Partitioner | None = None,
+ reducer: Callable[[Iterable[T]], U] | None = None,
+ timeout: float | None = UNSET_RPC_TIMEOUT,
+) -> list[T] | U:
"""Do an RPC synchronously on remote workers and return the result to the current worker.
Args:
- func (Callable[..., T]): The function to call.
- args (Optional[Args], optional): The arguments to pass to the function. Defaults to
- :data:`None`.
- kwargs (Optional[KwArgs], optional): The keyword arguments to pass to the function. Defaults
- to :data:`None`.
- partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
- workers. Defaults to :func:`batch_partitioner`.
- reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
- multiple workers. Defaults to :data:`None`.
- timeout (float, optional): The timeout for the RPC call. Defaults to
- :data:`rpc.api.UNSET_RPC_TIMEOUT`.
+ func (callable): The function to call.
+ args (tuple of object or None, optional): The arguments to pass to the function.
+ (default: :data:`None`)
+ kwargs (dict[str, object] or None, optional): The keyword arguments to pass to the function.
+ (default: :data:`None`)
+ partitioner (int, str, or callable, optional): A partitioner that partitions the arguments
+ to multiple workers. (default: :func:`batch_partitioner`)
+ reducer (callable or None, optional): A reducer that reduces the results from multiple
+ workers. If :data:`None`, do not reduce the results. (default: :data:`None`)
+ timeout (float, optional): The timeout for the RPC call.
+ (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`)
Returns:
The result of the RPC call. The result is at the current worker.
@@ -365,10 +369,10 @@ def remote_sync_call(
def parallelize_async(
- partitioner: Optional[Partitioner] = None,
- reducer: Optional[Callable[[Iterable[T]], U]] = None,
- timeout: Optional[float] = UNSET_RPC_TIMEOUT,
-) -> Callable[[Callable[..., T]], Callable[..., Union[Future[List[T]], Future[U]]]]:
+ partitioner: Partitioner | None = None,
+ reducer: Callable[[Iterable[T]], U] | None = None,
+ timeout: float | None = UNSET_RPC_TIMEOUT,
+) -> Callable[[Callable[..., T]], Callable[..., Future[list[T]] | Future[U]]]:
"""Return a decorator for parallelizing a function.
This decorator can be used to parallelize a function call across multiple workers. The
@@ -376,13 +380,12 @@ def parallelize_async(
return a :class:`torch.Future` instance of the result.
Args:
- partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
- workers. Defaults to :func:`batch_partitioner`.
- reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
- multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not
- specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`.
- timeout (float, optional): The timeout for the RPC call. Defaults to
- :data:`rpc.api.UNSET_RPC_TIMEOUT`.
+ partitioner (int, str, or callable, optional): A partitioner that partitions the arguments
+ to multiple workers. (default: :func:`batch_partitioner`)
+ reducer (callable or None, optional): A reducer that reduces the results from multiple
+ workers. If :data:`None`, do not reduce the results. (default: :data:`None`)
+ timeout (float, optional): The timeout for the RPC call.
+ (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`)
Returns:
The decorator function.
@@ -392,9 +395,9 @@ def parallelize_async(
if reducer is None:
reducer = mean_reducer # type: ignore[assignment]
- def wrapper(func: Callable[..., T]) -> Callable[..., Union[Future[List[T]], Future[U]]]:
+ def wrapper(func: Callable[..., T]) -> Callable[..., Future[list[T]] | Future[U]]:
@functools.wraps(func)
- def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]:
+ def wrapped(*args: Any, **kwargs: Any) -> Future[list[T]] | Future[U]:
return remote_async_call(
func,
args=args,
@@ -423,22 +426,21 @@ def wrapped(*args: Any, **kwargs: Any) -> Union[Future[List[T]], Future[U]]:
def parallelize(
- partitioner: Optional[Partitioner] = None,
- reducer: Optional[Callable[[Iterable[T]], U]] = None,
- timeout: Optional[float] = UNSET_RPC_TIMEOUT,
-) -> Callable[[Callable[..., T]], Callable[..., Union[List[T], U]]]:
+ partitioner: Partitioner | None = None,
+ reducer: Callable[[Iterable[T]], U] | None = None,
+ timeout: float | None = UNSET_RPC_TIMEOUT,
+) -> Callable[[Callable[..., T]], Callable[..., list[T] | U]]:
"""Return a decorator for parallelizing a function.
This decorator can be used to parallelize a function call across multiple workers.
Args:
- partitioner (Partitioner, optional): A partitioner that partitions the arguments to multiple
- workers. Defaults to :func:`batch_partitioner`.
- reducer (Callable[[Iterable[T]], U], optional): A reducer that reduces the results from
- multiple workers. Defaults to :func:`mean_reducer` if the ``partitioner`` is not
- specified, i.e., :func:`batch_partitioner`. Otherwise, it defaults to :data:`None`.
- timeout (float, optional): The timeout for the RPC call. Defaults to
- :data:`rpc.api.UNSET_RPC_TIMEOUT`.
+ partitioner (int, str, or callable, optional): A partitioner that partitions the arguments
+ to multiple workers. (default: :func:`batch_partitioner`)
+ reducer (callable or None, optional): A reducer that reduces the results from multiple
+ workers. If :data:`None`, do not reduce the results. (default: :data:`None`)
+ timeout (float, optional): The timeout for the RPC call.
+ (default: :data:`rpc.api.UNSET_RPC_TIMEOUT`)
Returns:
The decorator function.
@@ -448,9 +450,9 @@ def parallelize(
if reducer is None:
reducer = mean_reducer # type: ignore[assignment]
- def wrapper(func: Callable[..., T]) -> Callable[..., Union[List[T], U]]:
+ def wrapper(func: Callable[..., T]) -> Callable[..., list[T] | U]:
@functools.wraps(func)
- def wrapped(*args: Any, **kwargs: Any) -> Union[List[T], U]:
+ def wrapped(*args: Any, **kwargs: Any) -> list[T] | U:
return remote_sync_call(
func,
args=args,
diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py
index 5fe51278..17fa9463 100644
--- a/torchopt/distributed/autograd.py
+++ b/torchopt/distributed/autograd.py
@@ -14,14 +14,15 @@
# ==============================================================================
"""Distributed Autograd."""
+from __future__ import annotations
+
from threading import Lock
-from typing import Optional, overload
import torch
import torch.distributed.autograd as autograd
from torch.distributed.autograd import context
-from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors, TupleOfTensors
+from torchopt.typing import TensorOrTensors, TupleOfOptionalTensors
__all__ = ['is_available', 'context']
@@ -43,22 +44,23 @@ def backward(
autograd_ctx_id: int,
tensors: TensorOrTensors,
retain_graph: bool = False,
- inputs: Optional[TensorOrTensors] = None,
+ inputs: TensorOrTensors | None = None,
) -> None:
"""Perform distributed backward pass for local parameters.
Compute the sum of gradients of given tensors with respect to graph leaves.
Args:
- autograd_ctx_id: The autograd context id.
- tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed.
+ autograd_ctx_id (int): The autograd context id.
+ tensors (Tensor or sequence of Tensor): Tensors of which the derivative will be computed.
retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will
be freed. Note that in nearly all cases setting this option to :data:`True` is not
needed and often can be worked around in a much more efficient way.
- inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will
- accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the
- gradient is accumulated into all the leaf Tensors that were used to compute the
- attr::tensors.
+ (default: :data:`False`)
+ inputs (Tensor, sequence of Tensor, or None, optional): Inputs w.r.t. which the gradient
+ be will accumulated into ``.grad``. All other Tensors will be ignored. If not
+ provided, the gradient is accumulated into all the leaf Tensors that were used to
+ compute the ``tensors``. (default: :data:`None`)
"""
if inputs is not None:
if isinstance(inputs, torch.Tensor):
@@ -85,25 +87,6 @@ def backward(
else:
p.grad = g
- @overload
- def grad(
- autograd_ctx_id: int,
- outputs: TensorOrTensors,
- inputs: TensorOrTensors,
- retain_graph: bool = False,
- ) -> TupleOfTensors:
- ...
-
- @overload
- def grad(
- autograd_ctx_id: int,
- outputs: TensorOrTensors,
- inputs: TensorOrTensors,
- retain_graph: bool = False,
- allow_unused: bool = False,
- ) -> TupleOfOptionalTensors:
- ...
-
def grad(
autograd_ctx_id: int,
outputs: TensorOrTensors,
@@ -114,16 +97,17 @@ def grad(
"""Compute and return the sum of gradients of outputs with respect to the inputs.
Args:
- autograd_ctx_id: The autograd context id.
- outputs (sequence of Tensor): outputs of the differentiated function.
- inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not
- accumulated into ``.grad``).
+ autograd_ctx_id (int): The autograd context id.
+ outputs (Tensor or sequence of Tensor): Outputs of the differentiated function.
+ inputs (Tensor or sequence of Tensor): Inputs w.r.t. which the gradient will be returned
+ (and not accumulated into ``.grad``).
retain_graph (bool, optional): If :data:`False`, the graph used to compute the grad will
be freed. Note that in nearly all cases setting this option to :data:`True` is not
needed and often can be worked around in a much more efficient way.
+ (default: :data:`False`)
allow_unused (bool, optional): If :data:`False`, specifying inputs that were not used
when computing outputs (and therefore their grad is always zero) is an error.
- Defaults to :data:`False`.
+ (default: :data:`False`)
"""
outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs)
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py
index 45140df1..804d4b9d 100644
--- a/torchopt/distributed/world.py
+++ b/torchopt/distributed/world.py
@@ -14,10 +14,12 @@
# ==============================================================================
"""Utilities for gathering information about the world."""
+from __future__ import annotations
+
import atexit
import functools
import os
-from typing import Any, Callable, Iterable, NamedTuple, Optional, TypeVar, Union
+from typing import Any, Callable, Iterable, NamedTuple, TypeVar
import torch.distributed.rpc as rpc
from torch.distributed.elastic.multiprocessing.errors import record
@@ -127,32 +129,33 @@ def get_local_world_size() -> int:
# pylint: disable-next=redefined-builtin,invalid-name
-def get_worker_id(id: Optional[Union[str, int]] = None) -> int:
+def get_worker_id(id: str | int | None = None) -> int:
"""Get the worker id from the given id."""
if isinstance(id, int):
return id
return rpc.get_worker_info(worker_name=id).id
-def barrier(worker_names: Optional[Iterable[str]] = None) -> None:
+def barrier(worker_names: Iterable[str] | None = None) -> None:
r"""Synchronize local and remote RPC processes.
This will block until all local and remote RPC processes specified under worker_names
reach this method to wait for all outstanding work to complete.
Args:
- worker_names: The set of workers to synchronize. If :data:`None`, all workers.
+ worker_names (iterable of str or None, optional): The set of workers to synchronize.
+ If :data:`None`, all workers. (default: :data:`None`)
"""
worker_names = {} if worker_names is None else set(worker_names)
rpc.api._barrier(worker_names) # pylint: disable=protected-access
def auto_init_rpc(
- worker_init_fn: Optional[Callable[[], None]] = None,
+ worker_init_fn: Callable[[], None] | None = None,
worker_name_format: Callable[..., str] = default_worker_name_format,
*,
- backend: Optional['rpc.BackendType'] = None,
- rpc_backend_options: Optional['rpc.RpcBackendOptions'] = None,
+ backend: rpc.BackendType | None = None,
+ rpc_backend_options: rpc.RpcBackendOptions | None = None,
) -> Callable[[F], F]:
"""Return a decorator to automatically initialize RPC on the decorated function."""
global _WORKER_NAME_FORMAT # pylint: disable=global-statement
diff --git a/torchopt/hook.py b/torchopt/hook.py
index 949c76e7..f188415c 100644
--- a/torchopt/hook.py
+++ b/torchopt/hook.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Hook utilities."""
-from typing import Callable, Optional, Tuple
+from __future__ import annotations
+
+from typing import Callable
import torch
@@ -32,7 +34,7 @@ def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
def nan_to_num_hook(
- nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
+ nan: float = 0.0, posinf: float | None = None, neginf: float | None = None
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Return a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""
@@ -59,9 +61,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True, # pylint: disable=unused-argument
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
def f(g):
return g.register_hook(hook)
diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py
index 94daee53..5456f076 100644
--- a/torchopt/linalg/cg.py
+++ b/torchopt/linalg/cg.py
@@ -33,8 +33,10 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
from functools import partial
-from typing import Callable, Optional, Union
+from typing import Callable
import torch
@@ -100,14 +102,14 @@ def body_fn(value):
def _isolve(
_isolve_solve: Callable,
- A: Union[TensorTree, Callable[[TensorTree], TensorTree]],
+ A: TensorTree | Callable[[TensorTree], TensorTree],
b: TensorTree,
- x0: Optional[TensorTree] = None,
+ x0: TensorTree | None = None,
*,
rtol: float = 1e-5,
atol: float = 0.0,
- maxiter: Optional[int] = None,
- M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None,
+ maxiter: int | None = None,
+ M: TensorTree | Callable[[TensorTree], TensorTree] | None = None,
) -> TensorTree:
if x0 is None:
x0 = pytree.tree_map(torch.zeros_like, b)
@@ -133,14 +135,14 @@ def _isolve(
def cg(
- A: Union[TensorTree, Callable[[TensorTree], TensorTree]],
+ A: TensorTree | Callable[[TensorTree], TensorTree],
b: TensorTree,
- x0: Optional[TensorTree] = None,
+ x0: TensorTree | None = None,
*,
rtol: float = 1e-5,
atol: float = 0.0,
- maxiter: Optional[int] = None,
- M: Optional[Union[TensorTree, Callable[[TensorTree], TensorTree]]] = None,
+ maxiter: int | None = None,
+ M: TensorTree | Callable[[TensorTree], TensorTree] | None = None,
) -> TensorTree:
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
@@ -153,30 +155,30 @@ def cg(
solves converge.
Args:
- A: (tensor or tree of tensors or function)
- 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
- called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
- must return array(s) with the same structure and shape as its argument.
- b: (tensor or tree of tensors)
- Right hand side of the linear system representing a single vector. Can be stored as an
- array or Python container of array(s) with any shape.
- x0: (tensor or tree of tensors, optional)
- Starting guess for the solution. Must have the same structure as ``b``.
- rtol: (float, optional, default: :const:`1e-5`)
- Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``. We do not
- implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy
- unless you explicitly pass ``atol`` to SciPy's ``cg``.
- atol: (float, optional, default: :const:`0.0`)
- Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``. We do not
- implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from SciPy
- unless you explicitly pass ``atol`` to SciPy's ``cg``.
- maxiter: (integer, optional)
- Maximum number of iterations. Iteration will stop after maxiter steps even if the
- specified tolerance has not been achieved.
- M: (tensor or tree of tensors or function)
- Pre-conditioner for ``A``. The pre-conditioner should approximate the inverse of ``A``.
- Effective preconditioning dramatically improves the rate of convergence, which implies
- that fewer iterations are needed to reach a given error tolerance.
+ A (Tensor or tree of Tensor): 2D array or function that calculates the linear map
+ (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
+ hermitian, positive definite matrix, and must return tensor(s) with the same structure
+ and shape as its argument.
+ b (Tensor or tree of Tensor): Right hand side of the linear system representing a single
+ vector. Can be stored as a tensor or Python container of tensor(s) with any shape.
+ x0 (Tensor, tree of Tensor, or None, optional): Starting guess for the solution. Must have
+ the same structure as ``b``. If :data:`None`, use zero initialization.
+ (default: :data:`None`)
+ rtol (float, optional): Tolerances for convergence, ``norm(residual) <= max(rtol*norm(b), atol)``.
+ We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from
+ SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`1e-5`)
+ atol (float, optional): Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
+ We do not implement SciPy's "legacy" behavior, so TorchOpt's tolerance will differ from
+ SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``. (default: :const:`0.0`)
+ maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after
+ maxiter steps even if the specified tolerance has not been achieved. If :data:`None`,
+ ``10 * size`` will be used, where ``size`` is the size of the flattened input tensor(s).
+ (default: :data:`None`)
+ M (Tensor, tree of Tensor, function, or None, optional): Pre-conditioner for ``A``. The
+ pre-conditioner should approximate the inverse of ``A``. Effective preconditioning
+ dramatically improves the rate of convergence, which implies that fewer iterations are
+ needed to reach a given error tolerance. If :data:`None`, no pre-conditioner will be
+ used. (default: :data:`None`)
Returns:
the Conjugate Gradient (CG) linear solver
diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py
index 04f5dd11..c1975203 100644
--- a/torchopt/linalg/ns.py
+++ b/torchopt/linalg/ns.py
@@ -16,13 +16,15 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
import functools
-from typing import Callable, Optional, Union
+from typing import Callable
import torch
from torchopt import pytree
-from torchopt.linalg.utils import cat_shapes, normalize_matvec
+from torchopt.linalg.utils import normalize_matvec
from torchopt.typing import TensorTree
@@ -33,7 +35,7 @@ def _ns_solve(
A: torch.Tensor,
b: torch.Tensor,
maxiter: int,
- alpha: Optional[float] = None,
+ alpha: float | None = None,
) -> torch.Tensor:
"""Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``."""
if A.ndim != 2 or A.shape[0] != A.shape[1]:
@@ -57,27 +59,26 @@ def _ns_solve(
def ns(
- A: Union[TensorTree, Callable[[TensorTree], TensorTree]],
+ A: TensorTree | Callable[[TensorTree], TensorTree],
b: TensorTree,
- maxiter: Optional[int] = None,
+ maxiter: int | None = None,
*,
- alpha: Optional[float] = None,
+ alpha: float | None = None,
) -> TensorTree:
"""Use Neumann Series Matrix Inversion Approximation to solve ``Ax = b``.
Args:
- A: (tensor or tree of tensors or function)
- 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
- called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
- must return array(s) with the same structure and shape as its argument.
- b: (tensor or tree of tensors)
- Right hand side of the linear system representing a single vector. Can be stored as an
- array or Python container of array(s) with any shape.
- maxiter: (integer, optional)
- Maximum number of iterations. Iteration will stop after maxiter steps even if the
- specified tolerance has not been achieved.
- alpha: (float, optional)
- Decay coefficient.
+ A (Tensor or tree of Tensor): 2D array or function that calculates the linear map
+ (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
+ hermitian, positive definite matrix, and must return tensor(s) with the same structure
+ and shape as its argument.
+ b (Tensor or tree of Tensor): Right hand side of the linear system representing a single
+ vector. Can be stored as a tensor or Python container of tensor(s) with any shape.
+ maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after
+ maxiter steps even if the specified tolerance has not been achieved. If :data:`None`,
+ :const:`10` will be used. (default: :const:`10`)
+ alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be
+ used. (default: :const:`1.0`)
Returns:
The Neumann Series (NS) matrix inversion approximation.
@@ -111,7 +112,7 @@ def ns(
return inv_A_hat_b
-def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None):
+def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None):
"""Use Neumann Series iteration to solve ``A^{-1}``."""
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}')
@@ -134,28 +135,27 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: Optional[float] = None):
def ns_inv(
A: TensorTree,
- maxiter: Optional[int] = None,
+ maxiter: int | None = None,
*,
- alpha: Optional[float] = None,
+ alpha: float | None = None,
) -> TensorTree:
"""Use Neumann Series iteration to solve ``A^{-1}``.
Args:
- A: (tensor or tree of tensors or function)
- 2D array or function that calculates the linear map (matrix-vector product) ``Ax`` when
- called like ``A(x)``. ``A`` must represent a hermitian, positive definite matrix, and
- must return array(s) with the same structure and shape as its argument.
- maxiter: (integer, optional)
- Maximum number of iterations. Iteration will stop after maxiter steps even if the
- specified tolerance has not been achieved.
- alpha: (float, optional)
- Decay coefficient.
+ A (Tensor or tree of Tensor): 2D array or function that calculates the linear map
+ (matrix-vector product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
+ hermitian, positive definite matrix, and must return tensor(s) with the same structure
+ and shape as its argument.
+ maxiter (int or None, optional): Maximum number of iterations. Iteration will stop after
+ maxiter steps even if the specified tolerance has not been achieved. If :data:`None`,
+ :const:`10` will be used. (default: :const:`10`)
+ alpha: (float or None, optional): Decay coefficient. If :data:`None`, :const:`1.0` will be
+ used. (default: :const:`1.0`)
Returns:
The Neumann Series (NS) matrix inversion approximation.
"""
if maxiter is None:
- size = sum(cat_shapes(A))
- maxiter = 10 * size # copied from SciPy
+ maxiter = 10
return pytree.tree_map(functools.partial(_ns_inv, maxiter=maxiter, alpha=alpha), A)
diff --git a/torchopt/linalg/utils.py b/torchopt/linalg/utils.py
index 275232be..f301a624 100644
--- a/torchopt/linalg/utils.py
+++ b/torchopt/linalg/utils.py
@@ -14,8 +14,10 @@
# ==============================================================================
"""Utilities for linear algebra."""
+from __future__ import annotations
+
import itertools
-from typing import Callable, Tuple, Union
+from typing import Callable
import torch
@@ -23,14 +25,14 @@
from torchopt.typing import TensorTree
-def cat_shapes(tree: TensorTree) -> Tuple[int, ...]:
+def cat_shapes(tree: TensorTree) -> tuple[int, ...]:
"""Concatenate the shapes of the leaves of a tree of tensors."""
leaves = pytree.tree_leaves(tree)
return tuple(itertools.chain.from_iterable(tuple(leaf.shape) for leaf in leaves))
def normalize_matvec(
- matvec: Union[TensorTree, Callable[[TensorTree], TensorTree]]
+ matvec: TensorTree | Callable[[TensorTree], TensorTree]
) -> Callable[[TensorTree], TensorTree]:
"""Normalize an argument for computing matrix-vector product."""
if callable(matvec):
diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py
index f75ef9f4..844c9407 100644
--- a/torchopt/linear_solve/cg.py
+++ b/torchopt/linear_solve/cg.py
@@ -33,8 +33,10 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
import functools
-from typing import Callable, Optional
+from typing import Callable
from torchopt import linalg
from torchopt.linear_solve.utils import make_ridge_matvec
@@ -47,8 +49,8 @@
def _solve_cg(
matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x
b: TensorTree,
- ridge: Optional[float] = None,
- init: Optional[TensorTree] = None,
+ ridge: float | None = None,
+ init: TensorTree | None = None,
**kwargs,
) -> TensorTree:
"""Solve ``A x = b`` using conjugate gradient.
@@ -56,10 +58,12 @@ def _solve_cg(
This assumes that ``A`` is a hermitian, positive definite matrix.
Args:
- matvec: A function that returns the product between ``A`` and a vector.
- b: A tree of tensors for the right hand side of the equation.
- ridge: Optional ridge regularization.
- init: Optional initialization to be used by conjugate gradient.
+ matvec (callable): A function that returns the product between ``A`` and a vector.
+ b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A x + ridge x = b``. (default: :data:`None`)
+ init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by
+ conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`)
**kwargs: Additional keyword arguments for the conjugate gradient solver.
Returns:
@@ -80,8 +84,10 @@ def solve_cg(**kwargs):
This assumes that ``A`` is a hermitian, positive definite matrix.
Args:
- ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``.
- init: Optional initialization to be used by conjugate gradient.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A x + ridge x = b``. (default: :data:`None`)
+ init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by
+ conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`)
**kwargs: Additional keyword arguments for the conjugate gradient solver
:func:`torchopt.linalg.cg`.
diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py
index c3224a52..399a0ef9 100644
--- a/torchopt/linear_solve/inv.py
+++ b/torchopt/linear_solve/inv.py
@@ -33,8 +33,10 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
import functools
-from typing import Callable, Optional
+from typing import Callable
import torch
@@ -49,7 +51,7 @@
def _solve_inv(
matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x
b: TensorTree,
- ridge: Optional[float] = None,
+ ridge: float | None = None,
ns: bool = False,
**kwargs,
) -> TensorTree:
@@ -59,11 +61,13 @@ def _solve_inv(
in memory.
Args:
- matvec: A function that returns the product between ``A`` and a vector.
- b: A tensor for the right hand side of the equation.
- ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``.
- ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`,
- materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead.
+ matvec (callable): A function that returns the product between ``A`` and a vector.
+ b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A x + ridge x = b``. (default: :data:`None`)
+ ns (bool, optional): Whether to use Neumann Series matrix inversion approximation.
+ If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve`
+ instead. (default: :data:`False`)
**kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation
solver :func:`torchopt.linalg.ns`.
@@ -94,9 +98,11 @@ def solve_inv(**kwargs):
in memory.
Args:
- ridge: Optional ridge regularization. Solves the equation for ``(A + ridge * I) @ x = b``.
- ns: Whether to use Neumann Series matrix inversion approximation. If :data:`False`,
- materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve` instead.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A x + ridge x = b``. (default: :data:`None`)
+ ns (bool, optional): Whether to use Neumann Series matrix inversion approximation.
+ If :data:`False`, materialize the matrix ``A`` in memory and use :func:`torch.linalg.solve`
+ instead. (default: :data:`False`)
**kwargs: Additional keyword arguments for the Neumann Series matrix inversion approximation
solver :func:`torchopt.linalg.ns`.
diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py
index 3199a490..8d38f77a 100644
--- a/torchopt/linear_solve/normal_cg.py
+++ b/torchopt/linear_solve/normal_cg.py
@@ -33,8 +33,10 @@
# pylint: disable=invalid-name
+from __future__ import annotations
+
import functools
-from typing import Callable, Optional
+from typing import Callable
from torchopt import linalg
from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec
@@ -47,8 +49,8 @@
def _solve_normal_cg(
matvec: Callable[[TensorTree], TensorTree], # (x) -> A @ x
b: TensorTree,
- ridge: Optional[float] = None,
- init: Optional[TensorTree] = None,
+ ridge: float | None = None,
+ init: TensorTree | None = None,
**kwargs,
) -> TensorTree:
"""Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient.
@@ -57,10 +59,12 @@ def _solve_normal_cg(
positive definite.
Args:
- matvec: A function that returns the product between ``A`` and a vector.
- b: A tree of tensors for the right hand side of the equation.
- ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``.
- init: Optional initialization to be used by normal conjugate gradient.
+ matvec (callable): A function that returns the product between ``A`` and a vector.
+ b (Tensor or tree of Tensor): A tree of tensors for the right hand side of the equation.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`)
+ init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by
+ conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`)
**kwargs: Additional keyword arguments for the conjugate gradient solver
:func:`torchopt.linalg.cg`.
@@ -93,8 +97,10 @@ def solve_normal_cg(**kwargs):
positive definite.
Args:
- ridge: Optional ridge regularization. Solves the equation for ``(A.T @ A + ridge * I) @ x = A.T @ b``.
- init: Optional initialization to be used by normal conjugate gradient.
+ ridge (float or None, optional): Optional ridge regularization. If provided, solves the
+ equation for ``A^T A x + ridge x = A^T b``. (default: :data:`None`)
+ init (Tensor, tree of Tensor, or None, optional): Optional initialization to be used by
+ conjugate gradient. If :data:`None`, uses zero initialization. (default: :data:`None`)
**kwargs: Additional keyword arguments for the conjugate gradient solver
:func:`torchopt.linalg.cg`.
diff --git a/torchopt/linear_solve/utils.py b/torchopt/linear_solve/utils.py
index 9c2f7ced..f4f34e2a 100644
--- a/torchopt/linear_solve/utils.py
+++ b/torchopt/linear_solve/utils.py
@@ -31,7 +31,9 @@
# ==============================================================================
"""Utilities for linear algebra solvers."""
-from typing import Callable, Tuple
+from __future__ import annotations
+
+from typing import Callable
import functorch
@@ -75,7 +77,7 @@ def ridge_matvec(y: TensorTree) -> TensorTree:
def materialize_matvec(
matvec: Callable[[TensorTree], TensorTree], x: TensorTree
-) -> Tuple[
+) -> tuple[
TensorTree,
Callable[[TensorTree], TensorTree],
Callable[[TensorTree], TensorTree],
diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py
index 3716f674..f8804864 100644
--- a/torchopt/nn/module.py
+++ b/torchopt/nn/module.py
@@ -14,8 +14,10 @@
# ==============================================================================
"""Base class for neural network modules that hold meta-parameters and meta-modules."""
+from __future__ import annotations
+
from collections import OrderedDict
-from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union
+from typing import Any, Iterator, NamedTuple
import torch
import torch.nn as nn
@@ -27,8 +29,8 @@
class MetaInputsContainer(NamedTuple):
"""Container for parameters and modules in the constructor input arguments."""
- meta_parameters: Set[torch.Tensor]
- meta_modules: Set[nn.Module]
+ meta_parameters: set[torch.Tensor]
+ meta_modules: set[nn.Module]
class MetaGradientModule(nn.Module): # pylint: disable=abstract-method
@@ -36,12 +38,12 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method
_meta_inputs: MetaInputsContainer
_meta_parameters: TensorContainer
- _meta_modules: Dict[str, Optional[nn.Module]]
+ _meta_modules: dict[str, nn.Module | None]
- def __new__(cls, *args, **kwargs) -> 'MetaGradientModule':
+ def __new__(cls, *args, **kwargs) -> MetaGradientModule:
"""Create a new module instance."""
instance = super().__new__(cls)
- flat_args: List[Any]
+ flat_args: list[Any]
flat_args = pytree.tree_leaves((args, kwargs)) # type: ignore[arg-type]
meta_parameters = {x for x in flat_args if isinstance(x, torch.Tensor) and x.requires_grad}
meta_modules = {x for x in flat_args if isinstance(x, nn.Module) and x.training}
@@ -51,14 +53,14 @@ def __new__(cls, *args, **kwargs) -> 'MetaGradientModule':
instance._meta_inputs = MetaInputsContainer(meta_parameters, meta_modules)
instance._meta_parameters: TensorContainer = OrderedDict() # type: ignore[misc]
- instance._meta_modules: Dict[str, Optional[nn.Module]] = OrderedDict() # type: ignore[misc]
+ instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc]
return instance
def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
"""Initialize a new module instance."""
super().__init__()
- def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
+ def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
"""Get an attribute of the module."""
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
@@ -83,7 +85,7 @@ def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# pylint: disable-next=too-many-branches,too-many-statements
- def __setattr__(self, name: str, value: Union[torch.Tensor, nn.Module]) -> None:
+ def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
"""Set an attribute of the module."""
def remove_from(*dicts_or_sets):
@@ -186,18 +188,17 @@ def __delattr__(self, name: str) -> None:
else:
object.__delattr__(self, name)
- def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None:
+ def register_parameter(self, name: str, param: torch.Tensor | None) -> None:
r"""Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
- name (string): name of the parameter. The parameter can be accessed
- from this module using the given name
- param (torch.Tensor or None): parameter to be added to the module. If
- ``None``, then operations that run on parameters, such as :attr:`cuda`,
- are ignored. If ``None``, the parameter is **not** included in the
- module's :attr:`state_dict`.
+ name (str): The name of the parameter. The parameter can be accessed from this module
+ using the given name.
+ param (Tensor or None): The parameter to be added to the module. If :data:`None`, then
+ operations that run on parameters, such as ``cuda``, are ignored. If :data:`None`,
+ the parameter is **not** included in the module's ``state_dict``.
"""
if '_parameters' not in self.__dict__:
raise AttributeError('cannot assign parameter before Module.__init__() call')
@@ -231,18 +232,17 @@ def register_parameter(self, name: str, param: Optional[torch.Tensor]) -> None:
self._parameters[name] = param # type: ignore
- def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> None:
+ def register_meta_parameter(self, name: str, param: torch.Tensor | None) -> None:
r"""Add a meta-parameter to the module.
The meta-parameter can be accessed as an attribute using given name.
Args:
- name (string): name of the parameter. The parameter can be accessed
- from this module using the given name
- param (torch.Tensor or None): parameter to be added to the module. If
- ``None``, then operations that run on parameters, such as :attr:`cuda`,
- are ignored. If ``None``, the parameter is **not** included in the
- module's :attr:`state_dict`.
+ name (str): The name of the meta-parameter. The meta-parameter can be accessed from this
+ module using the given name.
+ param (Tensor or None): The meta-parameter to be added to the module. If :data:`None`,
+ then operations that run on meta-parameters, such as ``cuda``, are ignored. If
+ :data:`None`, the meta-parameter is **not** included in the module's ``state_dict``.
"""
if '_meta_parameters' not in self.__dict__:
raise AttributeError(
@@ -273,15 +273,15 @@ def register_meta_parameter(self, name: str, param: Optional[torch.Tensor]) -> N
self._meta_parameters[name] = param
- def add_module(self, name: str, module: Optional[nn.Module]) -> None:
+ def add_module(self, name: str, module: nn.Module | None) -> None:
r"""Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
- name (string): name of the child module. The child module can be
- accessed from this module using the given name
- module (Module): child module to be added to the module.
+ name (str): The name of the child module. The child module can be accessed from this
+ module using the given name
+ module (nn.Module or None): The child module to be added to the module.
"""
if not isinstance(module, nn.Module) and module is not None:
raise TypeError(f'{torch.typename(module)} is not a Module subclass')
@@ -301,19 +301,19 @@ def add_module(self, name: str, module: Optional[nn.Module]) -> None:
self._modules[name] = module
- def register_module(self, name: str, module: Optional[nn.Module]) -> None:
+ def register_module(self, name: str, module: nn.Module | None) -> None:
r"""Alias for :func:`add_module`."""
self.add_module(name, module)
- def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None:
+ def add_meta_module(self, name: str, meta_module: nn.Module | None) -> None:
r"""Add a child meta-module to the current module.
The meta-module can be accessed as an attribute using the given name.
Args:
- name (string): name of the child meta-module. The child meta-module can be
- accessed from this module using the given name
- meta_module (Module): child meta-module to be added to the module.
+ name (str): The name of the child meta-module. The child meta-module can be accessed
+ from this module using the given name
+ meta_module (nn.Module or None): The child meta-module to be added to the module.
"""
if not isinstance(meta_module, nn.Module) and meta_module is not None:
raise TypeError(f'{torch.typename(meta_module)} is not a Module subclass')
@@ -328,7 +328,7 @@ def add_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None:
self._meta_modules[name] = meta_module
- def register_meta_module(self, name: str, meta_module: Optional[nn.Module]) -> None:
+ def register_meta_module(self, name: str, meta_module: nn.Module | None) -> None:
r"""Alias for :func:`add_meta_module`."""
self.add_meta_module(name, meta_module)
@@ -338,9 +338,9 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]:
This is typically passed to an optimizer.
Args:
- recurse (bool): if True, then yields parameters of this module and
- all submodules. Otherwise, yields only meta-parameters that
- are direct members of this module.
+ recurse (bool, optional): If :data:`True`, then yields parameters of this module and
+ all submodules. Otherwise, yields only meta-parameters that are direct members of
+ this module. (default: :data:`True`)
Yields:
Parameter: module meta-parameter
@@ -358,14 +358,15 @@ def meta_parameters(self, recurse: bool = True) -> Iterator[torch.Tensor]:
def named_meta_parameters(
self, prefix: str = '', recurse: bool = True
- ) -> Iterator[Tuple[str, torch.Tensor]]:
+ ) -> Iterator[tuple[str, torch.Tensor]]:
r"""Return an iterator over module meta-parameters, yielding both the name of the meta-parameter as well as the meta-parameter itself.
Args:
- prefix (str): prefix to prepend to all meta-parameter names.
- recurse (bool): if True, then yields meta-parameters of this module
- and all submodules. Otherwise, yields only meta-parameters that
- are direct members of this module.
+ prefix (str, optional): The prefix to prepend to all meta-parameter names.
+ (default: :const:`''`)
+ recurse (bool, optional): if :data:`True`, then yields meta-parameters of this module
+ and all submodules. Otherwise, yields only meta-parameters that are direct members
+ of this module. (default: :data:`True`)
Yields:
(string, Parameter): Tuple containing the name and parameter
@@ -398,7 +399,7 @@ def meta_children(self) -> Iterator[nn.Module]:
for _, module in self.named_meta_children():
yield module
- def named_meta_children(self) -> Iterator[Tuple[str, nn.Module]]:
+ def named_meta_children(self) -> Iterator[tuple[str, nn.Module]]:
r"""Return an iterator over immediate children meta-modules, yielding both the name of the meta-module as well as the meta-module itself.
Yields:
@@ -430,15 +431,18 @@ def meta_modules(self) -> Iterator[nn.Module]:
yield meta_module
def named_meta_modules(
- self, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True
- ) -> Iterator[Tuple[str, nn.Module]]:
+ self, memo: set[nn.Module] | None = None, prefix: str = '', remove_duplicate: bool = True
+ ) -> Iterator[tuple[str, nn.Module]]:
r"""Return an iterator over all meta-modules in the network, yielding both the name of the meta-module as well as the meta-module itself.
Args:
- memo: a memo to store the set of meta-modules already added to the result
- prefix: a prefix that will be added to the name of the meta-module
- remove_duplicate: whether to remove the duplicated meta-module instances in the result
- or not
+ memo (set of nn.Module or None, optional): A memory to store the set of meta-modules
+ already added to the result. If not provided, a new set will be created.
+ (default: :const:`None`)
+ prefix (str, optional): A prefix that will be added to the name of the meta-module.
+ (default: :const:`''`)
+ remove_duplicate (bool, optional): whether to remove the duplicated meta-module
+ instances in the result or not. (default: :const:`True`)
Yields:
(string, Module): Tuple of name and meta-module
diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py
index 2fc0dbb4..9391352f 100644
--- a/torchopt/nn/stateless.py
+++ b/torchopt/nn/stateless.py
@@ -14,8 +14,10 @@
# ==============================================================================
"""Utility functions for stateless module calls."""
+from __future__ import annotations
+
import contextlib
-from typing import Dict, Generator, Iterable, Tuple, Union
+from typing import Generator, Iterable
import torch
import torch.nn as nn
@@ -29,9 +31,9 @@
def swap_state(
module: nn.Module,
- named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]],
+ named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
allow_missing: bool = False,
-) -> Dict[str, torch.Tensor]:
+) -> dict[str, torch.Tensor]:
"""Swap the module parameters and/or buffers."""
if not isinstance(named_tensors, dict):
named_tensors = dict(named_tensors)
@@ -84,7 +86,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor:
@contextlib.contextmanager
def reparametrize(
module: nn.Module,
- named_tensors: Union[Dict[str, torch.Tensor], Iterable[Tuple[str, torch.Tensor]]],
+ named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
allow_missing: bool = False,
) -> Generator[nn.Module, None, None]:
"""Reparameterize the module parameters and/or buffers."""
diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py
index c56956f8..640eea1d 100644
--- a/torchopt/optim/adam.py
+++ b/torchopt/optim/adam.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Adam optimizer."""
-from typing import Iterable, Tuple
+from __future__ import annotations
+
+from typing import Iterable
import torch
@@ -39,7 +41,7 @@ def __init__(
self,
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
@@ -50,25 +52,27 @@ def __init__(
r"""Initialize the Adam optimizer.
Args:
- params: (iterable of torch.Tensor)
- An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized.
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to
- avoid dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+ tensors should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ eps_root (float, optional): A small constant applied to denominator inside the square
+ root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
+ example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
"""
super().__init__(
params,
diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py
index 19c70678..7db5e750 100644
--- a/torchopt/optim/adamw.py
+++ b/torchopt/optim/adamw.py
@@ -14,13 +14,15 @@
# ==============================================================================
"""AdamW optimizer."""
-from typing import Any, Callable, Iterable, Optional, Tuple, Union
+from __future__ import annotations
+
+from typing import Callable, Iterable
import torch
from torchopt import alias
from torchopt.optim.base import Optimizer
-from torchopt.typing import Params, ScalarOrSchedule
+from torchopt.typing import OptState, Params, ScalarOrSchedule
__all__ = ['AdamW']
@@ -39,46 +41,48 @@ def __init__(
self,
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule = 1e-3,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
maximize: bool = False,
use_accelerated_op: bool = False,
) -> None:
r"""Initialize the AdamW optimizer.
Args:
- params: (iterable of torch.Tensor)
- An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized.
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`1e-2`)
- Strength of the weight decay regularization. Note that this weight decay is
- multiplied with the learning rate. This is consistent with other frameworks such as
- PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only
- multiplied with the "schedule multiplier", but not the base learning rate.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to
- avoid dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- mask: (default: :data:`None`)
- A tree with same structure as (or a prefix of) the params PyTree, or a Callable that
+ params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+ tensors should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Strength of the weight decay regularization. Note that
+ this weight decay is multiplied with the learning rate. This is consistent with
+ other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where
+ the weight decay is only multiplied with the "schedule multiplier", but not the base
+ learning rate. (default: :const:`1e-2`)
+ eps_root (float, optional): A small constant applied to denominator inside the square
+ root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
+ example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ mask (tree of Tensor, callable, or None, optional):
+ A tree with same structure as (or a prefix of) the params pytree, or a function that
returns such a pytree given the params/updates. The leaves should be booleans,
:data:`True` for leaves/subtrees you want to apply the weight decay to, and
:data:`False` for those you want to skip. Note that the Adam gradient
- transformations are applied to all parameters.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ transformations are applied to all parameters. (default: :data:`None`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
"""
super().__init__(
params,
diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py
index e894b93b..aac3a782 100644
--- a/torchopt/optim/base.py
+++ b/torchopt/optim/base.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""The base class for optimizers."""
-from typing import Callable, Iterable, List, Optional, Sequence, Tuple
+from __future__ import annotations
+
+from typing import Callable, Iterable, Sequence
import torch
@@ -37,8 +39,8 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation)
params (iterable of torch.Tensor): An iterable of :class:`torch.Tensor`\s. Specifies
what tensors should be optimized.
impl (GradientTransformation): A low level optimizer function, it could be a optimizer
- function provided by ``alias.py`` or a customized ``chain`` provided by
- ``combine.py``.
+ function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed
+ transformation.
Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to
:class:`torchopt.SGD`.
"""
@@ -46,9 +48,9 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation)
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')
self.impl: GradientTransformation = impl
- self.param_groups: List[TupleOfTensors] = []
- self.param_treespecs: List[pytree.PyTreeSpec] = []
- self.state_groups: List[OptState] = []
+ self.param_groups: list[TupleOfTensors] = []
+ self.param_treespecs: list[pytree.PyTreeSpec] = []
+ self.state_groups: list[OptState] = []
if not isinstance(params, (list, tuple)):
params = tuple(params)
@@ -60,7 +62,8 @@ def zero_grad(self, set_to_none: bool = False) -> None:
The behavior is similar to :meth:`torch.optim.Optimizer.zero_grad`.
Args:
- set_to_none (bool): Instead of setting to zero, set the ``grads`` to :data:`None`.
+ set_to_none (bool, optional): Instead of setting to zero, set the ``grads`` to
+ :data:`None`. (default: :data:`False`)
"""
if set_to_none:
@@ -80,7 +83,7 @@ def f(p):
pytree.tree_map_(f, self.param_groups) # type: ignore[arg-type]
- def state_dict(self) -> Tuple[OptState, ...]:
+ def state_dict(self) -> tuple[OptState, ...]:
"""Return the state of the optimizer."""
return tuple(self.state_groups)
@@ -88,18 +91,19 @@ def load_state_dict(self, state_dict: Sequence[OptState]) -> None:
"""Load the optimizer state.
Args:
- state_dict: Optimizer state. Should be an object returned from a call to
- :meth:`state_dict`.
+ state_dict (sequence of tree of Tensor): Optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
"""
self.state_groups[:] = list(state_dict)
- def step(self, closure: Optional[Callable[[], torch.Tensor]] = None) -> Optional[torch.Tensor]:
+ def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tensor | None:
"""Perform a single optimization step.
The behavior is similar to :meth:`torch.optim.Optimizer.step`.
Args:
- closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ closure (callable or None, optional): A closure that reevaluates the model and returns
+ the loss. Optional for most optimizers. (default: :data:`None`)
"""
loss = None
if closure is not None:
@@ -120,7 +124,7 @@ def f(p):
return loss
def add_param_group(self, params: Params) -> None:
- """Add a param group to the optimizer's :attr:`param_groups`."""
+ """Add a param group to the optimizer's ``param_groups``."""
flat_params: TupleOfTensors
flat_params, params_treespec = pytree.tree_flatten_as_tuple(params)
self.param_groups.append(flat_params)
diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py
index 7e51a21b..9dce3412 100644
--- a/torchopt/optim/func/base.py
+++ b/torchopt/optim/func/base.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Functional optimizer wrappers."""
-from typing import Optional
+from __future__ import annotations
import torch
@@ -41,26 +41,27 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods
"""
def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> None:
- """Initialize the functional optimizer wrapper.
+ r"""Initialize the functional optimizer wrapper.
Args:
impl (GradientTransformation): A low level optimizer function, it could be a optimizer
- function provided by `alias.py` or a customized `chain` provided by `combine.py`.
- inplace (optional): (default: :data:`False`)
- The default value of ``inplace`` for each optimization update.
+ function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed
+ transformation.
+ inplace (bool, optional): The default value of ``inplace`` for each optimization update.
+ (default: :data:`False`)
"""
if not isinstance(impl, GradientTransformation):
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')
self.impl: GradientTransformation = impl
- self.optim_state: Optional[OptState] = UninitializedState()
+ self.optim_state: OptState | None = UninitializedState()
self.inplace: bool = bool(inplace)
def step(
self,
loss: torch.Tensor,
params: Params,
- inplace: Optional[bool] = None,
+ inplace: bool | None = None,
) -> Params:
r"""Compute the gradients of loss to the network parameters and update network parameters.
@@ -69,13 +70,12 @@ def step(
gradients and update the network parameters without modifying tensors in-place.
Args:
- loss: (torch.Tensor)
- loss that is used to compute the gradients to network parameters.
- params: (tree of torch.Tensor)
- An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized.
- inplace (optional): (default: :data:`None`)
- Whether to update the parameters in-place. If :data:`None`, use the default value
- specified in the constructor.
+ loss (Tensor): The loss that is used to compute the gradients to network parameters.
+ params (tree of Tensor): An tree of :class:`torch.Tensor`\s. Specifies what tensors
+ should be optimized.
+ inplace (bool or None, optional): Whether to update the parameters in-place. If
+ :data:`None`, use the default value specified in the constructor.
+ (default: :data:`None`)
"""
if isinstance(self.optim_state, UninitializedState):
self.optim_state = self.impl.init(params)
diff --git a/torchopt/optim/meta/adam.py b/torchopt/optim/meta/adam.py
index 36d54857..bd9804b9 100644
--- a/torchopt/optim/meta/adam.py
+++ b/torchopt/optim/meta/adam.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Differentiable Adam optimizer."""
-from typing import Tuple
+from __future__ import annotations
import torch.nn as nn
@@ -39,7 +39,7 @@ def __init__(
self,
module: nn.Module,
lr: ScalarOrSchedule = 1e-3,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
@@ -51,28 +51,26 @@ def __init__(
"""Initialize the meta-Adam optimizer.
Args:
- module: (nn.Module)
- A network whose parameters should be optimized.
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to
- avoid dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- moment_requires_grad: (default: :data:`True`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ module (nn.Module): A network whose parameters should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ eps_root (float, optional): A small constant applied to denominator inside the square
+ root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
+ example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
"""
super().__init__(
module,
diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py
index dc869e30..c8a8ef9c 100644
--- a/torchopt/optim/meta/adamw.py
+++ b/torchopt/optim/meta/adamw.py
@@ -14,13 +14,15 @@
# ==============================================================================
"""Differentiable AdamW optimizer."""
-from typing import Any, Callable, Optional, Tuple, Union
+from __future__ import annotations
+
+from typing import Callable
import torch.nn as nn
from torchopt import alias
from torchopt.optim.meta.base import MetaOptimizer
-from torchopt.typing import Params, ScalarOrSchedule
+from torchopt.typing import OptState, Params, ScalarOrSchedule
__all__ = ['MetaAdamW']
@@ -39,12 +41,12 @@ def __init__(
self,
module: nn.Module,
lr: ScalarOrSchedule = 1e-3,
- betas: Tuple[float, float] = (0.9, 0.999),
+ betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
*,
eps_root: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
moment_requires_grad: bool = False,
maximize: bool = False,
use_accelerated_op: bool = False,
@@ -52,37 +54,35 @@ def __init__(
"""Initialize the meta-AdamW optimizer.
Args:
- module: (nn.Module)
- A network whose parameters should be optimized.
- lr: (default: :const:`1e-3`)
- This is a fixed global scaling factor.
- betas: (default: :const:`(0.9, 0.999)`)
- Coefficients used for computing running averages of gradient and its square.
- eps: (default: :const:`1e-8`)
- A small constant applied to denominator outside of the square root (as in the Adam
- paper) to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`1e-2`)
- Strength of the weight decay regularization. Note that this weight decay is
- multiplied with the learning rate. This is consistent with other frameworks such as
- PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only
- multiplied with the "schedule multiplier", but not the base learning rate.
- eps_root: (default: :data:`0.0`)
- A small constant applied to denominator inside the square root (as in RMSProp), to
- avoid dividing by zero when rescaling. This is needed for example when computing
- (meta-)gradients through Adam.
- mask: (default: :data:`None`)
- A tree with same structure as (or a prefix of) the params PyTree, or a Callable that
+ module (nn.Module): A network whose parameters should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-3`)
+ betas (tuple of float, optional): Coefficients used for computing running averages of
+ gradient and its square. (default: :const:`(0.9, 0.999)`)
+ eps (float, optional): A small constant applied to denominator outside of the square
+ root (as in the Adam paper) to avoid dividing by zero when rescaling.
+ (default: :const:`1e-8`)
+ weight_decay (float, optional): Strength of the weight decay regularization. Note that
+ this weight decay is multiplied with the learning rate. This is consistent with
+ other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where
+ the weight decay is only multiplied with the "schedule multiplier", but not the base
+ learning rate. (default: :const:`1e-2`)
+ eps_root (float, optional): A small constant applied to denominator inside the square
+ root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
+ example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+ mask (tree of Tensor, callable, or None, optional):
+ A tree with same structure as (or a prefix of) the params pytree, or a function that
returns such a pytree given the params/updates. The leaves should be booleans,
:data:`True` for leaves/subtrees you want to apply the weight decay to, and
:data:`False` for those you want to skip. Note that the Adam gradient
- transformations are applied to all parameters.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
- use_accelerated_op: (default: :data:`False`)
- If :data:`True` use our implemented fused operator.
+ transformations are applied to all parameters. (default: :data:`None`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
+ use_accelerated_op (bool, optional): If :data:`True` use our implemented fused operator.
+ (default: :data:`False`)
"""
super().__init__(
module,
diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py
index 8db4f0a7..c5c9ad73 100644
--- a/torchopt/optim/meta/base.py
+++ b/torchopt/optim/meta/base.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""The base class for differentiable meta-optimizers."""
-from typing import List, Sequence, Tuple
+from __future__ import annotations
+
+from typing import Sequence
import torch
import torch.nn as nn
@@ -33,14 +35,13 @@ class MetaOptimizer:
"""The base class for high-level differentiable optimizers."""
def __init__(self, module: nn.Module, impl: GradientTransformation) -> None:
- """Initialize the meta-optimizer.
+ r"""Initialize the meta-optimizer.
Args:
- module: (nn.Module)
- A network whose parameters should be optimized.
- impl: (GradientTransformation)
- A low level optimizer function, it could be a optimizer function provided by
- ``alias.py`` or a customized ``chain`` provided by ``combine.py``.
+ module (nn.Module): A network whose parameters should be optimized.
+ impl (GradientTransformation): A low level optimizer function, it could be a optimizer
+ function provided in :mod:`torchopt.alias` or a customized :func:`torchopt.chain`\ed
+ transformation.
Note that using ``MetaOptimizer(sgd(moment_requires_grad=True))`` or
``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to
:class:`torchopt.MetaSGD`.
@@ -49,8 +50,8 @@ def __init__(self, module: nn.Module, impl: GradientTransformation) -> None:
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')
self.impl: GradientTransformation = impl
- self.param_containers_groups: List[ModuleTensorContainers] = []
- self.state_groups: List[OptState] = []
+ self.param_containers_groups: list[ModuleTensorContainers] = []
+ self.state_groups: list[OptState] = []
self.add_param_group(module)
@@ -62,8 +63,8 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
gradients and update the network parameters without modifying tensors in-place.
Args:
- loss: (torch.Tensor)
- The loss that is used to compute the gradients to the network parameters.
+ loss (torch.Tensor): The loss that is used to compute the gradients to the network
+ parameters.
"""
# Step parameter only
for i, (param_container, state) in enumerate(
@@ -94,12 +95,12 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
container.update(new_param)
def add_param_group(self, module: nn.Module) -> None:
- """Add a param group to the optimizer's :attr:`state_groups`."""
+ """Add a param group to the optimizer's ``state_groups``."""
params_container = extract_module_containers(module, with_buffers=False)[0]
self.param_containers_groups.append(params_container)
self.state_groups.append(UninitializedState())
- def state_dict(self) -> Tuple[OptState, ...]:
+ def state_dict(self) -> tuple[OptState, ...]:
"""Extract the references of the optimizer states.
Note that the states are references, so any in-place operations will change the states
diff --git a/torchopt/optim/meta/rmsprop.py b/torchopt/optim/meta/rmsprop.py
index f4dfdae6..3aff20e1 100644
--- a/torchopt/optim/meta/rmsprop.py
+++ b/torchopt/optim/meta/rmsprop.py
@@ -50,30 +50,26 @@ def __init__(
"""Initialize the meta-RMSProp optimizer.
Args:
- module: (nn.Module)
- A network whose parameters should be optimized.
- lr: (default: :const:`1e-2`)
- This is a fixed global scaling factor.
- alpha: (default: :const:`0.99`)
- Smoothing constant, the decay used to track the magnitude of previous gradients.
- eps: (default: :const:`1e-8`)
- A small numerical constant to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- centered: (default: :data:`False`)
- If :data:`True`, use the variance of the past gradients to rescale the latest
- gradients.
- initial_scale: (default: :data:`0.0`)
- Initialization of accumulators tracking the magnitude of previous updates. PyTorch
- uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a
- paper, verify the value used by the authors.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ module (nn.Module): A network whose parameters should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-2`)
+ alpha (float, optional): Smoothing constant, the decay used to track the magnitude of
+ previous gradients. (default: :const:`0.99`)
+ eps (float, optional): A small numerical constant to avoid dividing by zero when
+ rescaling. (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is
+ not used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ centered (bool, optional): If :data:`True`, use the variance of the past gradients to
+ rescale the latest gradients. (default: :data:`False`)
+ initial_scale (float, optional): Initialization of accumulators tracking the magnitude
+ of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When
+ reproducing results from a paper, verify the value used by the authors.
+ (default: :data:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
"""
super().__init__(
module,
diff --git a/torchopt/optim/meta/sgd.py b/torchopt/optim/meta/sgd.py
index 5f9177e1..476ed9d6 100644
--- a/torchopt/optim/meta/sgd.py
+++ b/torchopt/optim/meta/sgd.py
@@ -47,23 +47,20 @@ def __init__(
"""Initialize the meta-SGD optimizer.
Args:
- module: (nn.Module)
- A network whose parameters should be optimized.
- lr: This is a fixed global scaling factor.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- dampening: (default: :const:`0.0`)
- Dampening for momentum.
- nesterov: (default: :const:`False`)
- Whether to use Nesterov momentum.
- moment_requires_grad: (default: :data:`True`)
- If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
- flag is often used in Meta-Learning algorithms.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ module (nn.Module): A network whose parameters should be optimized.
+ lr (float or callable): This is a fixed global scaling factor or a learning rate
+ scheduler.
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is
+ not used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ dampening (float, optional): Dampening for momentum. (default: :const:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
"""
super().__init__(
module,
diff --git a/torchopt/optim/rmsprop.py b/torchopt/optim/rmsprop.py
index 9101984f..5c4e536f 100644
--- a/torchopt/optim/rmsprop.py
+++ b/torchopt/optim/rmsprop.py
@@ -52,30 +52,27 @@ def __init__(
r"""Initialize the RMSProp optimizer.
Args:
- params: (iterable of torch.Tensor)
- An iterable of :class:`torch.Tensor`\s. Specifies what Tensors should be optimized.
- lr: (default: :const:`1e-2`)
- This is a fixed global scaling factor.
- alpha: (default: :const:`0.99`)
- Smoothing constant, the decay used to track the magnitude of previous gradients.
- eps: (default: :const:`1e-8`)
- A small numerical constant to avoid dividing by zero when rescaling.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- centered: (default: :data:`False`)
- If :data:`True`, use the variance of the past gradients to rescale the latest
- gradients.
- initial_scale: (default: :data:`0.0`)
- Initialization of accumulators tracking the magnitude of previous updates. PyTorch
- uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When reproducing results from a
- paper, verify the value used by the authors.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+ tensors should be optimized.
+ lr (float or callable, optional): This is a fixed global scaling factor or a learning
+ rate scheduler. (default: :const:`1e-2`)
+ alpha (float, optional): Smoothing constant, the decay used to track the magnitude of
+ previous gradients. (default: :const:`0.99`)
+ eps (float, optional): A small numerical constant to avoid dividing by zero when
+ rescaling. (default: :const:`1e-8`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is
+ not used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ centered (bool, optional): If :data:`True`, use the variance of the past gradients to
+ rescale the latest gradients. (default: :data:`False`)
+ initial_scale (float, optional): Initialization of accumulators tracking the magnitude
+ of previous updates. PyTorch uses :data:`0.0`, TensorFlow 1.x uses :data:`1.0`. When
+ reproducing results from a paper, verify the value used by the authors.
+ (default: :data:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
"""
super().__init__(
params,
diff --git a/torchopt/optim/sgd.py b/torchopt/optim/sgd.py
index 223e856e..3da9595a 100644
--- a/torchopt/optim/sgd.py
+++ b/torchopt/optim/sgd.py
@@ -48,20 +48,21 @@ def __init__(
r"""Initialize the SGD optimizer.
Args:
- params: (iterable of torch.Tensor)
- An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized.
- lr: This is a fixed global scaling factor.
- momentum: (default: :const:`0.0`)
- The decay rate used by the momentum term. The momentum is not used when it is set to
- :const:`0.0`.
- weight_decay: (default: :const:`0.0`)
- Weight decay, add L2 penalty to parameters.
- dampening: (default: :const:`0.0`)
- Dampening for momentum.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- maximize: (default: :data:`False`)
- Maximize the params based on the objective, instead of minimizing.
+ params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+ tensors should be optimized.
+ lr (float or callable): This is a fixed global scaling factor or a learning rate
+ scheduler.
+ momentum (float, optional): The decay rate used by the momentum term. The momentum is
+ not used when it is set to :const:`0.0`. (default: :const:`0.0`)
+ weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+ (default: :const:`0.0`)
+ dampening (float, optional): Dampening for momentum. (default: :const:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+ with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+ algorithms. (default: :data:`False`)
+ maximize (bool, optional): Maximize the params based on the objective, instead of
+ minimizing. (default: :data:`False`)
"""
super().__init__(
params,
diff --git a/torchopt/pytree.py b/torchopt/pytree.py
index 0abcf4fd..d3b2d181 100644
--- a/torchopt/pytree.py
+++ b/torchopt/pytree.py
@@ -14,9 +14,11 @@
# ==============================================================================
"""The PyTree utilities."""
+from __future__ import annotations
+
import functools
import operator
-from typing import Callable, List, Optional, Tuple
+from typing import Callable
import optree
import optree.typing as typing # pylint: disable=unused-import
@@ -47,19 +49,20 @@
def tree_flatten_as_tuple(
tree: PyTree[T],
- is_leaf: Optional[Callable[[T], bool]] = None,
+ is_leaf: Callable[[T], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = '',
-) -> Tuple[Tuple[T, ...], PyTreeSpec]:
+) -> tuple[tuple[T, ...], PyTreeSpec]:
"""Flatten a pytree to a tuple of leaves and a PyTreeSpec.
Args:
- tree: The pytree to flatten.
- is_leaf: A function that returns :data:`True` if a given node is a leaf.
- none_is_leaf: If :data:`True`, None is considered a leaf rather than a internal node with no
- children.
- namespace: The namespace of custom tree node types.
+ tree (pytree): The pytree to flatten.
+ is_leaf (callable or None, optional): An optionally specified function that returns
+ :data:`True` if a given node is a leaf. (default: :data:`None`)
+ none_is_leaf (bool, optional): If :data:`True`, :data:`None` is considered a leaf rather
+ than a internal node with no children. (default: :data:`False`)
+ namespace (str, optional): The namespace of custom tree node types. (default: :const:`''`)
Returns:
A tuple of (leaves, treespec).
@@ -99,7 +102,7 @@ def tree_add(*trees: PyTree[T]) -> PyTree[T]:
def tree_add_scalar_mul(
- tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None
+ tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None
) -> TensorTree:
"""Compute ``tree_x + alpha * tree_y``."""
if alpha is None:
@@ -113,7 +116,7 @@ def tree_sub(minuend_tree: PyTree[T], subtrahend_tree: PyTree[T]) -> PyTree[T]:
def tree_sub_scalar_mul(
- tree_x: TensorTree, tree_y: TensorTree, alpha: Optional[Scalar] = None
+ tree_x: TensorTree, tree_y: TensorTree, alpha: Scalar | None = None
) -> TensorTree:
"""Compute ``tree_x - alpha * tree_y``."""
if alpha is None:
@@ -190,4 +193,4 @@ def tree_local_value(rref_tree: PyTree[RRef[T]]) -> PyTree[T]:
__all__.extend(['tree_as_rref', 'tree_to_here'])
-del Callable, List, Optional, Tuple, optree, rpc, Scalar, T, RRef
+del Callable, optree, rpc, Scalar, T, RRef
diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py
index 8a8e51e8..d54dbf17 100644
--- a/torchopt/schedule/polynomial.py
+++ b/torchopt/schedule/polynomial.py
@@ -52,18 +52,17 @@ def polynomial_schedule(
"""Construct a schedule with polynomial transition from init to end value.
Args:
- init_value: Initial value for the scalar to be annealed.
- end_value: End value of the scalar to be annealed.
- power: The power of the polynomial used to transition from ``init`` to ``end``.
- transition_steps:
- Number of steps over which annealing takes place, the scalar starts changing at
- ``transition_begin`` steps and completes the transition by
- ``transition_begin + transition_steps`` steps.
- If ``transition_steps <= 0``, then the entire annealing process is disabled and the
- value is held fixed at ``init_value``.
- transition_begin:
- Must be *positive*. After how many steps to start annealing (before this many steps the
- scalar value is held fixed at ``init_value``).
+ init_value (float or Tensor): Initial value for the scalar to be annealed.
+ end_value (float or Tensor): End value of the scalar to be annealed.
+ power (float or Tensor): The power of the polynomial used to transition from ``init`` to
+ ``end``.
+ transition_steps (int): Number of steps over which annealing takes place, the scalar starts
+ changing at ``transition_begin`` steps and completes the transition by
+ ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the
+ entire annealing process is disabled and the value is held fixed at ``init_value``.
+ transition_begin (int, optional): Must be *positive*. After how many steps to start
+ annealing (before this many steps the scalar value is held fixed at ``init_value``).
+ (default: :const:`0`)
Returns:
schedule:
diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py
index 772e6291..14745766 100644
--- a/torchopt/transform/add_decayed_weights.py
+++ b/torchopt/transform/add_decayed_weights.py
@@ -32,7 +32,9 @@
# ==============================================================================
"""Preset transformations for adding weight decay to updates."""
-from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
+from __future__ import annotations
+
+from typing import Any, Callable, NamedTuple
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation, identity
@@ -59,7 +61,7 @@ class MaskedNode(NamedTuple):
def masked(
inner: GradientTransformation,
- mask: Union[Any, Callable[[Params], Any]],
+ mask: OptState | Callable[[Params], OptState] | None = None,
) -> GradientTransformation:
"""Mask updates so only some are transformed, the rest are passed through.
@@ -75,11 +77,12 @@ def masked(
of :data:`True`.
Args:
- inner: Inner transformation to mask.
- mask: A tree with same structure as (or a prefix of) the params tree, or a Callable that
- returns such a tree given the params/updates. The leaves should be booleans, :data:`True`
- for leaves/subtrees you want to apply the transformation to, and :data:`False` for those
- you want to skip. The mask must be static for the gradient transformation to be jit-compilable.
+ inner (GradientTransformation): Inner transformation to mask.
+ mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a
+ prefix of) the params tree, or a function that returns such a tree given the
+ params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want
+ to apply the transformation to, and :data:`False` for those you want to skip.
+ (default: :data:`None`)
Returns:
A :class:`GradientTransformation` wrapping ``inner``.
@@ -89,14 +92,14 @@ def masked(
def _masked_flat(
inner: GradientTransformation,
- mask: Union[Any, Callable[[Params], Any]],
+ mask: OptState | Callable[[Params], OptState] | None = None,
) -> GradientTransformation:
return _masked(inner, mask, already_flattened=True)
def _masked(
inner: GradientTransformation,
- mask: Union[Any, Callable[[Params], Any]],
+ mask: OptState | Callable[[Params], OptState] | None = None,
*,
already_flattened: bool = False,
) -> GradientTransformation:
@@ -117,9 +120,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
mask_tree = mask(updates) if callable(mask) else mask
masked_updates = tree_mask(updates, mask_tree)
masked_params = None if params is None else tree_mask(params, mask_tree)
@@ -145,16 +148,17 @@ def update_fn(
def add_decayed_weights(
weight_decay: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
) -> GradientTransformation:
"""Add parameter scaled by `weight_decay`.
Args:
- weight_decay: a scalar weight decay rate.
- mask: a tree with same structure as (or a prefix of) the params tree, or a Callable that
- returns such a pytree given the params/updates. The leaves should be booleans,
- :data:`True` for leaves/subtrees you want to apply the transformation to, and
- :data:`False` for those you want to skip.
+ weight_decay (float, optional): A scalar weight decay rate. (default: :const:`0.0`)
+ mask (tree of Tensor, callable, or None, optional): A tree with same structure as (or a
+ prefix of) the params tree, or a function that returns such a tree given the
+ params/updates. The leaves should be booleans, :data:`True` for leaves/subtrees you want
+ to apply the transformation to, and :data:`False` for those you want to skip.
+ (default: :data:`None`)
Returns:
An (init_fn, update_fn) tuple.
@@ -168,7 +172,7 @@ def add_decayed_weights(
def _add_decayed_weights_flat(
weight_decay: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
) -> GradientTransformation:
return _add_decayed_weights(
weight_decay=weight_decay,
@@ -179,7 +183,7 @@ def _add_decayed_weights_flat(
def _add_decayed_weights(
weight_decay: float = 0.0,
- mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
+ mask: OptState | Callable[[Params], OptState] | None = None,
*,
already_flattened: bool = False,
) -> GradientTransformation:
@@ -204,9 +208,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None,
+ params: Params | None = None,
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
assert params is not None, (
'Parameters are required for weight decay. '
'Call `update(updates, state, params=params)` instead.'
diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py
index 2c0b9d5e..804f8219 100644
--- a/torchopt/transform/nan_to_num.py
+++ b/torchopt/transform/nan_to_num.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Preset transformations that replaces updates with non-finite values to the given numbers."""
-from typing import Optional, Tuple
+from __future__ import annotations
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation
@@ -23,8 +23,8 @@
def nan_to_num(
nan: float = 0.0,
- posinf: Optional[float] = None,
- neginf: Optional[float] = None,
+ posinf: float | None = None,
+ neginf: float | None = None,
) -> GradientTransformation:
"""Replace updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers.
@@ -39,9 +39,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
if inplace:
def f(g):
diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py
index 4afac163..639c903e 100644
--- a/torchopt/transform/scale.py
+++ b/torchopt/transform/scale.py
@@ -31,7 +31,7 @@
# ==============================================================================
"""Preset transformation for scaling updates by learning rate."""
-from typing import Optional, Tuple
+from __future__ import annotations
from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation
@@ -49,7 +49,7 @@ def scale(step_size: float) -> GradientTransformation:
"""Scale updates by some fixed scalar ``step_size``.
Args:
- step_size: A scalar corresponding to a fixed scaling factor for updates.
+ step_size (float): A scalar corresponding to a fixed scaling factor for updates.
Returns:
An ``(init_fn, update_fn)`` tuple.
@@ -80,9 +80,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
if inplace:
def f(g):
diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py
index 039d31fb..36f30be9 100644
--- a/torchopt/transform/scale_by_adam.py
+++ b/torchopt/transform/scale_by_adam.py
@@ -33,7 +33,9 @@
# pylint: disable=invalid-name
-from typing import NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
import torch
@@ -88,17 +90,17 @@ def scale_by_adam(
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
Args:
- b1: (default: :const:`0.9`)
- Decay rate for the exponentially weighted average of grads.
- b2: (default: :const:`0.999`)
- Decay rate for the exponentially weighted average of squared grads.
- eps: (default: :const:`1e-8`)
- Term added to the denominator to improve numerical stability.
- eps_root: (default: :const:`0.0`)
- Term added to the denominator inside the square-root to improve
+ b1 (float, optional): Decay rate for the exponentially weighted average of grads.
+ (default: :const:`0.9`)
+ b2 (float, optional): Decay rate for the exponentially weighted average of squared grads.
+ (default: :const:`0.999`)
+ eps (float, optional): Term added to the denominator to improve numerical stability.
+ (default: :const:`1e-8`)
+ eps_root (float, optional): Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True`, states will be created with flag `requires_grad = True`.
+ (default: :const:`0.0`)
+ moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
+ ``requires_grad = True``. (default: :data:`False`)
Returns:
An (init_fn, update_fn) tuple.
@@ -169,9 +171,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
mu = update_moment.impl( # type: ignore[attr-defined]
updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened
)
@@ -218,17 +220,17 @@ def scale_by_accelerated_adam(
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
Args:
- b1: (default: :const:`0.9`)
- Decay rate for the exponentially weighted average of grads.
- b2: (default: :const:`0.999`)
- Decay rate for the exponentially weighted average of squared grads.
- eps: (default: :const:`1e-8`)
- Term added to the denominator to improve numerical stability.
- eps_root: (default: :const:`0.0`)
- Term added to the denominator inside the square-root to improve
+ b1 (float, optional): Decay rate for the exponentially weighted average of grads.
+ (default: :const:`0.9`)
+ b2 (float, optional): Decay rate for the exponentially weighted average of squared grads.
+ (default: :const:`0.999`)
+ eps (float, optional): Term added to the denominator to improve numerical stability.
+ (default: :const:`1e-8`)
+ eps_root (float, optional): Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True`, states will be created with flag `requires_grad = True`.
+ (default: :const:`0.0`)
+ moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
+ ``requires_grad = True``. (default: :data:`False`)
Returns:
An (init_fn, update_fn) tuple.
@@ -285,9 +287,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
count_inc = inc_count.impl(updates, state.count, already_flattened=True) # type: ignore[attr-defined]
op = AdamOp(b1=b1, b2=b2, eps=eps, eps_root=eps_root, inplace=inplace)
@@ -303,9 +305,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
count_inc = inc_count.impl(updates, state.count, already_flattened=False) # type: ignore[attr-defined]
treespec = pytree.tree_structure(updates, none_is_leaf=True)
diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py
index 7a685f6b..7a0c8c20 100644
--- a/torchopt/transform/scale_by_rms.py
+++ b/torchopt/transform/scale_by_rms.py
@@ -31,7 +31,9 @@
# ==============================================================================
"""Preset transformations for scaling updates by exponential root mean-squared (RMS)."""
-from typing import NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
import torch
@@ -61,12 +63,11 @@ def scale_by_rms(
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Args:
- alpha: (default: :const:`0.9`)
- Decay rate for the exponentially weighted average of squared grads.
- eps: (default: :const:`1e-8`)
- Term added to the denominator to improve numerical stability.
- initial_scale: (default: :const:`0.0`)
- Initial value for second moment
+ alpha (float, optional): Decay rate for the exponentially weighted average of squared grads.
+ (default: :const:`0.9`)
+ eps (float, optional): Term added to the denominator to improve numerical stability.
+ (default: :const:`1e-8`)
+ initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`)
Returns:
An (init_fn, update_fn) tuple.
@@ -121,9 +122,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
nu = update_moment.impl( # type: ignore[attr-defined]
updates, state.nu, alpha, order=2, inplace=inplace, already_flattened=already_flattened
)
diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py
index 5556d111..d6e3b0fa 100644
--- a/torchopt/transform/scale_by_schedule.py
+++ b/torchopt/transform/scale_by_schedule.py
@@ -31,7 +31,9 @@
# ==============================================================================
"""Preset transformation for scaling updates by learning rate schedules."""
-from typing import NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
import torch
@@ -54,9 +56,8 @@ def scale_by_schedule(step_size_fn: Schedule) -> GradientTransformation:
"""Scale updates using a custom schedule for the ``step_size``.
Args:
- step_size_fn:
- A function that takes an update count as input and proposes the ``step_size`` to
- multiply the updates by.
+ step_size_fn (callable): A function that takes an update count as input and proposes the
+ ``step_size`` to multiply the updates by.
Returns:
An ``(init_fn, update_fn)`` tuple.
@@ -90,9 +91,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
if inplace:
def f(g, c): # pylint: disable=invalid-name
diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py
index c15a0d6c..228ed707 100644
--- a/torchopt/transform/scale_by_stddev.py
+++ b/torchopt/transform/scale_by_stddev.py
@@ -33,7 +33,9 @@
# pylint: disable=invalid-name
-from typing import NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
import torch
@@ -64,12 +66,11 @@ def scale_by_stddev(
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Args:
- alpha: (default: :const:`0.9`)
- Decay rate for the exponentially weighted average of squared grads.
- eps: (default: :const:`1e-8`)
- Term added to the denominator to improve numerical stability.
- initial_scale: (default: :const:`0.0`)
- Initial value for second moment
+ alpha (float, optional): Decay rate for the exponentially weighted average of squared grads.
+ (default: :const:`0.9`)
+ eps (float, optional): Term added to the denominator to improve numerical stability.
+ (default: :const:`1e-8`)
+ initial_scale (float, optional): Initial value for second moment. (default: :const:`0.0`)
Returns:
An (init_fn, update_fn) tuple.
@@ -125,9 +126,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
mu = update_moment.impl( # type: ignore[attr-defined]
updates, state.mu, alpha, order=1, inplace=inplace, already_flattened=already_flattened
)
diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py
index 45e043f0..03d2441d 100644
--- a/torchopt/transform/trace.py
+++ b/torchopt/transform/trace.py
@@ -33,7 +33,9 @@
# pylint: disable=invalid-name
-from typing import NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
import torch
@@ -65,14 +67,12 @@ def trace(
Both are frequently found in the optimization literature.
Args:
- momentum: (default: :const:`0.9`)
- The decay rate for the trace of past updates.
- dampening: (default: :const:`0.0`)
- Dampening for momentum.
- nesterov: (default: :data:`False`)
- Whether to use Nesterov momentum.
- moment_requires_grad: (default: :data:`False`)
- If :data:`True`, states will be created with flag `requires_grad = True`.
+ momentum (float, optional): The decay rate for the trace of past updates.
+ (default: :const:`0.9`)
+ dampening (float, optional): Dampening for momentum. (default: :const:`0.0`)
+ nesterov (bool, optional): Whether to use Nesterov momentum. (default: :data:`False`)
+ moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
+ ``requires_grad = True``. (default: :data:`False`)
Returns:
An (init_fn, update_fn) tuple.
@@ -139,9 +139,9 @@ def update_fn(
updates: Updates,
state: OptState,
*,
- params: Optional[Params] = None, # pylint: disable=unused-argument
+ params: Params | None = None, # pylint: disable=unused-argument
inplace: bool = True,
- ) -> Tuple[Updates, OptState]:
+ ) -> tuple[Updates, OptState]:
nonlocal first_call
if nesterov:
diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py
index a9f02295..77ba58ca 100644
--- a/torchopt/transform/utils.py
+++ b/torchopt/transform/utils.py
@@ -31,6 +31,8 @@
# ==============================================================================
"""Utilities for the preset transformations."""
+from __future__ import annotations
+
from collections import deque
from typing import Any, Callable, Sequence
diff --git a/torchopt/update.py b/torchopt/update.py
index 3fdd38e1..9485896b 100644
--- a/torchopt/update.py
+++ b/torchopt/update.py
@@ -48,11 +48,11 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) ->
:func:`tree_map` (e.g. if you want to manipulate updates in custom ways before applying them).
Args:
- params: A tree of parameters.
- updates:
- A tree of updates, the tree structure and the shape of the leaf nodes must match that
- of ``params``.
- inplace: If :data:`True`, will update params in a inplace manner.
+ params (tree of Tensor): A tree of parameters.
+ updates (tree of Tensor): A tree of updates, the tree structure and the shape of the leaf
+ nodes must match that of ``params``.
+ inplace (bool, optional): If :data:`True`, will update params in a inplace manner.
+ (default: :data:`True`)
Returns:
Updated parameters, with same structure, shape and type as ``params``.
diff --git a/torchopt/utils.py b/torchopt/utils.py
index 4deaba8b..12adb214 100644
--- a/torchopt/utils.py
+++ b/torchopt/utils.py
@@ -14,21 +14,11 @@
# ==============================================================================
"""Utilities for TorchOpt."""
+from __future__ import annotations
+
import copy
import itertools
-from typing import (
- TYPE_CHECKING,
- Dict,
- List,
- NamedTuple,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
- cast,
- overload,
-)
+from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload
from typing_extensions import Literal # Python 3.8+
from typing_extensions import TypeAlias # Python 3.10+
@@ -56,32 +46,30 @@
class ModuleState(NamedTuple):
"""Container for module state."""
- params: Tuple[Dict[str, torch.Tensor], ...]
- buffers: Tuple[Dict[str, torch.Tensor], ...]
- visual_contents: Optional[Dict] = None
+ params: tuple[dict[str, torch.Tensor], ...]
+ buffers: tuple[dict[str, torch.Tensor], ...]
+ visual_contents: dict | None = None
detach_buffers: bool = False
CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone']
-def stop_gradient(target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]) -> None:
+def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) -> None:
"""Stop the gradient for the input object.
- Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
+ Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the
backpropagated gradient will flow over the tensor and continue flow to the tensors that is
- connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the
+ connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the
computation graph.
Note that the :func:`stop_gradient` operation is in-place.
Args:
- target: The target that to be detached from the computation graph, it could be a
- :class:`nn.Module`, :class:`torchopt.MetaOptimizer`, state of the
- :class:`torchopt.MetaOptimizer`, or just a plain list of tensors.
- inplace: If :data:`True`, the target will be detached in-place. if :data:`Frue`, this
- function will return a detached copy of the target. The in-place operation is fast and
- memory efficient but may raise backpropagation error.
+ target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The target that to be
+ detached from the computation graph, it could be a :class:`nn.Module`,
+ :class:`torchopt.MetaOptimizer`, state of the :class:`torchopt.MetaOptimizer`, or just
+ a plain list of tensors.
"""
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
@@ -108,7 +96,7 @@ def extract_state_dict(
target: nn.Module,
*,
by: CopyMode = 'reference',
- device: Optional[Device] = None,
+ device: Device | None = None,
with_buffers: bool = True,
enable_visual: bool = False,
visual_prefix: str = '',
@@ -118,57 +106,62 @@ def extract_state_dict(
@overload
def extract_state_dict(
- target: 'MetaOptimizer',
+ target: MetaOptimizer,
*,
by: CopyMode = 'reference',
- device: Optional[Device] = None,
+ device: Device | None = None,
with_buffers: bool = True,
enable_visual: bool = False,
visual_prefix: str = '',
-) -> Tuple[OptState, ...]: # pragma: no cover
+) -> tuple[OptState, ...]: # pragma: no cover
...
# pylint: disable-next=too-many-branches,too-many-locals
def extract_state_dict(
- target: Union[nn.Module, 'MetaOptimizer'],
+ target: nn.Module | MetaOptimizer,
*,
by: CopyMode = 'reference',
- device: Optional[Device] = None,
+ device: Device | None = None,
with_buffers: bool = True,
detach_buffers: bool = False,
enable_visual: bool = False,
visual_prefix: str = '',
-) -> Union[ModuleState, Tuple[OptState, ...]]:
+) -> ModuleState | tuple[OptState, ...]:
"""Extract target state.
- Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
+ Since a tensor use ``grad_fn`` to connect itself with the previous computation graph, the
backpropagated gradient will flow over the tensor and continue flow to the tensors that is
- connected by :attr:`grad_fn`. Some algorithms requires manually detaching tensors from the
+ connected by ``grad_fn``. Some algorithms requires manually detaching tensors from the
computation graph.
Note that the extracted state is a reference, which means any in-place operator will affect the
target that the state is extracted from.
Args:
- target: It could be a :class:`nn.Module` or :class:`torchopt.MetaOptimizer`.
- by: The extract policy of tensors in the target.
+ target (nn.Module or MetaOptimizer): It could be a :class:`nn.Module` or
+ :class:`torchopt.MetaOptimizer`.
+ by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`)
- :const:`'reference'`: The extracted tensors will be references to the original
tensors.
- :const:`'copy'`: The extracted tensors will be clones of the original tensors. This
- makes the copied tensors have :attr:`grad_fn` to be a ```` function
- points to the original tensors.
+ makes the copied tensors have ``grad_fn`` to be a ```` function points
+ to the original tensors.
- :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original
tensors. The deep-copied tensors will detach from the original computation graph.
- device: If specified, move the extracted state to the specified device.
- with_buffers: Extract buffer together with parameters, this argument is only used if the
- input target is :class:`nn.Module`.
- detach_buffers: Whether to detach the reference to the buffers, this argument is only used
- if the input target is :class:`nn.Module` and ``by='reference'``.
- enable_visual: Add additional annotations, which could be used in computation graph
- visualization. Currently, this flag only has effect on :class:`nn.Module` but we will
- support :class:`torchopt.MetaOptimizer` later.
- visual_prefix: Prefix for the visualization annotations.
+ device (Device or None, optional): If specified, move the extracted state to the specified
+ device. (default: :const:`None`)
+ with_buffers (bool, optional): Extract buffer together with parameters, this argument is
+ only used if the input target is :class:`nn.Module`. (default: :const:`True`)
+ detach_buffers (bool, optional): Whether to detach the reference to the buffers, this
+ argument is only used if the input target is :class:`nn.Module` and ``by='reference'``.
+ (default: :const:`False`)
+ enable_visual (bool, optional): Add additional annotations, which could be used in
+ computation graph visualization. Currently, this flag only has effect on
+ :class:`nn.Module` but we will support :class:`torchopt.MetaOptimizer` later.
+ (default: :const:`False`)
+ visual_prefix (str, optional): Prefix for the visualization annotations.
+ (default: :const:`''`)
Returns:
State extracted of the input object.
@@ -228,9 +221,9 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
else:
visual_contents = None
- params: List[Dict[str, torch.Tensor]] = []
- buffers: List[Dict[str, torch.Tensor]] = []
- memo: Set[nn.Module] = set()
+ params: list[dict[str, torch.Tensor]] = []
+ buffers: list[dict[str, torch.Tensor]] = []
+ memo: set[nn.Module] = set()
def update_params(container):
if len(container) > 0:
@@ -287,12 +280,12 @@ def get_variable(t):
def extract_module_containers(
module: nn.Module, with_buffers: bool = True
-) -> Tuple[ModuleTensorContainers, ModuleTensorContainers]:
+) -> tuple[ModuleTensorContainers, ModuleTensorContainers]:
"""Extract the references to the containers of parameters and buffers from a module."""
if isinstance(module, nn.Module):
- params: List[TensorContainer] = []
- buffers: List[TensorContainer] = []
- memo: Set[nn.Module] = set()
+ params: list[TensorContainer] = []
+ buffers: list[TensorContainer] = []
+ memo: set[nn.Module] = set()
def update_container(container, items):
if len(items) > 0:
@@ -316,8 +309,8 @@ def update_container(container, items):
def recover_state_dict(
- target: Union[nn.Module, 'MetaOptimizer'],
- state: Union[ModuleState, Sequence[OptState]],
+ target: nn.Module | MetaOptimizer,
+ state: ModuleState | Sequence[OptState],
) -> None:
"""Recover state.
@@ -327,8 +320,8 @@ def recover_state_dict(
modified.
Args:
- target: Target that need to recover.
- state: The recovering state.
+ target (nn.Module or MetaOptimizer): Target that need to recover.
+ state (ModuleState or sequence of tree of Tensor): The recovering state.
"""
# pylint: disable-next=import-outside-toplevel
from torchopt.optim.meta.base import MetaOptimizer
@@ -344,10 +337,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
return nn.Parameter(t.clone().detach_(), requires_grad=t.requires_grad)
return t.clone().detach_().requires_grad_(t.requires_grad)
- buffers = cast(
- Tuple[Dict[str, torch.Tensor], ...],
- pytree.tree_map(clone_detach_, buffers), # type: ignore[arg-type]
- )
+ buffers = pytree.tree_map(clone_detach_, buffers) # type: ignore[assignment,arg-type]
for tgt, src in itertools.chain(
zip(params_containers, params),
@@ -367,19 +357,19 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Device] = None,
+ device: Device | None = None,
) -> nn.Module: # pragma: no cover
...
@overload
def module_clone(
- target: 'MetaOptimizer',
+ target: MetaOptimizer,
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Device] = None,
-) -> 'MetaOptimizer': # pragma: no cover
+ device: Device | None = None,
+) -> MetaOptimizer: # pragma: no cover
...
@@ -389,34 +379,36 @@ def module_clone(
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Device] = None,
+ device: Device | None = None,
) -> TensorTree: # pragma: no cover
...
# pylint: disable-next=too-many-locals
def module_clone(
- target: Union[nn.Module, 'MetaOptimizer', TensorTree],
+ target: nn.Module | MetaOptimizer | TensorTree,
*,
by: CopyMode = 'reference',
detach_buffers: bool = False,
- device: Optional[Device] = None,
-) -> Union[nn.Module, 'MetaOptimizer', TensorTree]:
+ device: Device | None = None,
+) -> nn.Module | MetaOptimizer | TensorTree:
"""Clone a module.
Args:
- target: The target to be cloned.
- by: The extract policy of tensors in the target.
+ target (nn.Module, MetaOptimizer, or tree of Tensor): The target to be cloned.
+ by (str, optional): The extract policy of tensors in the target. (default: :const:`'reference'`)
- :const:`'reference'`: The extracted tensors will be references to the original
tensors.
- :const:`'copy'`: The extracted tensors will be clones of the original tensors. This
- makes the copied tensors have :attr:`grad_fn` to be a ```` function
- points to the original tensors.
+ makes the copied tensors have ``grad_fn`` to be a ```` function points
+ to the original tensors.
- :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original
tensors. The deep-copied tensors will detach from the original computation graph.
- detach_buffers: Whether to detach the reference to the buffers, this argument is only used
- if the input target is :class:`nn.Module` and ``by='reference'``.
- device: If specified, move the cloned module to the specified device.
+ detach_buffers (bool, optional): Whether to detach the reference to the buffers, this
+ argument is only used if the input target is :class:`nn.Module` and ``by='reference'``.
+ (default: :const:`False`)
+ device (Device or None, optional): If specified, move the cloned module to the specified
+ device. (default: :const:`None`)
Returns:
The cloned module.
@@ -499,7 +491,7 @@ def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover
@overload
-def module_detach_(target: 'MetaOptimizer') -> 'MetaOptimizer': # pragma: no cover
+def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover
...
@@ -509,12 +501,13 @@ def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover
def module_detach_(
- target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]
-) -> Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]:
+ target: ModuleState | nn.Module | MetaOptimizer | TensorTree,
+) -> ModuleState | nn.Module | MetaOptimizer | TensorTree:
"""Detach a module from the computation graph.
Args:
- target: The target to be detached.
+ target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The
+ target to be detached.
Returns:
The detached module.
diff --git a/torchopt/visual.py b/torchopt/visual.py
index e8145240..7afe65a4 100644
--- a/torchopt/visual.py
+++ b/torchopt/visual.py
@@ -17,8 +17,10 @@
# ==============================================================================
"""Computation graph visualization."""
+from __future__ import annotations
+
from collections import namedtuple
-from typing import Generator, Iterable, Mapping, Optional, Union, cast
+from typing import Generator, Iterable, Mapping, cast
import torch
from graphviz import Digraph
@@ -71,14 +73,13 @@ def truncate(s): # pylint: disable=invalid-name
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def make_dot(
var: TensorOrTensors,
- params: Optional[
- Union[
- Mapping[str, torch.Tensor],
- ModuleState,
- Generator,
- Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]],
- ]
- ] = None,
+ params: (
+ Mapping[str, torch.Tensor]
+ | ModuleState
+ | Generator
+ | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator]
+ | None
+ ) = None,
show_attrs: bool = False,
show_saved: bool = False,
max_attr_chars: int = 50,
@@ -89,7 +90,7 @@ def make_dot(
and is either blue, orange, or green:
- **Blue**
- Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be
+ Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be
populated during :meth:`backward`).
- **Orange**
Saved tensors of custom autograd functions as well as those saved by built-in backward
@@ -100,16 +101,16 @@ def make_dot(
If any output is a view, we represent its base tensor with a dark green node.
Args:
- var: Output tensor.
- params: ([dict of (name, tensor) or state_dict])
- Parameters to add names to node that requires grad.
- show_attrs: Whether to display non-tensor attributes of backward nodes
- (Requires PyTorch version >= 1.9)
- show_saved: Whether to display saved tensor nodes that are not by custom autograd
- functions. Saved tensor nodes for custom functions, if present, are always displayed.
- (Requires PyTorch version >= 1.9)
- max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display
- for any given attribute.
+ var (Tensor or sequence of Tensor): Output tensor.
+ params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional):
+ Parameters to add names to node that requires grad. (default: :data:`None`)
+ show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes.
+ (default: :data:`False`)
+ show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom
+ autograd functions. Saved tensor nodes for custom functions, if present, are always
+ displayed. (default: :data:`False`)
+ max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of
+ characters to display for any given attribute. (default: :const:`50`)
"""
param_map = {}