Skip to content

Commit

Permalink
test: vmap for parallel scan
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Sep 9, 2024
1 parent 7113cf8 commit a9a8c58
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,29 @@ def func(x, A, zi):
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan_vmap():
batch_size = 3
samples = 255
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")
y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")

A.requires_grad = True
x.requires_grad = True
zi.requires_grad = True

args = (x, A, zi)

def func(x, A, zi):
return F.mse_loss(LPC.apply(x, A, zi), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

loss = func(*args)
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)

0 comments on commit a9a8c58

Please sign in to comment.