Skip to content

Commit

Permalink
feat: add ruff integration (metaopt#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Feb 19, 2023
1 parent c677243 commit 6fa85d6
Show file tree
Hide file tree
Showing 68 changed files with 408 additions and 358 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ jobs:
run: |
make pre-commit
- name: ruff
run: |
make ruff
- name: flake8
run: |
make flake8
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ repos:
hooks:
- id: clang-format
stages: [commit, push, manual]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.247
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
stages: [commit, push, manual]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139).

### Changed

Expand Down
16 changes: 13 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ py-format-install:
$(call check_pip_install,isort)
$(call check_pip_install_extra,black,black[jupyter])

ruff-install:
$(call check_pip_install,ruff)

mypy-install:
$(call check_pip_install,mypy)

Expand All @@ -79,7 +82,7 @@ docs-install:
$(call check_pip_install,sphinxcontrib-bibtex)
$(call check_pip_install,sphinx-autodoc-typehints)
$(call check_pip_install,myst-nb)
$(call check_pip_install_extra,sphinxcontrib.spelling,sphinxcontrib.spelling pyenchant)
$(call check_pip_install_extra,sphinxcontrib-spelling,sphinxcontrib-spelling pyenchant)

pytest-install:
$(call check_pip_install,pytest)
Expand Down Expand Up @@ -132,6 +135,12 @@ py-format: py-format-install
$(PYTHON) -m isort --project $(PROJECT_NAME) --check $(PYTHON_FILES) && \
$(PYTHON) -m black --check $(PYTHON_FILES) tutorials

ruff: ruff-install
$(PYTHON) -m ruff check .

ruff-fix: ruff-install
$(PYTHON) -m ruff check . --fix --exit-non-zero-on-fix

mypy: mypy-install
$(PYTHON) -m mypy $(PROJECT_PATH)

Expand Down Expand Up @@ -181,11 +190,12 @@ clean-docs:

# Utility functions

lint: flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling
lint: ruff flake8 py-format mypy pylint clang-format clang-tidy cpplint addlicense docstyle spelling

format: py-format-install clang-format-install addlicense-install
format: py-format-install ruff-install clang-format-install addlicense-install
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
$(PYTHON) -m black $(PYTHON_FILES) tutorials
$(PYTHON) -m ruff check . --fix --exit-zero
$(CLANG_FORMAT) -style=file -i $(CXX_FILES) $(CUDA_FILES)
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)

Expand Down
1 change: 1 addition & 0 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ dependencies:
- flake8-docstrings
- flake8-pyi
- flake8-simplify
- ruff
- doc8
- pydocstyle
- clang-format >= 14
Expand Down
7 changes: 4 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import pathlib
import sys

import sphinx
import sphinxcontrib.katex as katex


Expand All @@ -39,7 +40,7 @@

def get_version() -> str:
sys.path.insert(0, str(PROJECT_ROOT / 'torchopt'))
import version # noqa
import version

return version.__version__

Expand All @@ -51,7 +52,7 @@ def get_version() -> str:
else:

class RecursiveForwardRefFilter(logging.Filter):
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
if (
"name 'TensorTree' is not defined" in record.getMessage()
or "name 'OptionalTensorTree' is not defined" in record.getMessage()
Expand Down Expand Up @@ -191,7 +192,7 @@ def filter(self, record):
html_logo = '_static/images/logo.png'


def setup(app):
def setup(app: sphinx.application.Sphinx) -> None:
app.add_js_file('https://cdn.jsdelivr.net/npm/[email protected]')
app.add_js_file('https://cdn.jsdelivr.net/npm/[email protected]')
app.add_js_file('https://cdn.jsdelivr.net/npm/[email protected]')
Expand Down
2 changes: 1 addition & 1 deletion docs/source/developer/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Lint Check

We use several tools to secure code quality, including:

* PEP8 code style: ``black``, ``isort``, ``pylint``, ``flake8``
* Python code style: ``black``, ``isort``, ``pylint``, ``flake8``, ``ruff``
* Type hint check: ``mypy``
* C++ Google-style: ``cpplint``, ``clang-format``, ``clang-tidy``
* License: ``addlicense``
Expand Down
2 changes: 1 addition & 1 deletion examples/FuncTorch/maml_omniglot_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test(db, net, device, epoch, log):
qry_losses = []
qry_accs = []

for batch_idx in range(n_test_iter):
for _ in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()

Expand Down
2 changes: 0 additions & 2 deletions examples/FuncTorch/parallel_train_torchopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import argparse
import math
from collections import namedtuple
from typing import Any, NamedTuple

import functorch
import torch
Expand Down
2 changes: 1 addition & 1 deletion examples/L2R/helpers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class LeNet5(nn.Module):
def __init__(self, args):
super(LeNet5, self).__init__()
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 16, 5),
nn.ReLU(),
Expand Down
9 changes: 1 addition & 8 deletions examples/L2R/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,10 @@ def get_imbalance_dataset(
y_val_subset = np.concatenate([np.zeros([x_val_0.shape[0]]), np.ones([x_val_1.shape[0]])])
y_test_subset = np.concatenate([np.zeros([x_test_0.shape[0]]), np.ones([x_test_1.shape[0]])])

y_train_pos_subset = np.ones([x_train_1.shape[0]])
y_train_neg_subset = np.zeros([x_train_0.shape[0]])

x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:, None, :, :]
x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :]
x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :]

