Skip to content

Commit

Permalink
fix: explicitly free ctx tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris authored Apr 25, 2024
1 parent 52784cc commit 7da81c9
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def backward(
unfolded_y = padded_y.unfold(1, order, 1).flip(2)
grad_A = unfolded_y * -flipped_grad_x.flip(1).unsqueeze(2)

del ctx.y, ctx.A, ctx.zi
return grad_x, grad_A, grad_zi

@staticmethod
Expand All @@ -199,4 +200,5 @@ def jvp(
fwd_A = -torch.sum(unfolded_y * grad_A, dim=2)
fwd_x = fwd_x + fwd_A

del ctx.y, ctx.A, ctx.zi
return LPC.apply(fwd_x, A, fwd_zi)

0 comments on commit 7da81c9

Please sign in to comment.