Skip to content

Commit

Permalink
chore(pre-commit): update pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jun 17, 2024
1 parent 605929a commit 5bc8133
Show file tree
Hide file tree
Showing 75 changed files with 519 additions and 280 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.5
rev: v18.1.6
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.4.9
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -43,7 +43,7 @@ repos:
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -52,7 +52,7 @@ repos:
^examples/
)
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
rev: 7.1.0
hooks:
- id: flake8
additional_dependencies:
Expand Down
17 changes: 14 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ extend-exclude = ["examples"]
select = [
"E", "W", # pycodestyle
"F", # pyflakes
"C90", # mccabe
"UP", # pyupgrade
"ANN", # flake8-annotations
"S", # flake8-bandit
Expand All @@ -243,14 +244,21 @@ select = [
"COM", # flake8-commas
"C4", # flake8-comprehensions
"EXE", # flake8-executable
"FA", # flake8-future-annotations
"LOG", # flake8-logging
"ISC", # flake8-implicit-str-concat
"INP", # flake8-no-pep420
"PIE", # flake8-pie
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
"TID", # flake8-tidy-imports
"TCH", # flake8-type-checking
"PERF", # perflint
"FURB", # refurb
"TRY", # tryceratops
"RUF", # ruff
]
ignore = [
Expand All @@ -268,9 +276,9 @@ ignore = [
# S101: use of `assert` detected
# internal use and may never raise at runtime
"S101",
# PLR0402: use from {module} import {name} in lieu of alias
# use alias for import convention (e.g., `import torch.nn as nn`)
"PLR0402",
# TRY003: avoid specifying long messages outside the exception class
# long messages are necessary for clarity
"TRY003",
]
typing-modules = ["torchopt.typing"]

Expand All @@ -296,6 +304,9 @@ typing-modules = ["torchopt.typing"]
"F401", # unused-import
"F811", # redefined-while-unused
]
"docs/source/conf.py" = [
"INP001", # flake8-no-pep420
]

[tool.ruff.lint.flake8-annotations]
allow-star-arg-any = true
Expand Down
7 changes: 5 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import itertools
import os
import random
from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import numpy as np
import pytest
Expand All @@ -30,7 +30,10 @@
from torch.utils import data

from torchopt import pytree
from torchopt.typing import TensorTree


if TYPE_CHECKING:
from torchopt.typing import TensorTree


BATCH_SIZE = 64
Expand Down
7 changes: 5 additions & 2 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import Callable
from typing import TYPE_CHECKING, Callable

import functorch
import pytest
Expand All @@ -26,7 +26,10 @@
import torchopt
from torchopt import pytree
from torchopt.alias.utils import _set_use_chain_flat
from torchopt.typing import TensorTree


if TYPE_CHECKING:
from torchopt.typing import TensorTree


@helpers.parametrize(
Expand Down
23 changes: 16 additions & 7 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import copy
import re
from collections import OrderedDict
from types import FunctionType
from typing import TYPE_CHECKING

import functorch
import numpy as np
Expand Down Expand Up @@ -47,6 +47,10 @@
HAS_JAX = False


if TYPE_CHECKING:
from types import FunctionType


BATCH_SIZE = 8
NUM_UPDATES = 3

Expand Down Expand Up @@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader:
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
)
def test_imaml_solve_normal_cg(
def test_imaml_solve_normal_cg( # noqa: C901
dtype: torch.dtype,
lr: float,
inner_lr: float,
Expand Down Expand Up @@ -251,7 +255,7 @@ def outer_level(p, xs, ys):
inner_update=[20, 50, 100],
ns=[False, True],
)
def test_imaml_solve_inv(
def test_imaml_solve_inv( # noqa: C901
dtype: torch.dtype,
lr: float,
inner_lr: float,
Expand Down Expand Up @@ -375,7 +379,12 @@ def outer_level(p, xs, ys):
inner_lr=[2e-2, 2e-3],
inner_update=[20, 50, 100],
)
def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None:
def test_imaml_module( # noqa: C901
dtype: torch.dtype,
lr: float,
inner_lr: float,
inner_update: int,
) -> None:
np_dtype = helpers.dtype_torch2numpy(dtype)

jax_model, jax_params = get_model_jax(dtype=np_dtype)
Expand Down Expand Up @@ -763,7 +772,7 @@ def solve(self):
make_optimality_from_objective(MyModule2)


def test_module_abstract_methods() -> None:
def test_module_abstract_methods() -> None: # noqa: C901
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
def objective(self):
return torch.tensor(0.0)
Expand Down Expand Up @@ -809,7 +818,7 @@ def solve(self):

class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
@classmethod
def optimality(self):
def optimality(cls):
return ()

def solve(self):
Expand Down Expand Up @@ -846,7 +855,7 @@ def solve(self):

class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
@classmethod
def objective(self):
def objective(cls):
return ()

def solve(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================

import operator

import torch

import torchopt
Expand Down Expand Up @@ -80,7 +82,7 @@ def test_module_clone() -> None:
assert y.is_cuda


def test_extract_state_dict():
def test_extract_state_dict(): # noqa: C901
fc = torch.nn.Linear(1, 1)
state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta'))
for param_dict in state_dict.params:
Expand Down Expand Up @@ -121,7 +123,7 @@ def test_extract_state_dict():
loss = fc(torch.ones(1, 1)).sum()
optim.step(loss)
state_dict = torchopt.extract_state_dict(optim)
same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups))
same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups))
assert all(pytree.tree_flatten(same)[0])