x_train_pos_subset = x_train_1[:, None, :, :]
x_train_neg_subset = x_train_0[:, None, :, :]

# Final shuffle.
idx = np.arange(x_train_subset.shape[0])
np.random.shuffle(idx)
Expand Down Expand Up @@ -146,7 +140,7 @@ def set_seed(seed, cudnn=True):
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# note: the below slows down the code but makes it reproducible
# Sets the seed for generating random numbers on all GPUs. Its safe to
# Sets the seed for generating random numbers on all GPUs. It's safe to
# call this function if CUDA is not available; in that case, it is
# silently ignored.
torch.cuda.manual_seed_all(seed)
Expand All @@ -157,7 +151,6 @@ def set_seed(seed, cudnn=True):

def plot(baseline, l2r):
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sns.set(style='darkgrid')
Expand Down
12 changes: 4 additions & 8 deletions examples/L2R/l2r.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,13 @@ def run_baseline(args, mnist_train, mnist_test):
ntest = args.ntest
epoch = args.epoch

folder = './result/baseline/'
writer = SummaryWriter('./result/baseline')
with open('./result/baseline/config.json', 'w') as f:
json.dump(args.__dict__, f)

args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_set, val_set, test_set = get_imbalance_dataset(
train_set, _, test_set = get_imbalance_dataset(
mnist_train,
mnist_test,
pos_ratio=pos_ratio,
Expand All @@ -67,7 +66,6 @@ def run_baseline(args, mnist_train, mnist_test):
ntest=ntest,
)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
model = LeNet5(args).to(args.device)

Expand All @@ -91,7 +89,7 @@ def run_baseline(args, mnist_train, mnist_test):

if step % 10 == 0 and step > 0:
running_train_mean = np.mean(np.array(running_train_loss))
print('EPOCH: {}, BATCH: {}, LOSS: {}'.format(_epoch, idx, running_train_mean))
print(f'EPOCH: {_epoch}, BATCH: {idx}, LOSS: {running_train_mean}')
writer.add_scalar('running_train_loss', running_train_mean, step)
running_train_loss = []

Expand All @@ -106,7 +104,7 @@ def run_baseline(args, mnist_train, mnist_test):
writer.add_scalar('train_acc', train_acc, _epoch)
writer.add_scalar('test_acc', test_acc, _epoch)
test_acc_result.append(test_acc)
print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc))
print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}')
return test_acc_result


Expand All @@ -120,7 +118,6 @@ def run_L2R(args, mnist_train, mnist_test):
ntest = args.ntest
epoch = args.epoch

folder = './result/l2r/'
writer = SummaryWriter('./result/l2r/log')
with open('./result/l2r/config.json', 'w') as f:
json.dump(args.__dict__, f)
Expand All @@ -143,7 +140,6 @@ def run_L2R(args, mnist_train, mnist_test):
real_model_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

step = 0
time_bp = 0
running_valid_loss = []
valid = iter(valid_loader)
running_train_loss = []
Expand Down Expand Up @@ -222,7 +218,7 @@ def run_L2R(args, mnist_train, mnist_test):
writer.add_scalar('train_acc', train_acc, _epoch)
writer.add_scalar('test_acc', test_acc, _epoch)
test_acc_result.append(test_acc)
print('EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}'.format(_epoch, train_acc, test_acc))
print(f'EPOCH: {_epoch}, TRAIN_ACC: {train_acc}, TEST_ACC: {test_acc}')
return test_acc_result


