diff --git a/diffeopt/group/ddmatch/action/density.py b/diffeopt/group/ddmatch/action/density.py index d170ac6..46667b5 100644 --- a/diffeopt/group/ddmatch/action/density.py +++ b/diffeopt/group/ddmatch/action/density.py @@ -1,4 +1,4 @@ -from typing import Callable, Any +from typing import Callable, Any, Union import torch import numpy as np from ddmatch.core import ( # type: ignore @@ -10,7 +10,7 @@ # TODO: implement forward or backward in numba?? -def get_density_action(shape: torch.Size, compute_id: bool=False) -> Callable[[torch.Tensor, torch.Tensor | ForwardDiffeo], torch.Tensor]: +def get_density_action(shape: torch.Size, compute_id: bool=False) -> Callable[[torch.Tensor, Union[torch.Tensor | ForwardDiffeo]], torch.Tensor]: image = np.zeros(shape) compute_grad = generate_optimized_image_gradient(image) compute_pullback = generate_optimized_image_composition(image) diff --git a/diffeopt/group/ddmatch/action/function.py b/diffeopt/group/ddmatch/action/function.py index a891456..0504496 100644 --- a/diffeopt/group/ddmatch/action/function.py +++ b/diffeopt/group/ddmatch/action/function.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any, Callable, Union, Optional import torch import numpy as np @@ -8,7 +8,7 @@ from ..group import DiffeoGroup -def get_composition_action(shape: torch.Size, compute_id:bool=False) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]: +def get_composition_action(shape: torch.Size, compute_id:bool=False) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]: """ compute_id: right composition with the identity is the identity map, but it takes time to compute. @@ -32,7 +32,7 @@ class CompositionAction(torch.autograd.Function): and its derivative wrt g. """ @staticmethod - def forward(ctx: Any, q: torch.Tensor, g: torch.Tensor | Diffeo) -> torch.Tensor: + def forward(ctx: Any, q: torch.Tensor, g: Union[torch.Tensor, Diffeo]) -> torch.Tensor: if isinstance(g, Diffeo): torch_data = g.forward to_save = g.backward @@ -52,7 +52,7 @@ def forward(ctx: Any, q: torch.Tensor, g: torch.Tensor | Diffeo) -> torch.Tensor return res @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: # type: ignore[override] + def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: # type: ignore[override] """ This is the adjoint of the derivative only if g was the identity. diff --git a/diffeopt/group/ddmatch/representation.py b/diffeopt/group/ddmatch/representation.py index 3316cf3..e490c81 100644 --- a/diffeopt/group/ddmatch/representation.py +++ b/diffeopt/group/ddmatch/representation.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Union import torch from ..representation import Representation @@ -7,11 +7,11 @@ from .action.function import get_composition_action class FunctionRepresentation(Representation): - def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]: + def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]: return get_composition_action(group.shape) from .action.density import get_density_action class DensityRepresentation(Representation): - def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]: + def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]: return get_density_action(group.shape) diff --git a/diffeopt/group/representation.py b/diffeopt/group/representation.py index f2917a5..c145434 100644 --- a/diffeopt/group/representation.py +++ b/diffeopt/group/representation.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Union from .deformation import Deformation import torch from torch.nn.parameter import Parameter @@ -39,5 +39,5 @@ def forward(self, I: torch.Tensor) -> torch.Tensor: return self.representation(self.representation(I, self.perturbation.base.deformation), self.perturbation) @abstractmethod - def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, torch.Tensor | Diffeo], torch.Tensor]: + def get_representation(self, group: BaseDiffeoGroup) -> Callable[[torch.Tensor, Union[torch.Tensor, Diffeo]], torch.Tensor]: pass diff --git a/diffeopt/optim.py b/diffeopt/optim.py index fed853d..ac56f7f 100644 --- a/diffeopt/optim.py +++ b/diffeopt/optim.py @@ -1,4 +1,4 @@ -from typing import Callable, Any +from typing import Callable, Any, Optional, Union from abc import ABC, abstractmethod import torch from .group.representation import Perturbation @@ -8,8 +8,8 @@ class DiffeoOptimizer(Optimizer, ABC): @torch.no_grad() - def step(self, closure:Callable[[], float] | None = None) -> float | None: # type: ignore[override] - loss: float | None = None + def step(self, closure:Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] + loss: Optional[float] = None if closure is not None: with torch.enable_grad(): loss = closure()