Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix optree compatibility for multi-tree-map with None values #195

Merged
merged 11 commits into from
Nov 9, 2023
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.9"
update-environment: true

- name: Setup CUDA Toolkit
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:
default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,11 +26,11 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
rev: v17.0.4
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
rev: v0.1.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -39,11 +39,11 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.11.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -68,7 +68,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).

### Fixed

-
- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).

### Removed

Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent
project(torchopt LANGUAGES CXX)

include(FetchContent)
set(PYBIND11_VERSION v2.10.3)
set(PYBIND11_VERSION v2.11.1)

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Threads REQUIRED) # -pthread
Expand Down
1 change: 0 additions & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ dependencies:
- hunspell-en
- myst-nb
- ipykernel
- pandoc
- docutils

# Testing
Expand Down
1 change: 0 additions & 1 deletion docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,4 @@ dependencies:
- hunspell-en
- myst-nb
- ipykernel
- pandoc
- docutils
10 changes: 5 additions & 5 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ torch >= 1.13

--requirement ../requirements.txt

sphinx >= 5.2.1
sphinx >= 5.2.1, < 7.0.0a0
sphinxcontrib-bibtex >= 2.4
sphinx-autodoc-typehints >= 1.20
myst-nb >= 0.15

sphinx-autoapi
sphinx-autobuild
sphinx-copybutton
sphinx-rtd-theme
sphinxcontrib-katex
sphinxcontrib-bibtex
sphinx-autodoc-typehints >= 1.19.2
IPython
ipykernel
pandoc
myst-nb
docutils
matplotlib
1 change: 1 addition & 0 deletions torchopt/alias/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
__all__ = ['sgd']


# pylint: disable-next=too-many-arguments
def sgd(
lr: ScalarOrSchedule,
momentum: float = 0.0,
Expand Down
26 changes: 15 additions & 11 deletions torchopt/alias/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.add(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand All @@ -139,7 +141,7 @@ def update_fn(
def f(g: torch.Tensor) -> torch.Tensor:
return g.neg_()

updates = tree_map_(f, updates)
tree_map_(f, updates)

else:

Expand All @@ -166,19 +168,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.neg_().add_(p, alpha=weight_decay)
return g.neg_().add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.neg().add_(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.neg().add_(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
2 changes: 2 additions & 0 deletions torchopt/distributed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
return torch.sum(torch.stack(tuple(results), dim=0), dim=0)


# pylint: disable-next=too-many-arguments
def remote_async_call(
func: Callable[..., T],
*,
Expand Down Expand Up @@ -328,6 +329,7 @@ def remote_async_call(
return future


# pylint: disable-next=too-many-arguments
def remote_sync_call(
func: Callable[..., T],
*,
Expand Down
4 changes: 3 additions & 1 deletion torchopt/linalg/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree:
return x


# pylint: disable-next=too-many-locals
# pylint: disable-next=too-many-arguments,too-many-locals
def _cg_solve(
A: Callable[[TensorTree], TensorTree],
b: TensorTree,
Expand Down Expand Up @@ -102,6 +102,7 @@ def body_fn(
return x_final


# pylint: disable-next=too-many-arguments
def _isolve(
_isolve_solve: Callable,
A: TensorTree | Callable[[TensorTree], TensorTree],
Expand Down Expand Up @@ -134,6 +135,7 @@ def _isolve(
return isolve_solve(A, b)


# pylint: disable-next=too-many-arguments
def cg(
A: TensorTree | Callable[[TensorTree], TensorTree],
b: TensorTree,
Expand Down
2 changes: 2 additions & 0 deletions torchopt/linalg/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch.
# A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...]
M = I - alpha * A
for rank in range(maxiter):
# pylint: disable-next=not-callable
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank)
inv_A_hat = alpha * inv_A_hat
else:
# A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ...
M = I - A
for rank in range(maxiter):
# pylint: disable-next=not-callable
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank)
return inv_A_hat

Expand Down
4 changes: 2 additions & 2 deletions torchopt/nn/stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
mod._parameters[attr] = value # type: ignore[assignment]
elif hasattr(mod, '_buffers') and attr in mod._buffers:
mod._buffers[attr] = value
elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator]
mod._meta_parameters[attr] = value # type: ignore[operator,index]
elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters:
mod._meta_parameters[attr] = value

Check warning on line 70 in torchopt/nn/stateless.py

View check run for this annotation

Codecov / codecov/patch

torchopt/nn/stateless.py#L69-L70

Added lines #L69 - L70 were not covered by tests
else:
setattr(mod, attr, value)
# pylint: enable=protected-access
Expand Down
12 changes: 7 additions & 5 deletions torchopt/transform/add_decayed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.add(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
18 changes: 5 additions & 13 deletions torchopt/transform/scale_by_adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,15 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_())
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return g.mul(v.add(eps).div_(m.add(eps)).sqrt_())
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g

updates = tree_map(f, updates, mu, state.nu)
updates = tree_map(f, mu, state.nu, updates)

nu = update_moment.impl( # type: ignore[attr-defined]
updates,
Expand Down
20 changes: 7 additions & 13 deletions torchopt/transform/scale_by_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _scale_by_adam_flat(
)


# pylint: disable-next=too-many-arguments
def _scale_by_adam(
b1: float = 0.9,
b2: float = 0.999,
Expand Down Expand Up @@ -200,23 +201,15 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return m.div_(v.add_(eps_root).sqrt_().add(eps))
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return m.div(v.add(eps_root).sqrt_().add(eps))
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g

updates = tree_map(f, updates, mu_hat, nu_hat)
updates = tree_map(f, mu_hat, nu_hat, updates)
return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc)

return GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat(
)


# pylint: disable-next=too-many-arguments
def _scale_by_accelerated_adam(
b1: float = 0.9,
b2: float = 0.999,
Expand Down
18 changes: 6 additions & 12 deletions torchopt/transform/scale_by_adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,17 @@ def update_fn(
already_flattened=already_flattened,
)

def update_nu(
g: torch.Tensor,
n: torch.Tensor,
) -> torch.Tensor:
return torch.max(n.mul(b2), g.abs().add_(eps))
def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g

nu = tree_map(update_nu, updates, state.nu)
nu = tree_map(update_nu, state.nu, updates)

one_minus_b1_pow_t = 1 - b1**state.t

def f(
n: torch.Tensor,
m: torch.Tensor,
) -> torch.Tensor:
return m.div(n).div_(one_minus_b1_pow_t)
def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor:
return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m

updates = tree_map(f, nu, mu)
updates = tree_map(f, mu, nu)

return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1)

Expand Down
Loading