Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parallel-scan algorithm for first-order filter #11

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,35 @@ def test_float64_vs_32_cuda():
assert torch.allclose(y64, y32.double(), atol=1e-6), torch.max(
torch.abs(y64 - y32.double())
)


@pytest.mark.parametrize(
"x_requires_grad",
[True],
)
@pytest.mark.parametrize(
"a_requires_grad",
[True, False],
)
@pytest.mark.parametrize(
"zi_requires_grad",
[True, False],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))
5 changes: 5 additions & 0 deletions torchlpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Optional

from .core import LPC
from .parallel_scan import WARPSIZE
from .recurrence import RecurrenceCUDA

__all__ = ["sample_wise_lpc"]

Expand Down Expand Up @@ -35,4 +37,7 @@ def sample_wise_lpc(
else:
assert zi.shape == (B, order)

if order == 1 and x.is_cuda and B * WARPSIZE < T:
return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))

return LPC.apply(x, a, zi)
160 changes: 160 additions & 0 deletions torchlpc/parallel_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from numba import cuda

WARPSIZE = 32

# implementation was translated from https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.cu


@cuda.jit(device=True)
def divide_work(n_jobs, n_workers, worker_idx) -> tuple:
cd = (n_jobs + n_workers - 1) // n_workers
d, doing_cd = divmod(n_jobs, n_workers)
if worker_idx < doing_cd:
x = cd * worker_idx
y = x + cd
else:
x = cd * doing_cd + d * (worker_idx - doing_cd)
y = x + d
return x, y


@cuda.jit(device=True)
def compute_warp_start_stop(blockIdx, warp_idx, n_blocks, n_steps):
block_start, block_stop = divide_work(n_steps, n_blocks, blockIdx)
block_jobs = block_stop - block_start

warp_start, warp_stop = divide_work(block_jobs, WARPSIZE, warp_idx)
warp_start += block_start
warp_stop += block_start

return warp_start, warp_stop


@cuda.jit
def reduction_kernel(
decay, impulses, initial_state, decay_storage, h_storage, n_dims, n_steps
):
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE)

storage_offset = cuda.blockIdx.x * (WARPSIZE + 1)

warp_start, warp_stop = compute_warp_start_stop(
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps
)

# reduce within warp
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE):
cum_decay = 1.0
h = 0.0
if (cuda.blockIdx.x == 0) and (lane == 0):
h = initial_state[i]

for t in range(warp_start, warp_stop):
cum_decay *= decay[i, t]
h = decay[i, t] * h + impulses[i, t]

decay_storage[lane + storage_offset, i] = cum_decay
h_storage[lane + storage_offset, i] = h

cuda.syncthreads()

# reduce within block
for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x):
cum_decay = 1.0
h = 0.0
for t in range(storage_offset, storage_offset + WARPSIZE):
cum_decay *= decay_storage[t, i]
h = decay_storage[t, i] * h + h_storage[t, i]

decay_storage[WARPSIZE + storage_offset, i] = cum_decay
h_storage[WARPSIZE + storage_offset, i] = h


@cuda.jit
def block_scan_kernel(decay_storage, h_storage, n_dims, n_blocks):
for i in range(
cuda.grid(1),
n_dims,
cuda.gridsize(1),
):
for t in range(1, n_blocks):
cur_idx = t * (WARPSIZE + 1) + WARPSIZE
prev_idx = (t - 1) * (WARPSIZE + 1) + WARPSIZE
h_storage[cur_idx, i] += h_storage[prev_idx, i] * decay_storage[cur_idx, i]
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i]


@cuda.jit
def warp_scan_kernel(
decay, impulses, initial_state, out, decay_storage, h_storage, n_dims, n_steps
):
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE)

for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x):
offset = cuda.blockIdx.x * (WARPSIZE + 1)
for cur_idx in range(offset, offset + WARPSIZE):
if cur_idx == 0:
continue
prev_idx = cur_idx - 1
h_storage[cur_idx, i] = (
h_storage[prev_idx, i] * decay_storage[cur_idx, i]
+ h_storage[cur_idx, i]
)
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i]

cuda.syncthreads()

warp_start, warp_stop = compute_warp_start_stop(
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps
)

# scan within warp
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE):
if (cuda.blockIdx.x == 0) and (lane == 0):
h = initial_state[i]
else:
h = h_storage[lane - 1 + cuda.blockIdx.x * (WARPSIZE + 1), i]

for t in range(warp_start, warp_stop):
h = decay[i, t] * h + impulses[i, t]
out[i, t] = h


def compute_linear_recurrence(
decays, impulses, init_states, out, n_dims: int, n_steps: int
):
n_blocks = min((n_steps + WARPSIZE - 1) // WARPSIZE, 128)

reduction_mem_shape = (n_blocks * (WARPSIZE + 1), n_dims)
decay_storage = cuda.device_array(reduction_mem_shape, dtype=decays.dtype)
h_storage = cuda.device_array(reduction_mem_shape, dtype=impulses.dtype)

reduction_kernel[n_blocks, 512](
decays, impulses, init_states, decay_storage, h_storage, n_dims, n_steps
)

block_scan_kernel[n_blocks, 512](decay_storage, h_storage, n_dims, n_blocks)

warp_scan_kernel[n_blocks, 512](
decays, impulses, init_states, out, decay_storage, h_storage, n_dims, n_steps
)


if __name__ == "__main__":
import numpy as np

n_dims = 16
n_steps = 20480
decays = np.full((n_dims, n_steps), 0.9, dtype=np.float32)
impulses = np.full((n_dims, n_steps), 0.0, dtype=np.float32)
impulses[:, 0] = 1.0
init_states = np.full(n_dims, 0.0, dtype=np.float32)

decays = cuda.to_device(decays)
impulses = cuda.to_device(impulses)
init_states = cuda.to_device(init_states)
out = cuda.device_array((n_dims, n_steps), dtype=np.float32)

compute_linear_recurrence(decays, impulses, init_states, out, n_dims, n_steps)

print(out.copy_to_host())
59 changes: 59 additions & 0 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn.functional as F
from torch.autograd import Function
from numba import cuda

from .parallel_scan import compute_linear_recurrence


class RecurrenceCUDA(Function):
@staticmethod
def forward(
ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor
) -> torch.Tensor:
n_dims, n_steps = decay.shape
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
ctx.save_for_backward(decay, initial_state, out)
return out

@staticmethod
def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor:
decay, initial_state, out = ctx.saved_tensors
grad_decay = grad_impulse = grad_initial_state = None
n_dims, _ = decay.shape

padded_decay = F.pad(decay.unsqueeze(1), (0, 1)).squeeze(1)
if ctx.needs_input_grad[2]:
padded_grad_out = F.pad(grad_out.unsqueeze(1), (1, 0)).squeeze(1)
else:
padded_grad_out = grad_out
padded_decay = padded_decay[:, 1:]

init = padded_grad_out.new_zeros(n_dims)
flipped_grad_impulse = RecurrenceCUDA.apply(
padded_decay.flip(1).conj_physical(),
padded_grad_out.flip(1),
init,
)

if ctx.needs_input_grad[2]:
grad_initial_state = flipped_grad_impulse[:, -1]
flipped_grad_impulse = flipped_grad_impulse[:, :-1]

if ctx.needs_input_grad[1]:
grad_impulse = flipped_grad_impulse.flip(1)

if ctx.needs_input_grad[0]:
valid_out = out[:, :-1]
padded_out = torch.cat([initial_state.unsqueeze(1), valid_out], dim=1)
grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1)

return grad_decay, grad_impulse, grad_initial_state
Loading