Expand Down
4 changes: 2 additions & 2 deletions examples/LOLA/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def step(ipd, theta1, theta2, values1, values2, args):
(s1, s2), _ = ipd.reset()
score1 = 0
score2 = 0
for t in range(args.len_rollout):
for _ in range(args.len_rollout):
a1, lp1, v1 = act(s1, theta1, values1)
a2, lp2, v2 = act(s2, theta2, values2)
(s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
Expand Down Expand Up @@ -109,7 +109,7 @@ def sample(ipd, policy, value, args):
(s1, s2), _ = ipd.reset()
memory_agent1 = Memory(args)
memory_agent2 = Memory(args)
for t in range(args.len_rollout):
for _ in range(args.len_rollout):
a1, lp1, v1 = act(s1, theta1, value1)
a2, lp2, v2 = act(s2, theta2, value2)
(s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
Expand Down
7 changes: 3 additions & 4 deletions examples/LOLA/lola_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,24 @@ def main(args):
score = step(ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args)
joint_scores.append(0.5 * (score[0] + score[1]))

# print
if update % 10 == 0:
p1 = [p.item() for p in torch.sigmoid(agent1.theta)]
p2 = [p.item() for p in torch.sigmoid(agent2.theta)]
print(
'update',
update,
'score (%.3f,%.3f)' % (score[0], score[1]),
f'score ({score[0]:.3f},{score[1]:.3f})',
'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}'
% (p1[0], p1[1], p1[2], p1[3], p1[4]),
' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4]),
f' (agent2) = {{{p2[0]:.3f}, {p2[1]:.3f}, {p2[2]:.3f}, {p2[3]:.3f}, {p2[4]:.3f}}}',
)

return joint_scores


if __name__ == '__main__':
args = parse_args()
joint_score = dict()
joint_score = {}
for nla in range(3):
args.n_lookaheads = nla
joint_score[nla] = main(args)
Expand Down
14 changes: 8 additions & 6 deletions examples/MAML-RL/func_maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def evaluate(env, seed, task_num, fpolicy, params):
inner_opt = torchopt.MetaSGD(lr=0.5)
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
),
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
tasks = env.sample_tasks(num_tasks=task_num)

Expand All @@ -131,9 +132,10 @@ def main(args):
# Env
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM, num_actions=ACTION_DIM, max_episode_steps=TRAJ_LEN, seed=args.seed
),
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
Expand Down
2 changes: 0 additions & 2 deletions examples/MAML-RL/helpers/policy_torchrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.
# ==============================================================================

import torch
import torch.nn as nn
from torch.distributions import Categorical
from torchrl.modules import (
ActorValueOperator,
OneHotCategorical,
Expand Down
6 changes: 1 addition & 5 deletions examples/MAML-RL/helpers/tabular_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def reset_task(self, task):
self._rewards_mean = task['rewards_mean']

def reset(self):
# From [1]: "an episode always starts on the first state"
self._state = 0
observation = np.zeros(self.num_states, dtype=np.float32)
observation[self._state] = 1.0
Expand All @@ -112,8 +111,5 @@ def step(self, action):
observation = np.zeros(self.num_states, dtype=np.float32)
observation[self._state] = 1.0
self._elapsed_steps += 1
if self._elapsed_steps >= self.max_episode_steps:
done = True
else:
done = False
done = self._elapsed_steps >= self.max_episode_steps
return observation, reward, done, {'task': self._task}
20 changes: 8 additions & 12 deletions examples/MAML-RL/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,10 @@ def evaluate(env, seed, task_num, policy):
inner_opt = torchopt.MetaSGD(policy, lr=0.1)
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
),
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
tasks = env.sample_tasks(num_tasks=task_num)
policy_state_dict = torchopt.extract_state_dict(policy)
Expand Down Expand Up @@ -141,12 +139,10 @@ def main(args):
# Env
env = gym.make(
'TabularMDP-v0',
**dict(
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
),
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
Expand Down
Loading

0 comments on commit 6fa85d6

Please sign in to comment.