Skip to content

Commit

Permalink
test: Don't check old PyTorch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
lRomul committed Jan 24, 2024
1 parent 89db71a commit 05fffe9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 56 deletions.
53 changes: 22 additions & 31 deletions argus/callbacks/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
from typing import Optional, Callable, Iterable, Any, Union, List

import torch
from torch.optim import Optimizer
from torch.optim import lr_scheduler as _scheduler

Expand Down Expand Up @@ -370,16 +369,12 @@ def __init__(self,
List[Callable[[int], float]]],
last_epoch: int = -1,
step_on_iteration: bool = False):
from distutils.version import LooseVersion
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
super().__init__(
lambda opt: _scheduler.MultiplicativeLR(opt,
lr_lambda,
last_epoch=last_epoch),
step_on_iteration=step_on_iteration
)
else:
raise ImportError("Update torch>=1.4.0 to use 'MultiplicativeLR'")
super().__init__(
lambda opt: _scheduler.MultiplicativeLR(opt,
lr_lambda,
last_epoch=last_epoch),
step_on_iteration=step_on_iteration
)


class OneCycleLR(LRScheduler):
Expand Down Expand Up @@ -451,23 +446,19 @@ def __init__(self,
div_factor: float = 25.,
final_div_factor: float = 1e4,
last_epoch: int = -1):
from distutils.version import LooseVersion
if LooseVersion(torch.__version__) >= LooseVersion("1.3.0"):
super().__init__(
lambda opt: _scheduler.OneCycleLR(opt,
max_lr,
total_steps=total_steps,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
pct_start=pct_start,
anneal_strategy=anneal_strategy,
cycle_momentum=cycle_momentum,
base_momentum=base_momentum,
max_momentum=max_momentum,
div_factor=div_factor,
final_div_factor=final_div_factor,
last_epoch=last_epoch),
step_on_iteration=True
)
else:
raise ImportError("Update torch>=1.3.0 to use 'OneCycleLR'")
super().__init__(
lambda opt: _scheduler.OneCycleLR(opt,
max_lr,
total_steps=total_steps,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
pct_start=pct_start,
anneal_strategy=anneal_strategy,
cycle_momentum=cycle_momentum,
base_momentum=base_momentum,
max_momentum=max_momentum,
div_factor=div_factor,
final_div_factor=final_div_factor,
last_epoch=last_epoch),
step_on_iteration=True
)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ select = [
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]

[tool.pytest]
[tool.pytest.ini_options]
minversion = 6.0
addopts = "--cov=argus"
testpaths = "tests"
testpaths = ["tests"]
20 changes: 2 additions & 18 deletions tests/callbacks/test_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
from collections import Counter
from distutils.version import LooseVersion

import torch
from torch.optim import lr_scheduler
from torch.optim.optimizer import Optimizer

Expand Down Expand Up @@ -104,9 +102,7 @@ def test_cosine_annealing_lr(self, test_engine):
assert cosine_annealing_lr.scheduler.T_max == 10
assert cosine_annealing_lr.scheduler.eta_min == 0

@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.4.0"),
reason="Requires torch==1.4.0 or higher")
def test_multiplicative_lr(self, test_engine, step_on_iteration, monkeypatch):
def test_multiplicative_lr(self, test_engine, step_on_iteration):
multiplicative_lr = MultiplicativeLR(lambda epoch: 0.95,
step_on_iteration=step_on_iteration)
multiplicative_lr.attach(test_engine)
Expand All @@ -115,26 +111,14 @@ def test_multiplicative_lr(self, test_engine, step_on_iteration, monkeypatch):
assert multiplicative_lr.scheduler.lr_lambdas[0](1) == 0.95
assert multiplicative_lr.step_on_iteration == step_on_iteration

from argus.callbacks.lr_schedulers import torch
monkeypatch.setattr(torch, "__version__", '1.3.0')
with pytest.raises(ImportError):
MultiplicativeLR(lambda epoch: 0.95)

@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.3.0"),
reason="Requires torch==1.3.0 or higher")
def test_one_cycle_lr(self, test_engine, monkeypatch):
def test_one_cycle_lr(self, test_engine):
one_cycle_lr = OneCycleLR(max_lr=0.01, steps_per_epoch=1000, epochs=10)
one_cycle_lr.attach(test_engine)
one_cycle_lr.start(test_engine.state)
assert isinstance(one_cycle_lr.scheduler, lr_scheduler.OneCycleLR)
assert one_cycle_lr.scheduler.total_steps == 10000
assert one_cycle_lr.step_on_iteration

from argus.callbacks.lr_schedulers import torch
monkeypatch.setattr(torch, "__version__", '1.1.0')
with pytest.raises(ImportError):
OneCycleLR(max_lr=0.01, steps_per_epoch=1000, epochs=10)

def test_cosine_annealing_warm_restarts(self, test_engine, step_on_iteration):
warm_restarts = CosineAnnealingWarmRestarts(T_0=1, T_mult=1, eta_min=0,
step_on_iteration=step_on_iteration)
Expand Down
5 changes: 0 additions & 5 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import pytest
from distutils.version import LooseVersion

import torch

from argus.optimizer import get_pytorch_optimizers, _is_pytorch_optimizer
Expand All @@ -19,8 +16,6 @@ def test_is_pytorch_optimizer():
assert not _is_pytorch_optimizer(torch.nn.BCELoss)


@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.7.0"),
reason="Requires torch==1.7.0 or higher")
def test_is_multi_tensor_optimizer():
from torch.optim import _multi_tensor
assert not _is_pytorch_optimizer(_multi_tensor.SGD)
Expand Down

0 comments on commit 05fffe9

Please sign in to comment.