Expand Down
66 changes: 33 additions & 33 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,50 +81,50 @@


__all__ = [
'accelerated_op_available',
'adam',
'adamax',
'adadelta',
'radam',
'adamw',
'adagrad',
'rmsprop',
'sgd',
'clip_grad_norm',
'nan_to_num',
'register_hook',
'chain',
'Optimizer',
'SGD',
'Adam',
'AdaMax',
'Adamax',
'AdaDelta',
'Adadelta',
'RAdam',
'AdamW',
'AdaGrad',
'AdaMax',
'Adadelta',
'Adagrad',
'RMSProp',
'RMSprop',
'MetaOptimizer',
'MetaSGD',
'MetaAdam',
'MetaAdaMax',
'MetaAdamax',
'Adam',
'AdamW',
'Adamax',
'FuncOptimizer',
'MetaAdaDelta',
'MetaAdadelta',
'MetaRAdam',
'MetaAdamW',
'MetaAdaGrad',
'MetaAdaMax',
'MetaAdadelta',
'MetaAdagrad',
'MetaAdam',
'MetaAdamW',
'MetaAdamax',
'MetaOptimizer',
'MetaRAdam',
'MetaRMSProp',
'MetaRMSprop',
'FuncOptimizer',
'MetaSGD',
'Optimizer',
'RAdam',
'RMSProp',
'RMSprop',
'accelerated_op_available',
'adadelta',
'adagrad',
'adam',
'adamax',
'adamw',
'apply_updates',
'chain',
'clip_grad_norm',
'extract_state_dict',
'recover_state_dict',
'stop_gradient',
'module_clone',
'module_detach_',
'nan_to_num',
'radam',
'recover_state_dict',
'register_hook',
'rmsprop',
'sgd',
'stop_gradient',
]
9 changes: 6 additions & 3 deletions torchopt/accelerated_op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

from __future__ import annotations

from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import torch

from torchopt.accelerated_op.adam_op import AdamOp
from torchopt.typing import Device


if TYPE_CHECKING:
from torchopt.typing import Device


def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
Expand All @@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
return False
updates = torch.tensor(1.0, device=device)
op(updates, updates, updates, 1)
return True
except Exception: # noqa: BLE001 # pylint: disable=broad-except
return False
return True
6 changes: 5 additions & 1 deletion torchopt/accelerated_op/_src/adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from __future__ import annotations

import torch
from typing import TYPE_CHECKING


if TYPE_CHECKING:
import torch


def forward_(
Expand Down
11 changes: 10 additions & 1 deletion torchopt/alias/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,13 @@
from torchopt.alias.sgd import sgd


__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd']
__all__ = [
'adadelta',
'adagrad',
'adam',
'adamax',
'adamw',
'radam',
'rmsprop',
'sgd',
]
7 changes: 6 additions & 1 deletion torchopt/alias/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

from __future__ import annotations

from typing import TYPE_CHECKING

from torchopt.alias.utils import (
_get_use_chain_flat,
flip_sign_and_add_weight_decay,
scale_by_neg_lr,
)
from torchopt.combine import chain
from torchopt.transform import scale_by_adadelta
from torchopt.typing import GradientTransformation, ScalarOrSchedule


if TYPE_CHECKING:
from torchopt.typing import GradientTransformation, ScalarOrSchedule


__all__ = ['adadelta']
Expand Down
Loading

0 comments on commit 5bc8133

Please sign in to comment.