Skip to content

Commit

Permalink
feat: forward mode autograd
Browse files Browse the repository at this point in the history
* draft: jvp for fwd mode

* fix: make sure layout is contiguous

* test: fwd mode gradients
  • Loading branch information
yoyolicoris authored Dec 31, 2023
1 parent 62f2243 commit e7ceba1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_low_order_cpu(
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi))
assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_low_order_cuda(
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi))
assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))


Expand Down
30 changes: 29 additions & 1 deletion torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor
else:
raise NotImplementedError

return padded_y[:, order:]
return padded_y[:, order:].contiguous()


@njit(parallel=True)
Expand Down Expand Up @@ -132,6 +132,11 @@ def forward(
)
y = torch.from_numpy(y).to(x.device, x.dtype)
ctx.save_for_backward(A, zi, y)

# for jvp
ctx.y = y
ctx.A = A
ctx.zi = zi
return y

@staticmethod
Expand Down Expand Up @@ -176,3 +181,26 @@ def backward(
grad_A = unfolded_y * -flipped_grad_x.flip(1).unsqueeze(2)

return grad_x, grad_A, grad_zi

@staticmethod
def jvp(
ctx: Any, grad_x: torch.Tensor, grad_A: torch.Tensor, grad_zi: torch.Tensor
) -> torch.Tensor:
A, y, zi = ctx.A, ctx.y, ctx.zi
*_, order = A.shape

grad_y = 0

if grad_x is not None:
grad_y_from_x_zi = LPC.apply(grad_x, A, grad_zi)
grad_y = grad_y_from_x_zi

if grad_A is not None:
unfolded_y = (
torch.cat([zi.flip(1), y[:, :-1]], dim=1).unfold(1, order, 1).flip(2)
)
grad_A_input = -torch.sum(unfolded_y * grad_A, dim=2)
grad_y_from_A = LPC.apply(grad_A_input, A, torch.zeros_like(zi))
grad_y = grad_y + grad_y_from_A

return grad_y

0 comments on commit e7ceba1

Please sign in to comment.