From 6fa85d6de8a4739f1bf680c37b0dee2878fe44d6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 20 Feb 2023 00:46:25 +0800 Subject: [PATCH] feat: add `ruff` integration (#139) --- .github/workflows/lint.yml | 4 + .pre-commit-config.yaml | 6 ++ CHANGELOG.md | 2 +- Makefile | 16 +++- conda-recipe.yaml | 1 + docs/source/conf.py | 7 +- docs/source/developer/contributing.rst | 2 +- examples/FuncTorch/maml_omniglot_vmap.py | 2 +- examples/FuncTorch/parallel_train_torchopt.py | 2 - examples/L2R/helpers/model.py | 2 +- examples/L2R/helpers/utils.py | 9 +-- examples/L2R/l2r.py | 12 +-- examples/LOLA/helpers/utils.py | 4 +- examples/LOLA/lola_dice.py | 7 +- examples/MAML-RL/func_maml.py | 14 ++-- examples/MAML-RL/helpers/policy_torchrl.py | 2 - examples/MAML-RL/helpers/tabular_mdp.py | 6 +- examples/MAML-RL/maml.py | 20 ++--- examples/MAML-RL/maml_torchrl.py | 26 +++---- examples/MGRL/mgrl.py | 2 +- .../few-shot/helpers/omniglot_loaders.py | 29 +++---- examples/few-shot/helpers/omniglot_loaders.py | 29 +++---- examples/few-shot/maml_omniglot.py | 2 +- examples/iMAML/helpers/omniglot_loaders.py | 29 +++---- examples/iMAML/imaml_omniglot.py | 2 +- examples/iMAML/imaml_omniglot_functional.py | 2 +- pyproject.toml | 75 +++++++++++++++++++ setup.py | 4 +- tests/requirements.txt | 1 + torchopt/accelerated_op/__init__.py | 2 +- torchopt/alias/adam.py | 11 +-- torchopt/alias/adamw.py | 11 +-- torchopt/alias/rmsprop.py | 15 ++-- torchopt/alias/sgd.py | 6 +- torchopt/alias/utils.py | 34 +++++---- torchopt/base.py | 2 +- torchopt/clip.py | 4 +- torchopt/combine.py | 10 +-- torchopt/diff/implicit/decorator.py | 16 ++-- torchopt/diff/implicit/nn/module.py | 30 ++++---- torchopt/diff/zero_order/__init__.py | 7 +- torchopt/diff/zero_order/decorator.py | 34 +++++---- torchopt/diff/zero_order/nn/module.py | 14 ++-- torchopt/distributed/api.py | 5 +- torchopt/distributed/autograd.py | 2 +- torchopt/distributed/world.py | 6 +- torchopt/hook.py | 4 +- torchopt/linalg/cg.py | 10 +-- torchopt/linalg/ns.py | 2 +- torchopt/linear_solve/cg.py | 8 +- torchopt/linear_solve/inv.py | 8 +- torchopt/linear_solve/normal_cg.py | 13 ++-- torchopt/nn/module.py | 6 +- torchopt/nn/stateless.py | 10 +-- torchopt/optim/base.py | 6 +- torchopt/optim/func/base.py | 3 +- torchopt/transform/add_decayed_weights.py | 12 +-- torchopt/transform/nan_to_num.py | 6 +- torchopt/transform/scale.py | 6 +- torchopt/transform/scale_by_adam.py | 20 +++-- torchopt/transform/scale_by_rms.py | 8 +- torchopt/transform/scale_by_schedule.py | 6 +- torchopt/transform/scale_by_stddev.py | 8 +- torchopt/transform/trace.py | 16 ++-- torchopt/transform/utils.py | 16 ++-- torchopt/update.py | 8 +- torchopt/utils.py | 30 ++++---- torchopt/visual.py | 32 ++++---- 68 files changed, 408 insertions(+), 358 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 55dee661..009589c7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -71,6 +71,10 @@ jobs: run: | make pre-commit + - name: ruff + run: | + make ruff + - name: flake8 run: | make flake8 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d7c1c7f0..4e5ddd14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,12 @@ repos: hooks: - id: clang-format stages: [commit, push, manual] + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.247 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + stages: [commit, push, manual] - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 927cb1db..dd9bc4bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). ### Changed diff --git a/Makefile b/Makefile index 93bb53e7..750b9d9f 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,9 @@ py-format-install: $(call check_pip_install,isort) $(call check_pip_install_extra,black,black[jupyter]) +ruff-install: + $(call check_pip_install,ruff) + mypy-install: $(call check_pip_install,mypy) @@ -79,7 +82,7 @@ docs-install: $(call check_pip_install,sphinxcontrib-bibtex) $(call check_pip_install,sphinx-autodoc-typehints) $(call check_pip_install,myst-nb) - $(call check_pip_install_extra,sphinxcontrib.spelling,sphinxcontrib.spelling pyenchant) + $(call check_pip_install_extra,sphinxcontrib-spelling,sphinxcontrib-spelling pyenchant) pytest-install: $(call check_pip_install,pytest) @@ -132,6 +135,12 @@ py-format: py-format-install $(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \ $(PYTHON) -m black --check $(PYTHON_FILES) tutorials +ruff: ruff-install + $(PYTHON) -m ruff check . + +ruff-fix: ruff-install + $(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix + mypy: mypy-install $(PYTHON) -m mypy $(PROJECT_PATH) @@ -181,11 +190,12 @@ clean-docs: # Utility functions -lint: flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling +lint: ruff flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling -format: py-format-install clang-format-install addlicense-install +format: py-format-install ruff-install clang-format-install addlicense-install $(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES) $(PYTHON) -m black $(PYTHON_FILES) tutorials + $(PYTHON) -m ruff check . --fix --exit-zero $(CLANG_FORMAT) -style=file -i $(CXX_FILES) $(CUDA_FILES) addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS) diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 939d43e0..48b6b164 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -94,6 +94,7 @@ dependencies: - flake8-docstrings - flake8-pyi - flake8-simplify + - ruff - doc8 - pydocstyle - clang-format >= 14 diff --git a/docs/source/conf.py b/docs/source/conf.py index d8233da7..f5d206c7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,6 +30,7 @@ import pathlib import sys +import sphinx import sphinxcontrib.katex as katex @@ -39,7 +40,7 @@ def get_version() -> str: sys.path.insert(0, str(PROJECT_ROOT / 'torchopt')) - import version # noqa + import version return version.__version__ @@ -51,7 +52,7 @@ def get_version() -> str: else: class RecursiveForwardRefFilter(logging.Filter): - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: if ( "name 'TensorTree' is not defined" in record.getMessage() or "name 'OptionalTensorTree' is not defined" in record.getMessage() @@ -191,7 +192,7 @@ def filter(self, record): html_logo = '_static/images/logo.png' -def setup(app): +def setup(app: sphinx.application.Sphinx) -> None: app.add_js_file('https://cdn.jsdelivr.net/npm/vega@5.20.2') app.add_js_file('https://cdn.jsdelivr.net/npm/vega-lite@5.1.0') app.add_js_file('https://cdn.jsdelivr.net/npm/vega-embed@6.17.0') diff --git a/docs/source/developer/contributing.rst b/docs/source/developer/contributing.rst index ee66f560..ce11e20e 100644 --- a/docs/source/developer/contributing.rst +++ b/docs/source/developer/contributing.rst @@ -51,7 +51,7 @@ Lint Check We use several tools to secure code quality, including: - * PEP8 code style: ``black``, ``isort``, ``pylint``, ``flake8`` + * Python code style: ``black``, ``isort``, ``pylint``, ``flake8``, ``ruff`` * Type hint check: ``mypy`` * C++ Google-style: ``cpplint``, ``clang-format``, ``clang-tidy`` * License: ``addlicense`` diff --git a/examples/FuncTorch/maml_omniglot_vmap.py b/examples/FuncTorch/maml_omniglot_vmap.py index 0933b44d..89a5a505 100644 --- a/examples/FuncTorch/maml_omniglot_vmap.py +++ b/examples/FuncTorch/maml_omniglot_vmap.py @@ -224,7 +224,7 @@ def test(db, net, device, epoch, log): qry_losses = [] qry_accs = [] - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num, setsz, c_, h, w = x_spt.size() diff --git a/examples/FuncTorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py index 640763cb..d762b986 100644 --- a/examples/FuncTorch/parallel_train_torchopt.py +++ b/examples/FuncTorch/parallel_train_torchopt.py @@ -15,8 +15,6 @@ import argparse import math -from collections import namedtuple -from typing import Any, NamedTuple import functorch import torch diff --git a/examples/L2R/helpers/model.py b/examples/L2R/helpers/model.py index 80fae8ac..deaf061c 100644 --- a/examples/L2R/helpers/model.py +++ b/examples/L2R/helpers/model.py @@ -35,7 +35,7 @@ class LeNet5(nn.Module): def __init__(self, args): - super(LeNet5, self).__init__() + super().__init__() self.model = nn.Sequential( nn.Conv2d(1, 16, 5), nn.ReLU(), diff --git a/examples/L2R/helpers/utils.py b/examples/L2R/helpers/utils.py index fe923860..7e95ca6f 100644 --- a/examples/L2R/helpers/utils.py +++ b/examples/L2R/helpers/utils.py @@ -89,16 +89,10 @@ def get_imbalance_dataset( y_val_subset = np.concatenate([np.zeros([x_val_0.shape[0]]), np.ones([x_val_1.shape[0]])]) y_test_subset = np.concatenate([np.zeros([x_test_0.shape[0]]), np.ones([x_test_1.shape[0]])]) - y_train_pos_subset = np.ones([x_train_1.shape[0]]) - y_train_neg_subset = np.zeros([x_train_0.shape[0]]) - x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:, None, :, :] x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :] x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :] - x_train_pos_subset = x_train_1[:, None, :, :] - x_train_neg_subset = x_train_0[:, None, :, :] - # Final shuffle. idx = np.arange(x_train_subset.shape[0]) np.random.shuffle(idx) @@ -146,7 +140,7 @@ def set_seed(seed, cudnn=True): torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) # note: the below slows down the code but makes it reproducible - # Sets the seed for generating random numbers on all GPUs. It’s safe to + # Sets the seed for generating random numbers on all GPUs. It's safe to # call this function if CUDA is not available; in that case, it is # silently ignored. torch.cuda.manual_seed_all(seed) @@ -157,7 +151,6 @@ def set_seed(seed, cudnn=True): def plot(baseline, l2r): import matplotlib.pyplot as plt - import numpy as np import seaborn as sns sns.set(style='darkgrid') diff --git a/examples/L2R/l2r.py b/examples/L2R/l2r.py index 5ce4839d..8866ea9b 100644 --- a/examples/L2R/l2r.py +++ b/examples/L2R/l2r.py @@ -51,14 +51,13 @@ def run_baseline(args, mnist_train, mnist_test): ntest = args.ntest epoch = args.epoch - folder = './result/baseline/' writer = SummaryWriter('./result/baseline') with open('./result/baseline/config.json', 'w') as f: json.dump(args.__dict__, f) args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - train_set, val_set, test_set = get_imbalance_dataset( + train_set, _, test_set = get_imbalance_dataset( mnist_train, mnist_test, pos_ratio=pos_ratio, @@ -67,7 +66,6 @@ def run_baseline(args, mnist_train, mnist_test): ntest=ntest, ) train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4) - valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1) model = LeNet5(args).to(args.device) @@ -91,7 +89,7 @@ def run_baseline(args, mnist_train, mnist_test): if step % 10 == 0 and step > 0: running_train_mean = np.mean(np.array(running_train_loss)) - print('EPOCH: {}, BATCH: {}, LOSS: {}'.format(_epoch, idx, running_train_mean)) + print(f'EPOCH: {_epoch}, BATCH: {idx}, LOSS: {running_train_mean}') writer.add_scalar('running_train_loss', running_train_mean, step) running_train_loss = [] @@ -106,7 +104,7 @@ def run_baseline(args, mnist_train, mnist_test): writer.add_scalar('train_acc', train_acc, _epoch) writer.add_scalar('test_acc', test_acc, _epoch) test_acc_result.append(test_acc) - print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc)) + print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}') return test_acc_result @@ -120,7 +118,6 @@ def run_L2R(args, mnist_train, mnist_test): ntest = args.ntest epoch = args.epoch - folder = './result/l2r/' writer = SummaryWriter('./result/l2r/log') with open('./result/l2r/config.json', 'w') as f: json.dump(args.__dict__, f) @@ -143,7 +140,6 @@ def run_L2R(args, mnist_train, mnist_test): real_model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) step = 0 - time_bp = 0 running_valid_loss = [] valid = iter(valid_loader) running_train_loss = [] @@ -222,7 +218,7 @@ def run_L2R(args, mnist_train, mnist_test): writer.add_scalar('train_acc', train_acc, _epoch) writer.add_scalar('test_acc', test_acc, _epoch) test_acc_result.append(test_acc) - print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc)) + print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}') return test_acc_result diff --git a/examples/LOLA/helpers/utils.py b/examples/LOLA/helpers/utils.py index afa9e609..149ce70e 100644 --- a/examples/LOLA/helpers/utils.py +++ b/examples/LOLA/helpers/utils.py @@ -27,7 +27,7 @@ def step(ipd, theta1, theta2, values1, values2, args): (s1, s2), _ = ipd.reset() score1 = 0 score2 = 0 - for t in range(args.len_rollout): + for _ in range(args.len_rollout): a1, lp1, v1 = act(s1, theta1, values1) a2, lp2, v2 = act(s2, theta2, values2) (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) @@ -109,7 +109,7 @@ def sample(ipd, policy, value, args): (s1, s2), _ = ipd.reset() memory_agent1 = Memory(args) memory_agent2 = Memory(args) - for t in range(args.len_rollout): + for _ in range(args.len_rollout): a1, lp1, v1 = act(s1, theta1, value1) a2, lp2, v2 = act(s2, theta2, value2) (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py index 20c0ff0e..6dbaaf24 100644 --- a/examples/LOLA/lola_dice.py +++ b/examples/LOLA/lola_dice.py @@ -96,17 +96,16 @@ def main(args): score = step(ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args) joint_scores.append(0.5 * (score[0] + score[1])) - # print if update % 10 == 0: p1 = [p.item() for p in torch.sigmoid(agent1.theta)] p2 = [p.item() for p in torch.sigmoid(agent2.theta)] print( 'update', update, - 'score (%.3f,%.3f)' % (score[0], score[1]), + f'score ({score[0]:.3f},{score[1]:.3f})', 'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p1[0], p1[1], p1[2], p1[3], p1[4]), - ' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4]), + f' (agent2) = {{{p2[0]:.3f}, {p2[1]:.3f}, {p2[2]:.3f}, {p2[3]:.3f}, {p2[4]:.3f}}}', ) return joint_scores @@ -114,7 +113,7 @@ def main(args): if __name__ == '__main__': args = parse_args() - joint_score = dict() + joint_score = {} for nla in range(3): args.n_lookaheads = nla joint_score[nla] = main(args) diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index 2534caeb..02f6792d 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -103,9 +103,10 @@ def evaluate(env, seed, task_num, fpolicy, params): inner_opt = torchopt.MetaSGD(lr=0.5) env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) tasks = env.sample_tasks(num_tasks=task_num) @@ -131,9 +132,10 @@ def main(args): # Env env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) # Policy policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) diff --git a/examples/MAML-RL/helpers/policy_torchrl.py b/examples/MAML-RL/helpers/policy_torchrl.py index 103a4ec5..91bdb269 100644 --- a/examples/MAML-RL/helpers/policy_torchrl.py +++ b/examples/MAML-RL/helpers/policy_torchrl.py @@ -13,9 +13,7 @@ # limitations under the License. # ============================================================================== -import torch import torch.nn as nn -from torch.distributions import Categorical from torchrl.modules import ( ActorValueOperator, OneHotCategorical, diff --git a/examples/MAML-RL/helpers/tabular_mdp.py b/examples/MAML-RL/helpers/tabular_mdp.py index 3a6bee60..3310cd1e 100644 --- a/examples/MAML-RL/helpers/tabular_mdp.py +++ b/examples/MAML-RL/helpers/tabular_mdp.py @@ -93,7 +93,6 @@ def reset_task(self, task): self._rewards_mean = task['rewards_mean'] def reset(self): - # From [1]: "an episode always starts on the first state" self._state = 0 observation = np.zeros(self.num_states, dtype=np.float32) observation[self._state] = 1.0 @@ -112,8 +111,5 @@ def step(self, action): observation = np.zeros(self.num_states, dtype=np.float32) observation[self._state] = 1.0 self._elapsed_steps += 1 - if self._elapsed_steps >= self.max_episode_steps: - done = True - else: - done = False + done = self._elapsed_steps >= self.max_episode_steps return observation, reward, done, {'task': self._task} diff --git a/examples/MAML-RL/maml.py b/examples/MAML-RL/maml.py index d4aa8c3c..336cb4a5 100644 --- a/examples/MAML-RL/maml.py +++ b/examples/MAML-RL/maml.py @@ -108,12 +108,10 @@ def evaluate(env, seed, task_num, policy): inner_opt = torchopt.MetaSGD(policy, lr=0.1) env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - seed=args.seed, - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) tasks = env.sample_tasks(num_tasks=task_num) policy_state_dict = torchopt.extract_state_dict(policy) @@ -141,12 +139,10 @@ def main(args): # Env env = gym.make( 'TabularMDP-v0', - **dict( - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - seed=args.seed, - ), + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + seed=args.seed, ) # Policy policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM) diff --git a/examples/MAML-RL/maml_torchrl.py b/examples/MAML-RL/maml_torchrl.py index 3cb72b49..8f68c89c 100644 --- a/examples/MAML-RL/maml_torchrl.py +++ b/examples/MAML-RL/maml_torchrl.py @@ -14,9 +14,7 @@ # ============================================================================== import argparse -import time -import numpy as np import torch import torch.optim as optim import tqdm @@ -60,8 +58,6 @@ def a2c_loss(traj, policy_module, value_module, value_coef): next_traj = step_tensordict(traj) next_value = value_module(next_traj).get('state_value').detach() - # tderror = TDEstimate(GAMMA, value_module, gradient_mode=True) - # tderror = TDLambdaEstimate(GAMMA, LAMBDA, value_module, gradient_mode=True) advantage = td_lambda_advantage_estimate(GAMMA, LAMBDA, value, next_value, reward, done) action_loss = -(advantage.detach() * log_probs.view_as(advantage)).mean() value_error = advantage @@ -131,14 +127,17 @@ def main(args): # init training torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) + # Env - lambda_env = lambda: GymEnv( - 'TabularMDP-v0', - num_states=STATE_DIM, - num_actions=ACTION_DIM, - max_episode_steps=TRAJ_LEN, - device=device, - ) + def lambda_env(): + return GymEnv( + 'TabularMDP-v0', + num_states=STATE_DIM, + num_actions=ACTION_DIM, + max_episode_steps=TRAJ_LEN, + device=device, + ) + if args.parallel: env = ParallelEnv( NUM_ENVS, @@ -171,8 +170,7 @@ def main(args): dummy_env.set_seed(args.seed) pbar = tqdm.tqdm(range(outer_iters)) - for i in pbar: - # print("i: ", i) + for _ in pbar: tasks = dummy_env.sample_tasks(num_tasks=TASK_NUM) train_pre_reward_ls = [] train_post_reward_ls = [] @@ -184,7 +182,7 @@ def main(args): env.reset_task(tasks[idx]) policy_module = actor_critic_module.get_policy_operator() value_module = actor_critic_module.get_value_operator() - for k in range(inner_iters): + for __ in range(inner_iters): with set_exploration_mode('random'), torch.no_grad(): pre_traj_td = ( env.rollout( diff --git a/examples/MGRL/mgrl.py b/examples/MGRL/mgrl.py index 152e4177..49eb79c4 100644 --- a/examples/MGRL/mgrl.py +++ b/examples/MGRL/mgrl.py @@ -55,7 +55,7 @@ def forward(self, x): meta_optimizer = torchopt.SGD([gamma], lr=5e-1) net_state = torchopt.extract_state_dict(net) for i in range(outer_iters): - for j in range(inner_iters): + for _ in range(inner_iters): trajectory, state = Rollout.get() backup = Rollout.rollout(trajectory, torch.sigmoid(gamma)) pred_value = net(state.float()) diff --git a/examples/distributed/few-shot/helpers/omniglot_loaders.py b/examples/distributed/few-shot/helpers/omniglot_loaders.py index e8f02042..8ff7e0c2 100644 --- a/examples/distributed/few-shot/helpers/omniglot_loaders.py +++ b/examples/distributed/few-shot/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for root, dirs, files in os.walk(root_dir): + for root, _, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -171,7 +171,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} for img, label in self.x: - if label in temp.keys(): + if label in temp: temp[label].append(img) else: temp[label] = [img] @@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') - # [1623, 20, 84, 84, 1] - # TODO: can not shuffle here, we must keep training and test set distinct! + # NOTE: do not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] - # self.normalization() - self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way @@ -230,7 +227,6 @@ def normalization(self): self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std @@ -239,8 +235,6 @@ def normalization(self): self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) - def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning @@ -253,10 +247,9 @@ def load_data_cache(self, data_pack): querysz = self.k_query * self.n_way data_cache = [] - # print('preload next 50 caches of batchsz of batch.') - for sample in range(10): # num of episodes + for _sample in range(10): # num of episodes x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] - for i in range(self.batchsz): # one batch means one set + for _ in range(self.batchsz): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) @@ -287,20 +280,20 @@ def load_data_cache(self, data_pack): x_qrys.append(x_qry) y_qrys.append(y_qry) - # [b, setsz, 1, 84, 84] x_spts = np.array(x_spts, dtype=np.float32).reshape( self.batchsz, setsz, 1, self.resize, self.resize - ) - y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) - # [b, qrysz, 1, 84, 84] + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, setsz + ) # [b, qrysz, 1, 84, 84] x_qrys = np.array(x_qrys, dtype=np.float32).reshape( self.batchsz, querysz, 1, self.resize, self.resize ) y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) - x_spts, y_spts, x_qrys, y_qrys = [ + x_spts, y_spts, x_qrys, y_qrys = ( torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] - ] + ) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) diff --git a/examples/few-shot/helpers/omniglot_loaders.py b/examples/few-shot/helpers/omniglot_loaders.py index e8f02042..8ff7e0c2 100644 --- a/examples/few-shot/helpers/omniglot_loaders.py +++ b/examples/few-shot/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for root, dirs, files in os.walk(root_dir): + for root, _, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -171,7 +171,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} for img, label in self.x: - if label in temp.keys(): + if label in temp: temp[label].append(img) else: temp[label] = [img] @@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') - # [1623, 20, 84, 84, 1] - # TODO: can not shuffle here, we must keep training and test set distinct! + # NOTE: do not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] - # self.normalization() - self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way @@ -230,7 +227,6 @@ def normalization(self): self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std @@ -239,8 +235,6 @@ def normalization(self): self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) - def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning @@ -253,10 +247,9 @@ def load_data_cache(self, data_pack): querysz = self.k_query * self.n_way data_cache = [] - # print('preload next 50 caches of batchsz of batch.') - for sample in range(10): # num of episodes + for _sample in range(10): # num of episodes x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] - for i in range(self.batchsz): # one batch means one set + for _ in range(self.batchsz): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) @@ -287,20 +280,20 @@ def load_data_cache(self, data_pack): x_qrys.append(x_qry) y_qrys.append(y_qry) - # [b, setsz, 1, 84, 84] x_spts = np.array(x_spts, dtype=np.float32).reshape( self.batchsz, setsz, 1, self.resize, self.resize - ) - y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) - # [b, qrysz, 1, 84, 84] + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, setsz + ) # [b, qrysz, 1, 84, 84] x_qrys = np.array(x_qrys, dtype=np.float32).reshape( self.batchsz, querysz, 1, self.resize, self.resize ) y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) - x_spts, y_spts, x_qrys, y_qrys = [ + x_spts, y_spts, x_qrys, y_qrys = ( torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] - ] + ) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) diff --git a/examples/few-shot/maml_omniglot.py b/examples/few-shot/maml_omniglot.py index 17172bdd..2e031c76 100644 --- a/examples/few-shot/maml_omniglot.py +++ b/examples/few-shot/maml_omniglot.py @@ -204,7 +204,7 @@ def test(db, net, epoch, log): qry_losses = [] qry_accs = [] - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num = x_spt.size(0) diff --git a/examples/iMAML/helpers/omniglot_loaders.py b/examples/iMAML/helpers/omniglot_loaders.py index e8f02042..8ff7e0c2 100644 --- a/examples/iMAML/helpers/omniglot_loaders.py +++ b/examples/iMAML/helpers/omniglot_loaders.py @@ -118,7 +118,7 @@ def download(self): def find_classes(root_dir): retour = [] - for root, dirs, files in os.walk(root_dir): + for root, _, files in os.walk(root_dir): for f in files: if f.endswith('png'): r = root.split('/') @@ -171,7 +171,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} temp = {} for img, label in self.x: - if label in temp.keys(): + if label in temp: temp[label].append(img) else: temp[label] = [img] @@ -196,12 +196,9 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=Non self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') - # [1623, 20, 84, 84, 1] - # TODO: can not shuffle here, we must keep training and test set distinct! + # NOTE: do not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] - # self.normalization() - self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way @@ -230,7 +227,6 @@ def normalization(self): self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std @@ -239,8 +235,6 @@ def normalization(self): self.max = np.max(self.x_train) self.min = np.min(self.x_train) - # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) - def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning @@ -253,10 +247,9 @@ def load_data_cache(self, data_pack): querysz = self.k_query * self.n_way data_cache = [] - # print('preload next 50 caches of batchsz of batch.') - for sample in range(10): # num of episodes + for _sample in range(10): # num of episodes x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] - for i in range(self.batchsz): # one batch means one set + for _ in range(self.batchsz): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) @@ -287,20 +280,20 @@ def load_data_cache(self, data_pack): x_qrys.append(x_qry) y_qrys.append(y_qry) - # [b, setsz, 1, 84, 84] x_spts = np.array(x_spts, dtype=np.float32).reshape( self.batchsz, setsz, 1, self.resize, self.resize - ) - y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) - # [b, qrysz, 1, 84, 84] + ) # [b, setsz, 1, 84, 84] + y_spts = np.array(y_spts, dtype=np.int).reshape( + self.batchsz, setsz + ) # [b, qrysz, 1, 84, 84] x_qrys = np.array(x_qrys, dtype=np.float32).reshape( self.batchsz, querysz, 1, self.resize, self.resize ) y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) - x_spts, y_spts, x_qrys, y_qrys = [ + x_spts, y_spts, x_qrys, y_qrys = ( torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] - ] + ) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 09344900..fe6608e1 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -222,7 +222,7 @@ def test(db, net, epoch, log, args): n_inner_iter = args.inner_steps reg_param = args.reg_params - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num = x_spt.size(0) diff --git a/examples/iMAML/imaml_omniglot_functional.py b/examples/iMAML/imaml_omniglot_functional.py index 1c0a089a..7985f6e8 100644 --- a/examples/iMAML/imaml_omniglot_functional.py +++ b/examples/iMAML/imaml_omniglot_functional.py @@ -196,7 +196,7 @@ def test(db, model, epoch, log, args): qry_losses = [] qry_accs = [] - for batch_idx in range(n_test_iter): + for _ in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num = x_spt.size(0) diff --git a/pyproject.toml b/pyproject.toml index fddc2b26..1dd131a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ lint = [ "flake8-docstrings", "flake8-pyi", "flake8-simplify", + "ruff", "doc8 < 1.0.0a0", # unpin this when we drop support for Python 3.7 "pydocstyle[toml]", "pyenchant", @@ -213,6 +214,80 @@ convention = "google" [tool.doc8] max-line-length = 500 +[tool.ruff] +# Sync with requires-python +target-version = "py37" +line-length = 100 +show-source = true +src = ["torchopt", "tests"] +extend-exclude = ["examples", "tests"] +select = [ + "E", "W", # pycodestyle + "F", # pyflakes + "UP", # pyupgrade + "ANN", # flake8-annotations + "S", # flake8-bandit + "BLE", # flake8-blind-except + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "EXE", # flake8-executable + "ISC", # flake8-implicit-str-concat + "PIE", # flake8-pie + "PYI", # flake8-pyi + "RSE", # flake8-raise + "RET", # flake8-return + "SIM", # flake8-simplify + "TID", # flake8-tidy-imports + "RUF", # ruff +] +ignore = [ + # E501: line too long + # W505: doc line too long + # too long docstring due to long example blocks + "E501", + "W505", + # ANN101: missing type annotation for `self` in method + # ANN102: missing type annotation for `cls` in classmethod + "ANN101", + "ANN102", + # ANN401: dynamically typed expressions (typing.Any) are disallowed + "ANN401", + # S101: use of `assert` detected + # internal use and may never raise at runtime + "S101", + # PLR0402: use from {module} import {name} in lieu of alias + # use alias for import convention (e.g., `import torch.nn as nn`) + "PLR0402", +] +typing-modules = ["torchopt.typing"] + +[tool.ruff.per-file-ignores] +"__init__.py" = [ + "F401", # unused-import +] +"torchopt/pytree.py" = [ + "F401", # unused-import + "F403", # import-star + "F405", # import-star-usage +] +"setup.py" = [ + "ANN", # flake8-annotations +] + +[tool.ruff.flake8-annotations] +allow-star-arg-any = true + +[tool.ruff.flake8-quotes] +docstring-quotes = "double" +multiline-quotes = "double" +inline-quotes = "single" + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.pylint] +allow-magic-value-types = ["int", "str", "float"] + [tool.pytest.ini_options] filterwarnings = [ "error", diff --git a/setup.py b/setup.py index 2142e96a..2571eb7f 100644 --- a/setup.py +++ b/setup.py @@ -85,9 +85,9 @@ def build_extension(self, ext): try: os.chdir(build_temp) - self.spawn([cmake, ext.source_dir] + cmake_args) + self.spawn([cmake, ext.source_dir, *cmake_args]) if not self.dry_run: - self.spawn([cmake, '--build', '.'] + build_args) + self.spawn([cmake, '--build', '.', *build_args]) finally: os.chdir(HERE) diff --git a/tests/requirements.txt b/tests/requirements.txt index 5d424001..b0fa5e51 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -20,6 +20,7 @@ flake8-comprehensions flake8-docstrings flake8-pyi flake8-simplify +ruff # https://github.com/PyCQA/doc8/issues/112 doc8 < 1.0.0a0 pydocstyle[toml] diff --git a/torchopt/accelerated_op/__init__.py b/torchopt/accelerated_op/__init__.py index ede60009..3ac943e3 100644 --- a/torchopt/accelerated_op/__init__.py +++ b/torchopt/accelerated_op/__init__.py @@ -43,5 +43,5 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool: updates = torch.tensor(1.0, device=device) op(updates, updates, updates, 1) return True - except Exception: # pylint: disable=broad-except + except Exception: # noqa: BLE001 # pylint: disable=broad-except return False diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index 08654577..be58e49e 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -96,24 +96,21 @@ def adam( """ b1, b2 = betas # pylint: disable=invalid-name # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): # pragma: no cover + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not chain_fn = chain flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay - if use_accelerated_op: - adam_scaler_fn = scale_by_accelerated_adam - else: - adam_scaler_fn = scale_by_adam + adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam scale_by_neg_lr_fn = scale_by_neg_lr if _get_use_chain_flat(): # default behavior diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 21ef84ef..4503381c 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -109,24 +109,21 @@ def adamw( """ b1, b2 = betas # pylint: disable=invalid-name # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): # pragma: no cover + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') if not 0.0 <= b2 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 1: {b2}') - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not chain_fn = chain flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay - if use_accelerated_op: - adam_scaler_fn = scale_by_accelerated_adam - else: - adam_scaler_fn = scale_by_adam + adam_scaler_fn = scale_by_accelerated_adam if use_accelerated_op else scale_by_adam add_decayed_weights_fn = add_decayed_weights scale_by_neg_lr_fn = scale_by_neg_lr diff --git a/torchopt/alias/rmsprop.py b/torchopt/alias/rmsprop.py index f0eb92cd..96092548 100644 --- a/torchopt/alias/rmsprop.py +++ b/torchopt/alias/rmsprop.py @@ -96,24 +96,21 @@ def rmsprop( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): # pragma: no cover + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= alpha: # pragma: no cover + if not alpha >= 0.0: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= momentum: # pragma: no cover + if not momentum >= 0.0: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not chain_fn = chain flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay - if centered: - rmsprop_scaler_fn = scale_by_stddev - else: - rmsprop_scaler_fn = scale_by_rms + rmsprop_scaler_fn = scale_by_stddev if centered else scale_by_rms trace_fn = trace scale_by_neg_lr_fn = scale_by_neg_lr diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 7d86b538..c2d37292 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -85,11 +85,11 @@ def sgd( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): # pragma: no cover + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= momentum: # pragma: no cover + if not momentum >= 0.0: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover raise ValueError('Nesterov momentum requires a momentum and zero dampening') diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index b5088164..751446e7 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -17,11 +17,13 @@ import threading +import torch + from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform import scale, scale_by_schedule from torchopt.transform.utils import tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, ScalarOrSchedule, Updates +from torchopt.typing import Numeric, OptState, Params, ScalarOrSchedule, Updates __all__ = ['flip_sign_and_add_weight_decay', 'scale_by_neg_lr'] @@ -43,7 +45,7 @@ def _get_use_chain_flat() -> bool: # only used for testing purposes def flip_sign_and_add_weight_decay( - weight_decay: float = 0.0, maximize=False + weight_decay: float = 0.0, maximize: bool = False ) -> GradientTransformation: """Flip the sign of the updates and adds weight decay.""" return _flip_sign_and_add_weight_decay( @@ -54,7 +56,7 @@ def flip_sign_and_add_weight_decay( def _flip_sign_and_add_weight_decay_flat( - weight_decay: float = 0.0, maximize=False + weight_decay: float = 0.0, maximize: bool = False ) -> GradientTransformation: """Flip the sign of the updates and adds weight decay.""" return _flip_sign_and_add_weight_decay( @@ -66,13 +68,13 @@ def _flip_sign_and_add_weight_decay_flat( def _flip_sign_and_add_weight_decay( weight_decay: float = 0.0, - maximize=False, + maximize: bool = False, *, already_flattened: bool = False, ) -> GradientTransformation: """Flip the sign of the updates and adds weight decay.""" # pylint: disable-next=unneeded-not - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if not maximize and weight_decay == 0.0: @@ -104,7 +106,7 @@ def update_fn( if inplace: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) @@ -113,7 +115,7 @@ def f(g, p): else: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return g.add(p, alpha=weight_decay) updates = tree_map(f, updates, params) @@ -132,14 +134,14 @@ def update_fn( ) -> tuple[Updates, OptState]: if inplace: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() updates = tree_map_(f, updates) else: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.neg() updates = tree_map(f, updates) @@ -162,7 +164,7 @@ def update_fn( if inplace: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) @@ -171,7 +173,7 @@ def f(g, p): else: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return g.neg().add_(p, alpha=weight_decay) updates = tree_map(f, updates, params) @@ -194,13 +196,17 @@ def _scale_by_neg_lr_flat(lr: ScalarOrSchedule) -> GradientTransformation: return _scale_by_neg_lr(lr=lr, already_flattened=True) -def _scale_by_neg_lr(lr: ScalarOrSchedule, *, already_flattened=False) -> GradientTransformation: - if not (callable(lr) or 0.0 <= lr): # pragma: no cover +def _scale_by_neg_lr( + lr: ScalarOrSchedule, + *, + already_flattened: bool = False, +) -> GradientTransformation: + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') if callable(lr): - def schedule_wrapper(count): + def schedule_wrapper(count: Numeric) -> Numeric: return -lr(count) # type: ignore[operator] return scale_by_schedule.impl( # type: ignore[attr-defined] diff --git a/torchopt/base.py b/torchopt/base.py index b250c387..dd0bd925 100644 --- a/torchopt/base.py +++ b/torchopt/base.py @@ -236,7 +236,7 @@ def __reduce__(self) -> tuple[Callable, tuple[tuple[GradientTransformation, ...] class IdentityGradientTransformation(GradientTransformation): """A gradient transformation that does nothing.""" - def __new__(cls): + def __new__(cls) -> IdentityGradientTransformation: """Create a new gradient transformation that does nothing.""" return super().__new__(cls, init=cls.init_fn, update=cls.update_fn) diff --git a/torchopt/clip.py b/torchopt/clip.py index b2aafb48..2bf8237e 100644 --- a/torchopt/clip.py +++ b/torchopt/clip.py @@ -88,12 +88,12 @@ def update_fn( clip_coefficient_clamped = min(clip_coefficient, 1.0) if inplace: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.mul_(clip_coefficient_clamped) else: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.mul(clip_coefficient_clamped) new_updates = pytree.tree_map(f, updates) diff --git a/torchopt/combine.py b/torchopt/combine.py index 0f1ed8ec..fc1a7152 100644 --- a/torchopt/combine.py +++ b/torchopt/combine.py @@ -74,10 +74,7 @@ def chain_flat(*transformations: GradientTransformation) -> GradientTransformati """ if len(transformations) == 0: return identity() - if len(transformations) == 1: - inner = transformations[0] - else: - inner = chain(*transformations) + inner = transformations[0] if len(transformations) == 1 else chain(*transformations) def init_fn(params: Params) -> OptState: return inner.init(pytree.tree_leaves(params, none_is_leaf=True)) @@ -90,10 +87,7 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: flat_updates, treespec = pytree.tree_flatten(updates, none_is_leaf=True) - if params is not None: - flat_params = pytree.tree_leaves(params, none_is_leaf=True) - else: - flat_params = None + flat_params = pytree.tree_leaves(params, none_is_leaf=True) if params is not None else None flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace) updates: Updates diff --git a/torchopt/diff/implicit/decorator.py b/torchopt/diff/implicit/decorator.py index 8b70a35a..4f042214 100644 --- a/torchopt/diff/implicit/decorator.py +++ b/torchopt/diff/implicit/decorator.py @@ -131,10 +131,7 @@ def matvec(u: TupleOfTensors) -> TupleOfTensors: ) output: TupleOfTensors - if output_is_tensor: - output = optimality_vjp_fn(u[0]) - else: - output = optimality_vjp_fn(u) + output = optimality_vjp_fn(u[0]) if output_is_tensor else optimality_vjp_fn(u) # Prepend None as the vjp for init_params. true_output: ListOfOptionalTensors = [None] @@ -179,7 +176,7 @@ def _signature_bind_and_match( mapping = [(was_kwarg, ref) for was_kwarg, ref, _ in bound.args] - def map_args_back(out_args): + def map_args_back(out_args: Args) -> tuple[Args, KwArgs]: src_args = [None] * len(args) src_kwargs = {} for (was_kwarg, ref), out_arg in zip(mapping, out_args): @@ -187,7 +184,7 @@ def map_args_back(out_args): src_kwargs[ref] = out_arg else: src_args[ref] = out_arg - return src_args, src_kwargs + return tuple(src_args), src_kwargs out_args = tuple(v for _, _, v in bound.args) out_kwargs = {k: v for k, (_, _, v) in bound.kwargs.items()} @@ -349,7 +346,7 @@ def backward( # pylint: disable=too-many-locals ) args_vjps, kwargs_vjps = map_args_back(vjps) - ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs.keys()) + ordered_vjps = tuple(args_vjps) + tuple(kwargs_vjps[k] for k in kwargs) true_vjps = [] for (_, _, arg_seq_type), vjp in zip(args_signs, ordered_vjps): if arg_seq_type is not None: @@ -399,10 +396,7 @@ def wrapped_solver_fn( result = make_custom_vjp_solver_fn(solver_fn, keys, args_signs).apply(*flat_args, *vals) *output, aux, output_is_tensor, output_type = result - if output_is_tensor: - output = output[0] - else: - output = output_type(output) + output = output[0] if output_is_tensor else output_type(output) if has_aux: return output, aux return output diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index bbae37c9..2aceb656 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -41,8 +41,8 @@ def _stateless_objective_fn( __params_names: Iterable[str], __meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, - *input, - **kwargs, + *input: Any, + **kwargs: Any, ) -> torch.Tensor: with reparametrize( self, @@ -60,8 +60,8 @@ def _stateless_optimality_fn( __params_names: Iterable[str], __meta_params_names: Iterable[str], self: ImplicitMetaGradientModule, - *input, - **kwargs, + *input: Any, + **kwargs: Any, ) -> TupleOfTensors: with reparametrize( self, @@ -83,12 +83,12 @@ def make_optimality_from_objective( ): raise TypeError('The objective function is not defined.') - def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTensors: + def optimality(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> TupleOfTensors: params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) objective_grad_fn = functorch.grad(_stateless_objective_fn, argnums=0) - flat_grads = objective_grad_fn( + return objective_grad_fn( flat_params, flat_meta_params, params_names, @@ -97,7 +97,6 @@ def optimality(self: ImplicitMetaGradientModule, *input, **kwargs) -> TupleOfTen *input, **kwargs, ) - return flat_grads cls.optimality = optimality # type: ignore[assignment] return cls @@ -111,10 +110,7 @@ def enable_implicit_gradients( if getattr(cls_solve, '__implicit_gradients_enabled__', False): raise TypeError('Implicit gradients are already enabled for the `solve` method.') - if cls.linear_solve is not None: - solve_kwargs = {'solve': cls.linear_solve} - else: - solve_kwargs = {} + solve_kwargs = {'solve': cls.linear_solve} if cls.linear_solve is not None else {} @custom_root(_stateless_optimality_fn, argnums=1, has_aux=True, **solve_kwargs) def stateless_solver_fn( @@ -125,8 +121,8 @@ def stateless_solver_fn( __meta_params_names: Iterable[str], # pylint: enable=unused-argument self: ImplicitMetaGradientModule, - *input, - **kwargs, + *input: Any, + **kwargs: Any, ) -> tuple[TupleOfTensors, Any]: """Solve the optimization problem.""" output = cls_solve(self, *input, **kwargs) @@ -134,7 +130,7 @@ def stateless_solver_fn( return flat_optimal_params, output @functools.wraps(cls_solve) - def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any: + def wrapped(self: ImplicitMetaGradientModule, *input: Any, **kwargs: Any) -> Any: """Solve the optimization problem.""" params_names, flat_params = tuple(zip(*self.named_parameters())) meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) @@ -197,7 +193,7 @@ def __init_subclass__(cls, linear_solve: LinearSolver | None = None) -> None: enable_implicit_gradients(cls) @abc.abstractmethod - def solve(self, *input, **kwargs) -> Any: + def solve(self, *input: Any, **kwargs: Any) -> Any: """Solve the inner optimization problem. .. warning:: @@ -224,7 +220,7 @@ def solve(self, batch, labels): """ raise NotImplementedError # update parameters - def optimality(self, *input, **kwargs) -> TupleOfTensors: + def optimality(self, *input: Any, **kwargs: Any) -> TupleOfTensors: r"""Compute the optimality residual. This method stands for the optimality residual to the optimal parameters after solving the @@ -267,7 +263,7 @@ def optimality(self, *input, **kwargs) -> TupleOfTensors: """ # pylint: disable=line-too-long raise NotImplementedError - def objective(self, *input, **kwargs) -> torch.Tensor: + def objective(self, *input: Any, **kwargs: Any) -> torch.Tensor: """Compute the objective function value. This method is used to calculate the :meth:`optimality` if it is not implemented. diff --git a/torchopt/diff/zero_order/__init__.py b/torchopt/diff/zero_order/__init__.py index 5b85d03d..16e39f7c 100644 --- a/torchopt/diff/zero_order/__init__.py +++ b/torchopt/diff/zero_order/__init__.py @@ -16,6 +16,9 @@ import sys as _sys from types import ModuleType as _ModuleType +from typing import Any, Callable + +import torch from torchopt.diff.zero_order import nn from torchopt.diff.zero_order.decorator import zero_order @@ -26,7 +29,9 @@ class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods - def __call__(self, *args, **kwargs): + def __call__( + self, *args: Any, **kwargs: Any + ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: return self.zero_order(*args, **kwargs) diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index a8f712b9..2f677518 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -49,7 +49,7 @@ def _zero_order_naive( # pylint: disable=too-many-statements distribution: Samplable, argnums: tuple[int, ...], num_samples: int, - sigma: Numeric, + sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements @@ -112,15 +112,17 @@ def backward( # pylint: disable=too-many-locals args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] - def add_perturbation(tensor, noises): - return tensor.add(noises, alpha=sigma) + def add_perturbation( + tensor: torch.Tensor, noise: torch.Tensor | Numeric + ) -> torch.Tensor: + return tensor.add(noise, alpha=sigma) param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) + add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] ] noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params @@ -150,7 +152,7 @@ def _zero_order_forward( # pylint: disable=too-many-statements distribution: Samplable, argnums: tuple[int, ...], num_samples: int, - sigma: Numeric, + sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements @@ -214,15 +216,15 @@ def backward( # pylint: disable=too-many-locals args: list[Any] = pytree.tree_unflatten(ctx.args_treespec, flat_args) # type: ignore[assignment] - def add_perturbation(tensor, noises): - return tensor.add(noises, alpha=sigma) + def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + return tensor.add(noise, alpha=sigma) param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = [ - add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) + add_perturbation(t, n) for t, n in zip(flat_diff_params, noises) # type: ignore[arg-type] ] noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, flat_noisy_params @@ -253,7 +255,7 @@ def _zero_order_antithetic( # pylint: disable=too-many-statements distribution: Samplable, argnums: tuple[int, ...], num_samples: int, - sigma: Numeric, + sigma: float, ) -> Callable[..., torch.Tensor]: @functools.wraps(fn) def apply(*args: Any) -> torch.Tensor: # pylint: disable=too-many-statements @@ -295,7 +297,9 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> torch.Tensor: return output @staticmethod - def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals + def backward( # pylint: disable=too-many-locals + ctx: Any, *grad_outputs: Any + ) -> TupleOfOptionalTensors: saved_tensors = ctx.saved_tensors flat_diff_params = saved_tensors[: ctx.len_params] tensors = saved_tensors[ctx.len_params :] @@ -316,7 +320,9 @@ def backward(ctx: Any, *grad_outputs: Any): # pylint: disable=too-many-locals param_grads: ListOfTensors = [0.0 for _ in range(len(flat_diff_params))] # type: ignore[misc] - def get_output(add_perturbation_fn, noises) -> torch.Tensor: + def get_output( + add_perturbation_fn: Callable, noises: Sequence[torch.Tensor | Numeric] + ) -> torch.Tensor: flat_noisy_params = [ add_perturbation_fn(t, n, alpha=sigma) for t, n in zip(flat_diff_params, noises) @@ -332,7 +338,7 @@ def get_output(add_perturbation_fn, noises) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] - output = get_output(torch.add, noises) - get_output(torch.sub, noises) + output = get_output(torch.add, noises) - get_output(torch.sub, noises) # type: ignore[arg-type] weighted_grad = grad_outputs[0].mul(output).mul_(0.5 / sigma) for i, noise in enumerate(noises): @@ -356,7 +362,7 @@ def zero_order( method: Method = 'naive', argnums: int | tuple[int, ...] = (0,), num_samples: int = 1, - sigma: Numeric = 1.0, + sigma: float = 1.0, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Return a decorator for applying zero-order differentiation. @@ -372,7 +378,7 @@ def zero_order( respect to. (default: :const:`0`) num_samples (int, optional): The number of sample to get the averaged estimated gradient. (default: :const:`1`) - sigma (float or Tensor, optional): The standard deviation of the perturbation. + sigma (float, optional): The standard deviation of the perturbation. (default: :const:`1.0`) Returns: diff --git a/torchopt/diff/zero_order/nn/module.py b/torchopt/diff/zero_order/nn/module.py index 6b84f83d..aa75890c 100644 --- a/torchopt/diff/zero_order/nn/module.py +++ b/torchopt/diff/zero_order/nn/module.py @@ -20,7 +20,7 @@ import abc import functools -from typing import Sequence +from typing import Any, Sequence import torch import torch.nn as nn @@ -37,7 +37,7 @@ def enable_zero_order_gradients( cls: type[ZeroOrderGradientModule], method: Method = 'naive', num_samples: int = 1, - sigma: Numeric = 1.0, + sigma: float = 1.0, ) -> type[ZeroOrderGradientModule]: """Enable zero-order gradient estimation for the :func:`forward` method.""" cls_forward = cls.forward @@ -47,15 +47,15 @@ def enable_zero_order_gradients( ) @functools.wraps(cls_forward) - def wrapped(self: ZeroOrderGradientModule, *input, **kwargs) -> torch.Tensor: + def wrapped(self: ZeroOrderGradientModule, *input: Any, **kwargs: Any) -> torch.Tensor: """Do the forward pass calculation.""" params_names, flat_params = tuple(zip(*self.named_parameters())) @zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma) def forward_fn( __flat_params: TupleOfTensors, - *input, - **kwargs, + *input: Any, + **kwargs: Any, ) -> torch.Tensor: with reparametrize(self, zip(params_names, __flat_params)): return cls_forward(self, *input, **kwargs) @@ -74,7 +74,7 @@ def __init_subclass__( # pylint: disable=arguments-differ cls, method: Method = 'naive', num_samples: int = 1, - sigma: Numeric = 1.0, + sigma: float = 1.0, ) -> None: """Validate and initialize the subclass.""" super().__init_subclass__() @@ -86,7 +86,7 @@ def __init_subclass__( # pylint: disable=arguments-differ ) @abc.abstractmethod - def forward(self, *args, **kwargs) -> torch.Tensor: + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: """Do the forward pass of the model.""" raise NotImplementedError diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index b46ad67e..b2fd17cd 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -54,10 +54,7 @@ ] -if rpc.is_available(): - UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT -else: - UNSET_RPC_TIMEOUT = -1.0 +UNSET_RPC_TIMEOUT = rpc.api.UNSET_RPC_TIMEOUT if rpc.is_available() else -1.0 T = TypeVar('T') diff --git a/torchopt/distributed/autograd.py b/torchopt/distributed/autograd.py index 5f6d929f..6c005239 100644 --- a/torchopt/distributed/autograd.py +++ b/torchopt/distributed/autograd.py @@ -31,7 +31,7 @@ LOCK = Lock() -def is_available(): +def is_available() -> bool: """Check if distributed autograd module is available.""" return autograd.is_available() diff --git a/torchopt/distributed/world.py b/torchopt/distributed/world.py index 804d4b9d..a9821ee0 100644 --- a/torchopt/distributed/world.py +++ b/torchopt/distributed/world.py @@ -166,7 +166,7 @@ def wrapper(func: F) -> F: @record @functools.wraps(func) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: rpc.init_rpc( name=world_info.worker_name, rank=world_info.rank, @@ -193,7 +193,7 @@ def wrapper(func: F) -> F: world_rank = get_world_info().world_rank @functools.wraps(func) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: if inverse: if world_rank not in ranks: return func(*args, **kwargs) @@ -211,7 +211,7 @@ def on_rank(*ranks: int) -> Callable[[F], F]: return __on_ranks(ranks=ranks, inverse=False) -def not_on_rank(*ranks) -> Callable[[F], F]: +def not_on_rank(*ranks: int) -> Callable[[F], F]: """Return a decorator to mark a function to be executed only on non given ranks.""" return __on_ranks(ranks=ranks, inverse=True) diff --git a/torchopt/hook.py b/torchopt/hook.py index f188415c..6f6e1753 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -45,7 +45,7 @@ def hook(g: torch.Tensor) -> torch.Tensor: return hook -def register_hook(hook) -> GradientTransformation: +def register_hook(hook: Callable[[torch.Tensor], torch.Tensor | None]) -> GradientTransformation: """Stateless identity transformation that leaves input gradients untouched. This function passes through the *gradient updates* unchanged. @@ -64,7 +64,7 @@ def update_fn( params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, # pylint: disable=unused-argument ) -> tuple[Updates, OptState]: - def f(g): + def f(g: torch.Tensor) -> torch.utils.hooks.RemovableHandle: return g.register_hook(hook) pytree.tree_map_(f, updates) diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 5456f076..309384f3 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -70,12 +70,14 @@ def _cg_solve( b2 = tree_vdot_real(b, b) atol2 = max(rtol**2 * b2, atol**2) - def cond_fn(value): + def cond_fn(value: tuple[TensorTree, TensorTree, float, TensorTree, int]) -> bool: _, r, gamma, _, k = value rs = gamma if M is _identity else tree_vdot_real(r, r) return rs > atol2 and k < maxiter - def body_fn(value): + def body_fn( + value: tuple[TensorTree, TensorTree, float, TensorTree, int] + ) -> tuple[TensorTree, TensorTree, float, TensorTree, int]: x, r, gamma, p, k = value Ap = A(p) alpha = gamma / tree_vdot_real(p, Ap) @@ -129,9 +131,7 @@ def _isolve( ) isolve_solve = partial(_isolve_solve, x0=x0, rtol=rtol, atol=atol, maxiter=maxiter, M=M) - - x = isolve_solve(A, b) - return x + return isolve_solve(A, b) def cg( diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 4880a842..747ad3cf 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -112,7 +112,7 @@ def ns( return inv_A_hat_b -def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None): +def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch.Tensor: """Use Neumann Series iteration to solve ``A^{-1}``.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}') diff --git a/torchopt/linear_solve/cg.py b/torchopt/linear_solve/cg.py index 844c9407..fd44ad75 100644 --- a/torchopt/linear_solve/cg.py +++ b/torchopt/linear_solve/cg.py @@ -36,11 +36,11 @@ from __future__ import annotations import functools -from typing import Callable +from typing import Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_ridge_matvec -from torchopt.typing import TensorTree +from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_cg'] @@ -51,7 +51,7 @@ def _solve_cg( b: TensorTree, ridge: float | None = None, init: TensorTree | None = None, - **kwargs, + **kwargs: Any, ) -> TensorTree: """Solve ``A x = b`` using conjugate gradient. @@ -78,7 +78,7 @@ def _solve_cg( return linalg.cg(matvec, b, x0=init, **kwargs) -def solve_cg(**kwargs): +def solve_cg(**kwargs: Any) -> LinearSolver: """Return a solver function to solve ``A x = b`` using conjugate gradient. This assumes that ``A`` is a hermitian, positive definite matrix. diff --git a/torchopt/linear_solve/inv.py b/torchopt/linear_solve/inv.py index 399a0ef9..460a2463 100644 --- a/torchopt/linear_solve/inv.py +++ b/torchopt/linear_solve/inv.py @@ -36,13 +36,13 @@ from __future__ import annotations import functools -from typing import Callable +from typing import Any, Callable import torch from torchopt import linalg, pytree from torchopt.linear_solve.utils import make_ridge_matvec, materialize_matvec -from torchopt.typing import TensorTree +from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_inv'] @@ -53,7 +53,7 @@ def _solve_inv( b: TensorTree, ridge: float | None = None, ns: bool = False, - **kwargs, + **kwargs: Any, ) -> TensorTree: """Solve ``A x = b`` using matrix inversion. @@ -91,7 +91,7 @@ def _solve_inv( return tree_unravel(pytree.tree_map(torch.linalg.solve, A, tree_ravel(b))) -def solve_inv(**kwargs): +def solve_inv(**kwargs: Any) -> LinearSolver: """Return a solver function to solve ``A x = b`` using matrix inversion. If ``ns = False``, this assumes the matrix ``A`` is a constant matrix and will materialize it diff --git a/torchopt/linear_solve/normal_cg.py b/torchopt/linear_solve/normal_cg.py index 8d38f77a..05136509 100644 --- a/torchopt/linear_solve/normal_cg.py +++ b/torchopt/linear_solve/normal_cg.py @@ -36,11 +36,11 @@ from __future__ import annotations import functools -from typing import Callable +from typing import Any, Callable from torchopt import linalg from torchopt.linear_solve.utils import make_normal_matvec, make_ridge_matvec, make_rmatvec -from torchopt.typing import TensorTree +from torchopt.typing import LinearSolver, TensorTree __all__ = ['solve_normal_cg'] @@ -51,7 +51,7 @@ def _solve_normal_cg( b: TensorTree, ridge: float | None = None, init: TensorTree | None = None, - **kwargs, + **kwargs: Any, ) -> TensorTree: """Solve the normal equation ``A^T A x = A^T b`` using conjugate gradient. @@ -71,10 +71,7 @@ def _solve_normal_cg( Returns: The solution with the same structure as ``b``. """ - if init is None: - example_x = b # This assumes that matvec is a square linear operator. - else: - example_x = init + example_x = b if init is None else init rmatvec = make_rmatvec(matvec, example_x) # (x) -> A.T @ x normal_matvec = make_normal_matvec(matvec) # (x) -> A.T @ A @ x @@ -90,7 +87,7 @@ def _solve_normal_cg( return linalg.cg(normal_matvec, rhs, x0=init, **kwargs) -def solve_normal_cg(**kwargs): +def solve_normal_cg(**kwargs: Any) -> LinearSolver: """Return a solver function to solve ``A^T A x = A^T b`` using conjugate gradient. This can be used to solve ``A x = b`` using conjugate gradient when ``A`` is not hermitian, diff --git a/torchopt/nn/module.py b/torchopt/nn/module.py index f8804864..383c5728 100644 --- a/torchopt/nn/module.py +++ b/torchopt/nn/module.py @@ -40,7 +40,7 @@ class MetaGradientModule(nn.Module): # pylint: disable=abstract-method _meta_parameters: TensorContainer _meta_modules: dict[str, nn.Module | None] - def __new__(cls, *args, **kwargs) -> MetaGradientModule: + def __new__(cls, *args: Any, **kwargs: Any) -> MetaGradientModule: """Create a new module instance.""" instance = super().__new__(cls) flat_args: list[Any] @@ -56,7 +56,7 @@ def __new__(cls, *args, **kwargs) -> MetaGradientModule: instance._meta_modules: dict[str, nn.Module | None] = OrderedDict() # type: ignore[misc] return instance - def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument """Initialize a new module instance.""" super().__init__() @@ -88,7 +88,7 @@ def __getattr__(self, name: str) -> torch.Tensor | nn.Module: def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None: """Set an attribute of the module.""" - def remove_from(*dicts_or_sets): + def remove_from(*dicts_or_sets: dict[str, Any] | set[str]) -> None: for dict_or_set in dicts_or_sets: if name in dict_or_set: if isinstance(dict_or_set, dict): diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index 9391352f..e547b5cb 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -57,10 +57,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: prefix, _, attr = path.rpartition('.') mod = get_submodule(prefix) - if allow_missing: - orig = getattr(mod, attr, MISSING) - else: - orig = getattr(mod, attr) + orig = getattr(mod, attr, MISSING) if allow_missing else getattr(mod, attr) # pylint: disable=protected-access if value is MISSING: @@ -77,10 +74,7 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: return orig - orig_named_tensors = { - name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items() - } - return orig_named_tensors + return {name: recursive_setattr(name, tensor) for name, tensor in named_tensors.items()} @contextlib.contextmanager diff --git a/torchopt/optim/base.py b/torchopt/optim/base.py index aac3a782..d0be2fd1 100644 --- a/torchopt/optim/base.py +++ b/torchopt/optim/base.py @@ -67,12 +67,12 @@ def zero_grad(self, set_to_none: bool = False) -> None: """ if set_to_none: - def f(p): + def f(p: torch.Tensor) -> None: p.grad = None else: - def f(p): + def f(p: torch.Tensor) -> None: if p.grad is None: return if p.grad.grad_fn is not None: @@ -110,7 +110,7 @@ def step(self, closure: Callable[[], torch.Tensor] | None = None) -> torch.Tenso with torch.enable_grad(): loss = closure() - def f(p): + def f(p: torch.Tensor) -> torch.Tensor | None: return p.grad for i, (params, state) in enumerate(zip(self.param_groups, self.state_groups)): diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 9dce3412..0b567eb8 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -88,8 +88,7 @@ def step( updates, self.optim_state = self.impl.update( grads, self.optim_state, params=params, inplace=inplace ) - new_params = apply_updates(params, updates, inplace=inplace) - return new_params + return apply_updates(params, updates, inplace=inplace) def state_dict(self) -> OptState: """Extract the references of the optimizer states. diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 14745766..12ea5e8f 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -36,6 +36,8 @@ from typing import Any, Callable, NamedTuple +import torch + from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation, identity from torchopt.transform.utils import tree_map_flat, tree_map_flat_ @@ -103,12 +105,12 @@ def _masked( *, already_flattened: bool = False, ) -> GradientTransformation: - if already_flattened: + if already_flattened: # noqa: SIM108 tree_map = tree_map_flat else: tree_map = pytree.tree_map # type: ignore[assignment] - def tree_mask(params, mask_tree): + def tree_mask(params: Params, mask_tree: OptState) -> Params: return tree_map(lambda p, m: p if m else MaskedNode(), params, mask_tree) def init_fn(params: Params) -> OptState: @@ -188,7 +190,7 @@ def _add_decayed_weights( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable-next=unneeded-not - if not 0.0 <= weight_decay: # pragma: no cover + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if weight_decay == 0.0 and mask is None: @@ -218,7 +220,7 @@ def update_fn( if inplace: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) @@ -227,7 +229,7 @@ def f(g, p): else: - def f(g, p): + def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return g.add(p, alpha=weight_decay) updates = tree_map(f, updates, params) diff --git a/torchopt/transform/nan_to_num.py b/torchopt/transform/nan_to_num.py index 804f8219..27d87499 100644 --- a/torchopt/transform/nan_to_num.py +++ b/torchopt/transform/nan_to_num.py @@ -16,6 +16,8 @@ from __future__ import annotations +import torch + from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation from torchopt.typing import OptState, Params, Updates @@ -44,12 +46,12 @@ def update_fn( ) -> tuple[Updates, OptState]: if inplace: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf) else: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) new_updates = pytree.tree_map(f, updates) diff --git a/torchopt/transform/scale.py b/torchopt/transform/scale.py index 639c903e..c731003c 100644 --- a/torchopt/transform/scale.py +++ b/torchopt/transform/scale.py @@ -33,6 +33,8 @@ from __future__ import annotations +import torch + from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation from torchopt.transform.utils import tree_map_flat, tree_map_flat_ @@ -85,14 +87,14 @@ def update_fn( ) -> tuple[Updates, OptState]: if inplace: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.mul_(step_size) updates = tree_map_(f, updates) else: - def f(g): + def f(g: torch.Tensor) -> torch.Tensor: return g.mul(step_size) updates = tree_map(f, updates) diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index 36f30be9..c1f6274a 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -69,7 +69,7 @@ def _bias_correction( ) -> Updates: """Perform bias correction. This becomes a no-op as count goes to infinity.""" - def f(t, c): # pylint: disable=invalid-name + def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name return t.div(1 - pow(decay, c)) if already_flattened: @@ -142,7 +142,7 @@ def _scale_by_adam( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') @@ -150,7 +150,7 @@ def _scale_by_adam( raise ValueError(f'Invalid beta parameter at index 1: {b2}') # pylint: enable=unneeded-not - if already_flattened: + if already_flattened: # noqa: SIM108 tree_map = tree_map_flat else: tree_map = pytree.tree_map # type: ignore[assignment] @@ -187,12 +187,20 @@ def update_fn( if inplace: - def f(g, m, v): # pylint: disable=unused-argument + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: return m.div_(v.add_(eps_root).sqrt_().add(eps)) else: - def f(g, m, v): # pylint: disable=unused-argument + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: return m.div(v.add(eps_root).sqrt_().add(eps)) updates = tree_map(f, updates, mu_hat, nu_hat) @@ -272,7 +280,7 @@ def _scale_by_accelerated_adam( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= b1 < 1.0: # pragma: no cover raise ValueError(f'Invalid beta parameter at index 0: {b1}') diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index 7a0c8c20..d415522a 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -101,9 +101,9 @@ def _scale_by_rms( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= alpha: # pragma: no cover + if not alpha >= 0.0: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') # pylint: enable=unneeded-not @@ -131,14 +131,14 @@ def update_fn( if inplace: - def f(g, n): # pylint: disable=invalid-name + def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name return g.div_(n.sqrt().add_(eps)) updates = tree_map_(f, updates, nu) else: - def f(g, n): # pylint: disable=invalid-name + def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name return g.div(n.sqrt().add(eps)) updates = tree_map(f, updates, nu) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index d6e3b0fa..8347a05c 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -40,7 +40,7 @@ from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import inc_count, tree_map_flat, tree_map_flat_ -from torchopt.typing import OptState, Params, Schedule, SequenceOfTensors, Updates +from torchopt.typing import Numeric, OptState, Params, Schedule, SequenceOfTensors, Updates __all__ = ['scale_by_schedule'] @@ -96,7 +96,7 @@ def update_fn( ) -> tuple[Updates, OptState]: if inplace: - def f(g, c): # pylint: disable=invalid-name + def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name step_size = step_size_fn(c) return g.mul_(step_size) @@ -104,7 +104,7 @@ def f(g, c): # pylint: disable=invalid-name else: - def f(g, c): # pylint: disable=invalid-name + def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name step_size = step_size_fn(c) return g.mul(step_size) diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 228ed707..06aaa90c 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -104,9 +104,9 @@ def _scale_by_stddev( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= alpha: # pragma: no cover + if not alpha >= 0.0: # pragma: no cover raise ValueError(f'Invalid alpha value: {alpha}') - if not 0.0 <= eps: # pragma: no cover + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') # pylint: enable=unneeded-not @@ -138,14 +138,14 @@ def update_fn( if inplace: - def f(g, m, n): + def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) updates = tree_map_(f, updates, mu, nu) else: - def f(g, m, n): + def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) updates = tree_map(f, updates, mu, nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 03d2441d..98a48e26 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -110,7 +110,7 @@ def _trace( already_flattened: bool = False, ) -> GradientTransformation: # pylint: disable=unneeded-not - if not 0.0 <= momentum: # pragma: no cover + if not momentum >= 0.0: # pragma: no cover raise ValueError(f'Invalid momentum value: {momentum}') if nesterov and (momentum <= 0.0 or dampening != 0.0): # pragma: no cover raise ValueError('Nesterov momentum requires a momentum and zero dampening') @@ -147,12 +147,12 @@ def update_fn( if nesterov: if inplace: - def f1(g, t): + def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g, t): + def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return g.add_(t, alpha=momentum) new_trace = tree_map(f1, updates, state.trace) @@ -160,12 +160,12 @@ def f2(g, t): else: - def f1(g, t): + def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g, t): + def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return g.add(t, alpha=momentum) new_trace = tree_map(f1, updates, state.trace) @@ -174,12 +174,12 @@ def f2(g, t): else: if inplace: - def f(g, t): + def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g, t): + def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return g.copy_(t) new_trace = tree_map(f, updates, state.trace) @@ -187,7 +187,7 @@ def copy_(g, t): else: - def f(g, t): + def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 77ba58ca..57abe7ec 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -59,7 +59,7 @@ def tree_map_flat( fn = func else: - def fn(x, *xs): + def fn(x: Any | None, *xs: Any) -> Any | None: return func(x, *xs) if x is not None else None return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg] @@ -76,7 +76,7 @@ def tree_map_flat_( fn = func else: - def fn(x, *xs): + def fn(x: Any | None, *xs: Any) -> Any | None: return func(x, *xs) if x is not None else None flat_results = map(fn, flat_arg, *flat_args) @@ -111,7 +111,7 @@ def _inc_count( *, already_flattened: bool = False, ) -> TensorTree: - def f(c, g): # pylint: disable=invalid-name + def f(c: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor: # pylint: disable=invalid-name return c + (c != INT64_MAX).to(torch.int64) if g is not None else c if already_flattened: @@ -167,30 +167,30 @@ def _update_moment( *, order: int, inplace: bool = True, - already_flattened=False, + already_flattened: bool = False, ) -> TensorTree: assert order in (1, 2) if inplace: if order == 2: - def f(g, t): + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t else: - def f(g, t): + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t else: if order == 2: - def f(g, t): + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t else: - def f(g, t): + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t if already_flattened: diff --git a/torchopt/update.py b/torchopt/update.py index 9485896b..3a2a6984 100644 --- a/torchopt/update.py +++ b/torchopt/update.py @@ -31,6 +31,10 @@ # ============================================================================== """Helper functions for applying updates.""" +from __future__ import annotations + +import torch + from torchopt import pytree from torchopt.typing import Params, Updates @@ -59,14 +63,14 @@ def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> """ if inplace: - def f(p, u): + def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor: if u is not None: p.data.add_(u) return p else: - def f(p, u): + def f(p: torch.Tensor, u: torch.Tensor | None) -> torch.Tensor: return p.add(u) if u is not None else p return pytree.tree_map(f, params, updates) diff --git a/torchopt/utils.py b/torchopt/utils.py index 12adb214..6afe08e7 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -18,7 +18,7 @@ import copy import itertools -from typing import TYPE_CHECKING, NamedTuple, Sequence, cast, overload +from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, cast, overload from typing_extensions import Literal # Python 3.8+ from typing_extensions import TypeAlias # Python 3.10+ @@ -46,8 +46,8 @@ class ModuleState(NamedTuple): """Container for module state.""" - params: tuple[dict[str, torch.Tensor], ...] - buffers: tuple[dict[str, torch.Tensor], ...] + params: tuple[TensorContainer, ...] + buffers: tuple[TensorContainer, ...] visual_contents: dict | None = None detach_buffers: bool = False @@ -74,7 +74,7 @@ def stop_gradient(target: ModuleState | nn.Module | MetaOptimizer | TensorTree) # pylint: disable-next=import-outside-toplevel from torchopt.optim.meta.base import MetaOptimizer - def fn_(obj): + def fn_(obj: Any) -> None: if isinstance(obj, torch.Tensor): requires_grad = obj.requires_grad obj.detach_().requires_grad_(requires_grad) @@ -221,11 +221,11 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: visual_contents = None - params: list[dict[str, torch.Tensor]] = [] - buffers: list[dict[str, torch.Tensor]] = [] + params: list[TensorContainer] = [] + buffers: list[TensorContainer] = [] memo: set[nn.Module] = set() - def update_params(container): + def update_params(container: TensorContainer) -> None: if len(container) > 0: params.append( type(container)( @@ -235,7 +235,7 @@ def update_params(container): ) ) - def update_buffers(container): + def update_buffers(container: TensorContainer) -> None: if len(container) > 0: fn = clone_detach_ if detach_buffers else replicate buffers.append( @@ -245,14 +245,14 @@ def update_buffers(container): ) # pylint: disable=protected-access - update_params(target._parameters) + update_params(target._parameters) # type: ignore[arg-type] if with_buffers: update_buffers(target._buffers) memo.add(target) for submodule in target.modules(): if submodule in memo: continue - update_params(submodule._parameters) + update_params(submodule._parameters) # type: ignore[arg-type] if with_buffers: update_buffers(submodule._buffers) memo.add(submodule) @@ -264,10 +264,10 @@ def update_buffers(container): detach_buffers=detach_buffers, ) - elif isinstance(target, MetaOptimizer): + if isinstance(target, MetaOptimizer): state = target.state_dict() - def get_variable(t): + def get_variable(t: torch.Tensor | None) -> torch.Tensor | None: if isinstance(t, torch.Tensor): return replicate(t) return t @@ -287,19 +287,19 @@ def extract_module_containers( buffers: list[TensorContainer] = [] memo: set[nn.Module] = set() - def update_container(container, items): + def update_container(container: list[TensorContainer], items: TensorContainer) -> None: if len(items) > 0: container.append(items) # we need references to original dictionaries # pylint: disable=protected-access - update_container(params, module._parameters) + update_container(params, module._parameters) # type: ignore[arg-type] if with_buffers: update_container(buffers, module._buffers) memo.add(module) for submodule in module.modules(): if submodule in memo: continue - update_container(params, submodule._parameters) + update_container(params, submodule._parameters) # type: ignore[arg-type] if with_buffers: update_container(buffers, submodule._buffers) memo.add(submodule) diff --git a/torchopt/visual.py b/torchopt/visual.py index 7afe65a4..493ffbab 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -20,12 +20,13 @@ from __future__ import annotations from collections import namedtuple -from typing import Generator, Iterable, Mapping, cast +from typing import Any, Generator, Iterable, Mapping, cast import torch from graphviz import Digraph -from torchopt.typing import TensorOrTensors +from torchopt import pytree +from torchopt.typing import TensorTree from torchopt.utils import ModuleState @@ -38,7 +39,7 @@ SAVED_PREFIX = '_saved_' -def get_fn_name(fn, show_attrs, max_attr_chars): +def get_fn_name(fn: Any, show_attrs: bool, max_attr_chars: int) -> str: """Return function name.""" name = str(type(fn).__name__) if not show_attrs: @@ -63,7 +64,7 @@ def get_fn_name(fn, show_attrs, max_attr_chars): sep = '-' * max(col1width + col2width + 2, len(name)) attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's' - def truncate(s): # pylint: disable=invalid-name + def truncate(s: str) -> str: # pylint: disable=invalid-name return s[: col2width - 3] + '...' if len(s) > col2width else s params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items()) @@ -72,7 +73,7 @@ def truncate(s): # pylint: disable=invalid-name # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals def make_dot( - var: TensorOrTensors, + var: TensorTree, params: ( Mapping[str, torch.Tensor] | ModuleState @@ -142,20 +143,20 @@ def make_dot( dot = Digraph(node_attr=node_attr, graph_attr={'size': '12,12'}) seen = set() - def size_to_str(size): + def size_to_str(size: tuple[int, ...]) -> str: return '(' + (', ').join(map(str, size)) + ')' - def get_var_name(var, name=None): + def get_var_name(var: torch.Tensor, name: str | None = None) -> str: if not name: name = param_map[var] if var in param_map else '' return f'{name}\n{size_to_str(var.size())}' - def get_var_name_with_flag(var): + def get_var_name_with_flag(var: torch.Tensor) -> str | None: if var in param_map: return f'{param_map[var][0]}\n{size_to_str(param_map[var][1].size())}' return None - def add_nodes(fn): # pylint: disable=too-many-branches + def add_nodes(fn: Any) -> None: # pylint: disable=too-many-branches assert not isinstance(fn, torch.Tensor) if fn in seen: return @@ -210,7 +211,10 @@ def add_nodes(fn): # pylint: disable=too-many-branches dot.edge(str(id(t)), str(id(fn))) dot.node(str(id(t)), get_var_name(t), fillcolor='orange') - def add_base_tensor(v, color='darkolivegreen1'): # pylint: disable=invalid-name + def add_base_tensor( + v: torch.Tensor, # pylint: disable=invalid-name + color: str = 'darkolivegreen1', + ) -> None: if v in seen: return seen.add(v) @@ -220,15 +224,11 @@ def add_base_tensor(v, color='darkolivegreen1'): # pylint: disable=invalid-name dot.edge(str(id(v.grad_fn)), str(id(v))) # pylint: disable=protected-access if v._is_view(): - add_base_tensor(v._base, color='darkolivegreen3') + add_base_tensor(v._base, color='darkolivegreen3') # type: ignore[arg-type] dot.edge(str(id(v._base)), str(id(v)), style='dotted') # handle multiple outputs - if isinstance(var, (tuple, list)): - for v in var: # pylint: disable=invalid-name - add_base_tensor(v) - else: - add_base_tensor(var) + pytree.tree_map_(add_base_tensor, var) resize_graph(dot)