Skip to content

Commit

Permalink
Merge pull request #3 from NVlabs/ruff-linting
Browse files Browse the repository at this point in the history
Add ruff linting and pre-commit hooks
  • Loading branch information
nbren12 authored May 31, 2024
2 parents e66f4fe + 00b4901 commit 9661ddb
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 31 deletions.
18 changes: 8 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,21 @@ repos:
- id: check-merge-conflict
- id: check-yaml
args: [ --unsafe ]
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.8.0
hooks:
- id: isort
args: [ "--filter-files" ]
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
additional_dependencies: ["click==8.0.4"]
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
exclude: conf.py
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
exclude: tests/
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.5
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
exclude: ^docs/
15 changes: 10 additions & 5 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ def to_pyvista(self):
points = healpix_bare.corners(nside, torch.from_numpy(pix), True).numpy()
out = einops.rearrange(points, "n d s -> (n s) d")
unique_points, inverse = np.unique(out, return_inverse=True, axis=0)
assert unique_points.ndim == 2
assert unique_points.shape[1] == 3
if unique_points.ndim != 2:
raise ValueError(f"unique_points.ndim should be 2, got {unique_points.ndim}.")
if unique_points.shape[1] != 3:
raise ValueError(f"unique_points.shape[1] should be 3, got {unique_points.shape[1]}.")
inverse = einops.rearrange(inverse, "(n s) -> n s", n=pix.size)
n, s = inverse.shape
cells = np.ones_like(inverse, shape=(n, s + 1))
Expand Down Expand Up @@ -335,7 +337,8 @@ def _rotate_index(nside: int, rotations: int, i):
# Reduce k to its equivalent in the range [0, 3]
k = rotations % 4

assert 0 <= k < 4
if k < 0 or k >= 4:
raise ValueError(f"k not in [0, 3], got {k}")

# Apply the rotation based on k
if k == 1: # 90 degrees counterclockwise
Expand Down Expand Up @@ -399,13 +402,15 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
"""
px, py = padding
assert px == py
if px != py:
raise ValueError(f"Padding should be equal in x and y, got px={px}, py={py}")

n, c, x, y = input.shape
npix = input.size(-1)
nside2 = npix // 12
nside = int(math.sqrt(nside2))
assert nside**2 * 12 == npix
if nside**2 * 12 != npix:
raise ValueError(f"Incompatible npix ({npix}) and nside ({nside})")

input = einops.rearrange(input, "n c () (f x y) -> (n c) f x y", f=12, x=nside)
input = pad(input, px)
Expand Down
6 changes: 3 additions & 3 deletions earth2grid/third_party/zephyr/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def pn(
c: th.Tensor,
t: th.Tensor,
tl: th.Tensor,
l: th.Tensor,
l: th.Tensor, # noqa: E741
bl: th.Tensor,
b: th.Tensor,
br: th.Tensor,
Expand Down Expand Up @@ -156,7 +156,7 @@ def pe(
c: th.Tensor,
t: th.Tensor,
tl: th.Tensor,
l: th.Tensor,
l: th.Tensor, # noqa: E741
bl: th.Tensor,
b: th.Tensor,
br: th.Tensor,
Expand Down Expand Up @@ -193,7 +193,7 @@ def ps(
c: th.Tensor,
t: th.Tensor,
tl: th.Tensor,
l: th.Tensor,
l: th.Tensor, # noqa: E741
bl: th.Tensor,
b: th.Tensor,
br: th.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license:
python tests/_license/header_check.py

format: license
isort $(sources) tests
ruff check --fix $(sources) tests
black $(sources) tests

lint: license
Expand Down
74 changes: 65 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ dev = [
"pip>=20.3.1",
"twine>=3.3.0",
"toml>=0.10.2",
"bump2version>=1.0.1"
"bump2version>=1.0.1",
"ruff>=0.1.5"
]
doc = [
"sphinx",
Expand Down Expand Up @@ -79,11 +80,66 @@ exclude = '''
)/
'''

[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
skip_gitignore = true
[tool.ruff]
# Enable flake8/pycodestyle (`E`), Pyflakes (`F`), flake8-bandit (`S`),
# isort (`I`), and performance 'PERF' rules.
select = ["E", "F", "S", "I", "PERF"]
fixable = ["ALL"]

# Never enforce `E402`, `E501` (line length violations),
# and `S311` (random number generators)
ignore = ["E501", "S311"]

# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 120
indent-width = 4

target-version = 'py38'

[tool.ruff.per-file-ignores]
# Ignore `F401` (import violations) in all `__init__.py` files, and in `docs/*.py`.
"__init__.py" = ["F401", "E402"]
"docs/*.py" = ["F401"]
"**/{tests,docs,tools}/*" = ["E402"]
"**/tests/**/*.py" = [
# at least this three should be fine in tests:
"S101", # asserts allowed in tests...
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
]
"**/test_*.py" = [
# at least this three should be fine in tests:
"S101", # asserts allowed in tests...
"ARG", # Unused function args -> fixtures nevertheless are functionally relevant...
"FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize()
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_compiler():
try:
# Try to get the compiler path from the CC environment variable
# If not set, it will default to gcc (which could be symlinked to clang or g++)
compiler = subprocess.check_output(["gcc", "--version"], universal_newlines=True)
compiler = subprocess.check_output(["gcc", "--version"], universal_newlines=True) # noqa: S603, S607

if "clang" in compiler:
return "clang"
Expand Down
3 changes: 1 addition & 2 deletions tests/test_latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import torch
# limitations under the License.
import torch

from earth2grid.latlon import equiangular_lat_lon_grid


Expand Down

0 comments on commit 9661ddb

Please sign in to comment.