Skip to content

Commit

Permalink
Update CUDA kernel functions for different data types in torchlpc/cor…
Browse files Browse the repository at this point in the history
…e.py
  • Loading branch information
yoyolicoris committed Apr 17, 2024
1 parent 52784cc commit 366d619
Showing 1 changed file with 98 additions and 40 deletions.
138 changes: 98 additions & 40 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import torch.nn.functional as F
from torch.autograd import Function
from typing import Any, Tuple, Optional
from numba import jit, njit, prange, cuda, float32, float64
from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128


@cuda.jit
def lpc_cuda_kernel_float64(padded_y, A, B, T, order) -> None:
sm = cuda.shared.array(shape=(1024,), dtype=float64)

for t in ["float32", "float64", "complex64", "complex128"]:
exec(
f"""@cuda.jit
def lpc_cuda_kernel_{t}(padded_y, A, B, T, order) -> None:
sm = cuda.shared.array(shape=(1024,), dtype={t})
batch_idx = cuda.blockIdx.x
tid = cuda.threadIdx.x
Expand Down Expand Up @@ -38,41 +39,90 @@ def lpc_cuda_kernel_float64(padded_y, A, B, T, order) -> None:
cuda.syncthreads()
if i == (order - 1):
padded_y[b, t + order] = sm[circular_idx]


@cuda.jit
def lpc_cuda_kernel_float32(padded_y, A, B, T, order) -> None:
sm = cuda.shared.array(shape=(1024,), dtype=float32)

batch_idx = cuda.blockIdx.x
tid = cuda.threadIdx.x
i = tid
b = batch_idx

if b >= B or i >= order:
return

circular_idx = 0
sm[i] = padded_y[b, i]

for t in range(T):
circular_idx = t % order
if i == (order - 1):
sm[circular_idx] *= -A[b, t, i]
cuda.syncthreads()

if i == (order - 1):
v = padded_y[b, t + order]
elif i > circular_idx - 1:
v = -A[b, t, i] * sm[circular_idx - i - 1 + order]
else:
v = -A[b, t, i] * sm[circular_idx - i - 1]
cuda.atomic.add(sm, circular_idx, v)
cuda.syncthreads()

if i == (order - 1):
padded_y[b, t + order] = sm[circular_idx]
padded_y[b, t + order] = sm[circular_idx]"""
)

# lpc_cuda_kernel_float32 = cuda.jit(
# eval(cuda_kernel_string_1 + "float32" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_float64 = cuda.jit(
# eval(cuda_kernel_string_1 + "float64" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_complex64 = cuda.jit(
# eval(cuda_kernel_string_1 + "complex64" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_complex128 = cuda.jit(
# eval(cuda_kernel_string_1 + "complex128" + cuda_kernel_string_2)
# )


# @cuda.jit
# def lpc_cuda_kernel_float64(padded_y, A, B, T, order) -> None:
# sm = cuda.shared.array(shape=(1024,), dtype=float64)

# batch_idx = cuda.blockIdx.x
# tid = cuda.threadIdx.x

# i = tid
# b = batch_idx

# if b >= B or i >= order:
# return

# circular_idx = 0
# sm[i] = padded_y[b, i]

# for t in range(T):
# circular_idx = t % order
# if i == (order - 1):
# sm[circular_idx] *= -A[b, t, i]
# cuda.syncthreads()

# if i == (order - 1):
# v = padded_y[b, t + order]
# elif i > circular_idx - 1:
# v = -A[b, t, i] * sm[circular_idx - i - 1 + order]
# else:
# v = -A[b, t, i] * sm[circular_idx - i - 1]
# cuda.atomic.add(sm, circular_idx, v)
# cuda.syncthreads()

# if i == (order - 1):
# padded_y[b, t + order] = sm[circular_idx]


# @cuda.jit
# def lpc_cuda_kernel_float32(padded_y, A, B, T, order) -> None:
# sm = cuda.shared.array(shape=(1024,), dtype=float32)

# batch_idx = cuda.blockIdx.x
# tid = cuda.threadIdx.x
# i = tid
# b = batch_idx

# if b >= B or i >= order:
# return

# circular_idx = 0
# sm[i] = padded_y[b, i]

# for t in range(T):
# circular_idx = t % order
# if i == (order - 1):
# sm[circular_idx] *= -A[b, t, i]
# cuda.syncthreads()

# if i == (order - 1):
# v = padded_y[b, t + order]
# elif i > circular_idx - 1:
# v = -A[b, t, i] * sm[circular_idx - i - 1 + order]
# else:
# v = -A[b, t, i] * sm[circular_idx - i - 1]
# cuda.atomic.add(sm, circular_idx, v)
# cuda.syncthreads()

# if i == (order - 1):
# padded_y[b, t + order] = sm[circular_idx]


def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
Expand All @@ -93,6 +143,14 @@ def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor
lpc_cuda_kernel_float32[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
elif x.dtype == torch.complex64:
lpc_cuda_kernel_complex64[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
elif x.dtype == torch.complex128:
lpc_cuda_kernel_complex128[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
else:
raise NotImplementedError

Expand Down

0 comments on commit 366d619

Please sign in to comment.