Skip to content

Commit

Permalink
draft: jvp and vmap for recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Sep 9, 2024
1 parent 3a0fe52 commit bdb72be
Showing 1 changed file with 56 additions and 3 deletions.
59 changes: 56 additions & 3 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import torch.nn.functional as F
from torch.autograd import Function
from numba import cuda
from typing import Tuple, Optional, Any, List

from .parallel_scan import compute_linear_recurrence


class RecurrenceCUDA(Function):
@staticmethod
def forward(
ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor
ctx: Any,
decay: torch.Tensor,
impulse: torch.Tensor,
initial_state: torch.Tensor,
) -> torch.Tensor:
n_dims, n_steps = decay.shape
out = torch.empty_like(impulse)
Expand All @@ -21,11 +25,18 @@ def forward(
n_dims,
n_steps,
)
ctx.save_for_backward(decay, initial_state, out)
return out

@staticmethod
def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor:
def setup_context(ctx: Any, inputs: List[Any], output: Any) -> Any:
decay, _, initial_state = inputs
ctx.save_for_backward(decay, initial_state, output)
ctx.save_for_forward(decay, initial_state, output)

@staticmethod
def backward(
ctx: Any, grad_out: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
decay, initial_state, out = ctx.saved_tensors
grad_decay = grad_impulse = grad_initial_state = None
n_dims, _ = decay.shape
Expand Down Expand Up @@ -57,3 +68,45 @@ def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor:
grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1)

return grad_decay, grad_impulse, grad_initial_state

@staticmethod
def jvp(
ctx: Any,
grad_decay: torch.Tensor,
grad_impulse: torch.Tensor,
grad_initial_state: torch.Tensor,
) -> torch.Tensor:
decay, initial_state, out = ctx.saved_tensors

fwd_initial_state = (
grad_initial_state
if grad_initial_state is not None
else torch.zeros_like(initial_state)
)
fwd_impulse = (
grad_impulse if grad_impulse is not None else torch.zeros_like(out)
)

if grad_decay is not None:
concat_out = torch.cat([initial_state.unsqueeze(1), out[:, :-1]], dim=1)
fwd_decay = -concat_out * grad_decay
fwd_impulse = fwd_impulse + fwd_decay

return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state)

@staticmethod
def vmap(info, in_dims, *args):
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)

decay, impulse, initial_state = tuple(
map(
lambda x: x.reshape(-1, *x.shape[2:]),
map(maybe_expand_bdim_at_front, args, in_dims),
)
)

out = RecurrenceCUDA.apply(decay, impulse, initial_state)
return out.reshape(info.batch_size, -1, *out.shape[1:]), 0

0 comments on commit bdb72be

Please sign in to comment.