Skip to content

Commit

Permalink
MNT: compat Python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Sep 29, 2024
1 parent 46d6070 commit 63b6942
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions diffeopt/group/ddmatch/action/density.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Any
from typing import Callable, Any, Union
import torch
import numpy as np
from ddmatch.core import ( # type: ignore
Expand All @@ -10,7 +10,7 @@

# TODO: implement forward or backward in numba??

def get_density_action(shape: torch.Size, compute_id: bool=False) -> Callable[[torch.Tensor, torch.Tensor | ForwardDiffeo], torch.Tensor]:
def get_density_action(shape: torch.Size, compute_id: bool=False) -> Callable[[torch.Tensor, Union[torch.Tensor, ForwardDiffeo]], torch.Tensor]:
image = np.zeros(shape)
compute_grad = generate_optimized_image_gradient(image)
compute_pullback = generate_optimized_image_composition(image)
Expand All @@ -20,7 +20,7 @@ class DensityAction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx: Any, x: torch.Tensor, g: torch.Tensor | ForwardDiffeo) -> torch.Tensor:
def forward(ctx: Any, x: torch.Tensor, g: Union[torch.Tensor, ForwardDiffeo]) -> torch.Tensor:
ctx.save_for_backward(x)
if isinstance(g, ForwardDiffeo):
torch_data = g.forward
Expand Down
8 changes: 4 additions & 4 deletions diffeopt/group/ddmatch/action/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any, Callable, Union, Optional
import torch
import numpy as np

Expand All @@ -8,7 +8,7 @@

from ..group import DiffeoGroup

def get_composition_action(shape: torch.Size, compute_id:bool=False) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]:
def get_composition_action(shape: torch.Size, compute_id:bool=False) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]:
"""
compute_id: right composition with the identity
is the identity map, but it takes time to compute.
Expand All @@ -32,7 +32,7 @@ class CompositionAction(torch.autograd.Function):
and its derivative wrt g.
"""
@staticmethod
def forward(ctx: Any, q: torch.Tensor, g: torch.Tensor | Diffeo) -> torch.Tensor:
def forward(ctx: Any, q: torch.Tensor, g: Union[torch.Tensor, Diffeo]) -> torch.Tensor:
if isinstance(g, Diffeo):
torch_data = g.forward
to_save = g.backward
Expand All @@ -52,7 +52,7 @@ def forward(ctx: Any, q: torch.Tensor, g: torch.Tensor | Diffeo) -> torch.Tensor
return res

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: # type: ignore[override]
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: # type: ignore[override]
"""
This is the adjoint of the derivative only
if g was the identity.
Expand Down
6 changes: 3 additions & 3 deletions diffeopt/group/ddmatch/representation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Union
import torch

from ..representation import Representation
Expand All @@ -7,11 +7,11 @@
from .action.function import get_composition_action

class FunctionRepresentation(Representation):
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]:
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]:
return get_composition_action(group.shape)

from .action.density import get_density_action

Check failure on line 13 in diffeopt/group/ddmatch/representation.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (E402)

diffeopt/group/ddmatch/representation.py:13:1: E402 Module level import not at top of file

Check failure on line 13 in diffeopt/group/ddmatch/representation.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (E402)

diffeopt/group/ddmatch/representation.py:13:1: E402 Module level import not at top of file

Check failure on line 13 in diffeopt/group/ddmatch/representation.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E402)

diffeopt/group/ddmatch/representation.py:13:1: E402 Module level import not at top of file

Check failure on line 13 in diffeopt/group/ddmatch/representation.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (E402)

diffeopt/group/ddmatch/representation.py:13:1: E402 Module level import not at top of file

class DensityRepresentation(Representation):
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]:
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]:
return get_density_action(group.shape)
4 changes: 2 additions & 2 deletions diffeopt/group/representation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, Union
from .deformation import Deformation
import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -39,5 +39,5 @@ def forward(self, I: torch.Tensor) -> torch.Tensor:
return self.representation(self.representation(I, self.perturbation.base.deformation), self.perturbation)

@abstractmethod
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]:
def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]:
pass
6 changes: 3 additions & 3 deletions diffeopt/optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Any
from typing import Callable, Any, Optional, Union

Check failure on line 1 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

diffeopt/optim.py:1:45: F401 `typing.Union` imported but unused

Check failure on line 1 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (F401)

diffeopt/optim.py:1:45: F401 `typing.Union` imported but unused

Check failure on line 1 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

diffeopt/optim.py:1:45: F401 `typing.Union` imported but unused

Check failure on line 1 in diffeopt/optim.py

View workflow job for this annotation

GitHub Actions / build (3.12)

Ruff (F401)

diffeopt/optim.py:1:45: F401 `typing.Union` imported but unused
from abc import ABC, abstractmethod
import torch
from .group.representation import Perturbation
Expand All @@ -8,8 +8,8 @@
class DiffeoOptimizer(Optimizer, ABC):

@torch.no_grad()
def step(self, closure:Callable[[], float] | None = None) -> float | None: # type: ignore[override]
loss: float | None = None
def step(self, closure:Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
loss: Optional[float] = None
if closure is not None:
with torch.enable_grad():
loss = closure()
Expand Down

0 comments on commit 63b6942

Please sign in to comment.