Skip to content

Commit

Permalink
feat: vmap, use setup_ctx syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Jul 10, 2024
1 parent 7ae569b commit bc0fc5f
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
from torch.autograd import Function
from typing import Any, Tuple, Optional, Callable
from itertools import starmap
from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128


Expand Down Expand Up @@ -156,9 +157,7 @@ def lpc_np(x: np.ndarray, A: np.ndarray, zi: np.ndarray) -> np.ndarray:

class LPC(Function):
@staticmethod
def forward(
ctx: Any, x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor
) -> torch.Tensor:
def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
else:
Expand All @@ -168,14 +167,21 @@ def forward(
zi.detach().cpu().numpy(),
)
y = torch.from_numpy(y).to(x.device, x.dtype)
ctx.save_for_backward(A, zi, y)
# ctx.save_for_backward(A, zi, y)

# for jvp
ctx.y = y
ctx.A = A
ctx.zi = zi
# # for jvp
# ctx.y = y
# ctx.A = A
# ctx.zi = zi
return y

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
_, A, zi = inputs
y = output
ctx.save_for_backward(A, zi, y)
ctx.save_for_forward(A, zi, y)

@staticmethod
def backward(
ctx, grad_y: torch.Tensor
Expand Down Expand Up @@ -219,19 +225,19 @@ def backward(
unfolded_y = padded_y.unfold(1, order, 1).flip(2)
grad_A = unfolded_y.conj_physical() * -flipped_grad_x.flip(1).unsqueeze(2)

if hasattr(ctx, "y"):
del ctx.y
if hasattr(ctx, "A"):
del ctx.A
if hasattr(ctx, "zi"):
del ctx.zi
# if hasattr(ctx, "y"):
# del ctx.y
# if hasattr(ctx, "A"):
# del ctx.A
# if hasattr(ctx, "zi"):
# del ctx.zi
return grad_x, grad_A, grad_zi

@staticmethod
def jvp(
ctx: Any, grad_x: torch.Tensor, grad_A: torch.Tensor, grad_zi: torch.Tensor
) -> torch.Tensor:
A, y, zi = ctx.A, ctx.y, ctx.zi
A, zi, y = ctx.saved_tensors
*_, order = A.shape

fwd_zi = grad_zi if grad_zi is not None else torch.zeros_like(zi)
Expand All @@ -244,5 +250,22 @@ def jvp(
fwd_A = -torch.sum(unfolded_y * grad_A, dim=2)
fwd_x = fwd_x + fwd_A

del ctx.y, ctx.A, ctx.zi
# del ctx.y, ctx.A, ctx.zi
return LPC.apply(fwd_x, A, fwd_zi)

@staticmethod
def vmap(info, in_dims, *args):
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)

x, A, zi = tuple(
map(
lambda x: x.reshape(-1, *x.shape[2:]),
starmap(maybe_expand_bdim_at_front, zip(args, in_dims)),
)
)

y = LPC.apply(x, A, zi)
return y.reshape(info.batch_size, -1, *y.shape[1:]), 0

0 comments on commit bc0fc5f

Please sign in to comment.