From a9a8c588598503ad6f6f8fba3cbb761e23c58dbd Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Tue, 10 Sep 2024 01:22:07 +0800 Subject: [PATCH] test: vmap for parallel scan --- tests/test_vmap.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_vmap.py b/tests/test_vmap.py index a5f2f6d..75c4496 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -45,3 +45,29 @@ def func(x, A, zi): loss.backward() for jac, arg in zip(jacs, args): assert torch.allclose(jac, arg.grad) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_parallel_scan_vmap(): + batch_size = 3 + samples = 255 + x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") + A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1 + zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda") + y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") + + A.requires_grad = True + x.requires_grad = True + zi.requires_grad = True + + args = (x, A, zi) + + def func(x, A, zi): + return F.mse_loss(LPC.apply(x, A, zi), y) + + jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) + + loss = func(*args) + loss.backward() + for jac, arg in zip(jacs, args): + assert torch.allclose(jac, arg.grad)