Skip to content

Commit

Permalink
test: gradient using parallel scan
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Jul 10, 2024
1 parent 58e5334 commit 7102ac2
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,35 @@ def test_float64_vs_32_cuda():
assert torch.allclose(y64, y32.double(), atol=1e-6), torch.max(
torch.abs(y64 - y32.double())
)


@pytest.mark.parametrize(
"x_requires_grad",
[True],
)
@pytest.mark.parametrize(
"a_requires_grad",
[True, False],
)
@pytest.mark.parametrize(
"zi_requires_grad",
[True, False],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))

0 comments on commit 7102ac2

Please sign in to comment.