Skip to content

Commit

Permalink
feat: add flake8 plugins and enable pre-commit hook (metaopt#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Feb 17, 2023
1 parent f5b292c commit 845e572
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 26 deletions.
41 changes: 41 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[flake8]
max-line-length = 120
max-doc-length = 100
select = B,C,E,F,W,Y,SIM
ignore =
# E203: whitespace before ':'
# W503: line break before binary operator
# W504: line break after binary operator
# format by black
E203,W503,W504,
# E501: line too long
# W505: doc line too long
# too long docstring due to long example blocks
E501,W505,
per-file-ignores =
# F401: module imported but unused
# intentionally unused imports
__init__.py: F401
# F401: module imported but unused
# F403: unable to detect undefined names
# F405: member mey be undefined, or defined from star imports
# members populated from optree
torchopt/pytree.py: F401,F403,F405
# E301: expected 1 blank line
# E302: expected 2 blank lines
# E305: expected 2 blank lines after class or function definition
# E701: multiple statements on one line (colon)
# E704: multiple statements on one line (def)
# format by black
*.pyi: E301,E302,E305,E701,E704
exclude =
.git,
.vscode,
venv,
third-party,
__pycache__,
docs/source/conf.py,
build,
dist,
examples,
tests
26 changes: 21 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
ci:
skip: [pylint]
autofix_prs: true
autofix_commit_msg: 'fix: [pre-commit.ci] auto fixes [...]'
autoupdate_commit_msg: 'chore(pre-commit): [pre-commit.ci] autoupdate'
autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand All @@ -26,8 +26,8 @@ repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v15.0.7
hooks:
- id: clang-format
stages: [commit, push, manual]
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand All @@ -48,6 +48,22 @@ repos:
(?x)(
^examples/
)
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear
- flake8-comprehensions
- flake8-docstrings
- flake8-pyi
- flake8-simplify
exclude: |
(?x)(
^examples/|
^tests/|
^docs/source/conf.py$
)
- repo: local
hooks:
- id: pylint
Expand All @@ -68,7 +84,7 @@ repos:
rev: 6.3.0
hooks:
- id: pydocstyle
additional_dependencies: ['.[toml]']
additional_dependencies: [".[toml]"]
exclude: |
(?x)(
^.github/|
Expand Down
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ pylint-install:

flake8-install:
$(call check_pip_install,flake8)
$(call check_pip_install_extra,flake8-bugbear,flake8-bugbear)
$(call check_pip_install,flake8-bugbear)
$(call check_pip_install,flake8-comprehensions)
$(call check_pip_install,flake8-docstrings)
$(call check_pip_install,flake8-pyi)
$(call check_pip_install,flake8-simplify)

py-format-install:
$(call check_pip_install,isort)
Expand Down Expand Up @@ -122,7 +126,7 @@ pylint: pylint-install
$(PYTHON) -m pylint $(PROJECT_PATH)

flake8: flake8-install
$(PYTHON) -m flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
$(PYTHON) -m flake8 --count --show-source --statistics

py-format: py-format-install
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
Expand Down
6 changes: 5 additions & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ dependencies:
- mypy >= 0.990
- flake8
- flake8-bugbear
- doc8 < 1.0.0a0
- flake8-comprehensions
- flake8-docstrings
- flake8-pyi
- flake8-simplify
- doc8
- pydocstyle
- clang-format >= 14
- clang-tools >= 14 # clang-tidy
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ lint = [
"mypy >= 0.990",
"flake8",
"flake8-bugbear",
"flake8-comprehensions",
"flake8-docstrings",
"flake8-pyi",
"flake8-simplify",
"doc8 < 1.0.0a0", # unpin this when we drop support for Python 3.7
"pydocstyle[toml]",
"pyenchant",
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def build_extension(self, ext):
LINUX = platform.system() == 'Linux'
MACOS = platform.system() == 'Darwin'
WINDOWS = platform.system() == 'Windows'
ext_kwargs = dict(
cmdclass={'build_ext': cmake_build_ext},
ext_modules=[
ext_kwargs = {
'cmdclass': {'build_ext': cmake_build_ext},
'ext_modules': [
CMakeExtension(
'torchopt._C',
source_dir=HERE,
optional=not (LINUX and CIBUILDWHEEL),
)
],
)
}

TORCHOPT_NO_EXTENSIONS = (
bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or (MACOS and CIBUILDWHEEL)
Expand All @@ -123,7 +123,7 @@ def build_extension(self, ext):
VERSION_FILE.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
f"__version__ = '{version.__version__}'",
f'__version__ = {version.__version__!r}',
string=VERSION_CONTENT,
),
encoding='UTF-8',
Expand Down
4 changes: 4 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pylint[spelling] >= 2.15.0
mypy >= 0.990
flake8
flake8-bugbear
flake8-comprehensions
flake8-docstrings
flake8-pyi
flake8-simplify
# https://github.com/PyCQA/doc8/issues/112
doc8 < 1.0.0a0
pydocstyle[toml]
Expand Down
2 changes: 0 additions & 2 deletions torchopt/_C/adam_op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# pylint: disable=all

from __future__ import annotations

import torch

def forward_(
Expand Down
7 changes: 5 additions & 2 deletions torchopt/diff/implicit/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _root_vjp(
grad_outputs: TupleOfTensors,
output_is_tensor: bool,
argnums: tuple[int, ...],
solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
solve: Callable[..., TensorOrTensors],
) -> TupleOfOptionalTensors:
if output_is_tensor:

Expand Down Expand Up @@ -414,7 +414,7 @@ def custom_root(
optimality_fn: Callable[..., TensorOrTensors],
argnums: int | tuple[int, ...],
has_aux: bool = False,
solve: Callable[..., TensorOrTensors] = linear_solve.solve_normal_cg(),
solve: Callable[..., TensorOrTensors] | None = None,
) -> Callable[
[Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]]],
Callable[..., TensorOrTensors | tuple[TensorOrTensors, Any]],
Expand Down Expand Up @@ -465,6 +465,9 @@ def solver_fn(params, arg1, arg2, ...):
else:
assert 0 not in argnums

if solve is None:
solve = linear_solve.solve_normal_cg()

return functools.partial(
_custom_root,
optimality_fn=optimality_fn,
Expand Down
5 changes: 4 additions & 1 deletion torchopt/diff/zero_order/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __init__(self, sample_fn: SampleFunc) -> None:
"""Wrap a sample function to make it a :class:`Samplable` object."""
self.sample_fn = sample_fn

def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor | Sequence[Numeric]:
def sample(
self,
sample_shape: torch.Size = torch.Size(), # noqa: B008
) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
return self.sample_fn(sample_shape)
Expand Down
3 changes: 2 additions & 1 deletion torchopt/diff/zero_order/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor:

@abc.abstractmethod
def sample(
self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument
self,
sample_shape: torch.Size = torch.Size(), # noqa: B008 # pylint: disable=unused-argument
) -> torch.Tensor | Sequence[Numeric]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
Expand Down
4 changes: 2 additions & 2 deletions torchopt/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import torch.distributed.rpc as rpc

from torchopt.distributed import api, autograd, world
from torchopt.distributed.api import *
from torchopt.distributed.world import *
from torchopt.distributed.api import * # noqa: F403
from torchopt.distributed.world import * # noqa: F403


__all__ = ['is_available', *api.__all__, *world.__all__]
Expand Down
6 changes: 3 additions & 3 deletions torchopt/distributed/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def is_available():

if is_available():
# pylint: disable-next=unused-import,ungrouped-imports
from torch.distributed.autograd import DistAutogradContext, get_gradients
from torch.distributed.autograd import DistAutogradContext, get_gradients # noqa: F401

def backward(
autograd_ctx_id: int,
Expand Down Expand Up @@ -69,7 +69,7 @@ def backward(
raise RuntimeError("'inputs' argument to backward() cannot be empty.")
else:
inputs = tuple(inputs)
if not all(map(lambda t: t.requires_grad, inputs)):
if not all(t.requires_grad for t in inputs):
raise RuntimeError('One of the differentiated Tensors does not require grad')

roots = [tensors] if isinstance(tensors, torch.Tensor) else list(tensors)
Expand Down Expand Up @@ -111,7 +111,7 @@ def grad(
"""
outputs = [outputs] if isinstance(outputs, torch.Tensor) else list(outputs)
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
if not all(map(lambda t: t.requires_grad, inputs)):
if not all(t.requires_grad for t in inputs):
raise RuntimeError('One of the differentiated Tensors does not require grad')

autograd.backward(autograd_ctx_id, roots=outputs, retain_graph=retain_graph)
Expand Down
2 changes: 1 addition & 1 deletion torchopt/linalg/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None):
if A.ndim != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f'`A` must be a square matrix, but has shape: {A.shape}')

I = torch.eye(*A.shape, out=torch.empty_like(A))
I = torch.eye(*A.shape, out=torch.empty_like(A)) # noqa: E741
inv_A_hat = torch.zeros_like(A)
if alpha is not None:
# A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...]
Expand Down
3 changes: 2 additions & 1 deletion torchopt/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class Samplable(Protocol): # pylint: disable=too-few-public-methods

@abc.abstractmethod
def sample(
self, sample_shape: Size = Size() # pylint: disable=unused-argument
self,
sample_shape: Size = Size(), # noqa: B008 # pylint: disable=unused-argument
) -> Union[Tensor, Sequence[Numeric]]:
# pylint: disable-next=line-too-long
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
Expand Down

0 comments on commit 845e572

Please sign in to comment.