From bdb72bee854f1dc0110a70aa6299165fcd9e67ed Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Mon, 9 Sep 2024 18:02:19 +0800 Subject: [PATCH] draft: jvp and vmap for recurrence --- torchlpc/recurrence.py | 59 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index 45bd4c6..cd9292e 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from torch.autograd import Function from numba import cuda +from typing import Tuple, Optional, Any, List from .parallel_scan import compute_linear_recurrence @@ -9,7 +10,10 @@ class RecurrenceCUDA(Function): @staticmethod def forward( - ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor + ctx: Any, + decay: torch.Tensor, + impulse: torch.Tensor, + initial_state: torch.Tensor, ) -> torch.Tensor: n_dims, n_steps = decay.shape out = torch.empty_like(impulse) @@ -21,11 +25,18 @@ def forward( n_dims, n_steps, ) - ctx.save_for_backward(decay, initial_state, out) return out @staticmethod - def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor: + def setup_context(ctx: Any, inputs: List[Any], output: Any) -> Any: + decay, _, initial_state = inputs + ctx.save_for_backward(decay, initial_state, output) + ctx.save_for_forward(decay, initial_state, output) + + @staticmethod + def backward( + ctx: Any, grad_out: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: decay, initial_state, out = ctx.saved_tensors grad_decay = grad_impulse = grad_initial_state = None n_dims, _ = decay.shape @@ -57,3 +68,45 @@ def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor: grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1) return grad_decay, grad_impulse, grad_initial_state + + @staticmethod + def jvp( + ctx: Any, + grad_decay: torch.Tensor, + grad_impulse: torch.Tensor, + grad_initial_state: torch.Tensor, + ) -> torch.Tensor: + decay, initial_state, out = ctx.saved_tensors + + fwd_initial_state = ( + grad_initial_state + if grad_initial_state is not None + else torch.zeros_like(initial_state) + ) + fwd_impulse = ( + grad_impulse if grad_impulse is not None else torch.zeros_like(out) + ) + + if grad_decay is not None: + concat_out = torch.cat([initial_state.unsqueeze(1), out[:, :-1]], dim=1) + fwd_decay = -concat_out * grad_decay + fwd_impulse = fwd_impulse + fwd_decay + + return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state) + + @staticmethod + def vmap(info, in_dims, *args): + def maybe_expand_bdim_at_front(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + decay, impulse, initial_state = tuple( + map( + lambda x: x.reshape(-1, *x.shape[2:]), + map(maybe_expand_bdim_at_front, args, in_dims), + ) + ) + + out = RecurrenceCUDA.apply(decay, impulse, initial_state) + return out.reshape(info.batch_size, -1, *out.shape[1:]), 0