Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Apr 17, 2024
1 parent 366d619 commit 924800e
Showing 1 changed file with 6 additions and 94 deletions.
100 changes: 6 additions & 94 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,88 +42,6 @@ def lpc_cuda_kernel_{t}(padded_y, A, B, T, order) -> None:
padded_y[b, t + order] = sm[circular_idx]"""
)

# lpc_cuda_kernel_float32 = cuda.jit(
# eval(cuda_kernel_string_1 + "float32" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_float64 = cuda.jit(
# eval(cuda_kernel_string_1 + "float64" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_complex64 = cuda.jit(
# eval(cuda_kernel_string_1 + "complex64" + cuda_kernel_string_2)
# )
# lpc_cuda_kernel_complex128 = cuda.jit(
# eval(cuda_kernel_string_1 + "complex128" + cuda_kernel_string_2)
# )


# @cuda.jit
# def lpc_cuda_kernel_float64(padded_y, A, B, T, order) -> None:
# sm = cuda.shared.array(shape=(1024,), dtype=float64)

# batch_idx = cuda.blockIdx.x
# tid = cuda.threadIdx.x

# i = tid
# b = batch_idx

# if b >= B or i >= order:
# return

# circular_idx = 0
# sm[i] = padded_y[b, i]

# for t in range(T):
# circular_idx = t % order
# if i == (order - 1):
# sm[circular_idx] *= -A[b, t, i]
# cuda.syncthreads()

# if i == (order - 1):
# v = padded_y[b, t + order]
# elif i > circular_idx - 1:
# v = -A[b, t, i] * sm[circular_idx - i - 1 + order]
# else:
# v = -A[b, t, i] * sm[circular_idx - i - 1]
# cuda.atomic.add(sm, circular_idx, v)
# cuda.syncthreads()

# if i == (order - 1):
# padded_y[b, t + order] = sm[circular_idx]


# @cuda.jit
# def lpc_cuda_kernel_float32(padded_y, A, B, T, order) -> None:
# sm = cuda.shared.array(shape=(1024,), dtype=float32)

# batch_idx = cuda.blockIdx.x
# tid = cuda.threadIdx.x
# i = tid
# b = batch_idx

# if b >= B or i >= order:
# return

# circular_idx = 0
# sm[i] = padded_y[b, i]

# for t in range(T):
# circular_idx = t % order
# if i == (order - 1):
# sm[circular_idx] *= -A[b, t, i]
# cuda.syncthreads()

# if i == (order - 1):
# v = padded_y[b, t + order]
# elif i > circular_idx - 1:
# v = -A[b, t, i] * sm[circular_idx - i - 1 + order]
# else:
# v = -A[b, t, i] * sm[circular_idx - i - 1]
# cuda.atomic.add(sm, circular_idx, v)
# cuda.syncthreads()

# if i == (order - 1):
# padded_y[b, t + order] = sm[circular_idx]


def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
B, T, order = A.shape
Expand All @@ -136,24 +54,18 @@ def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor
blocks_per_grid = B

if x.dtype == torch.float64:
lpc_cuda_kernel_float64[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
runner = lpc_cuda_kernel_float64[blocks_per_grid, threads_per_block]
elif x.dtype == torch.float32:
lpc_cuda_kernel_float32[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
runner = lpc_cuda_kernel_float32[blocks_per_grid, threads_per_block]
elif x.dtype == torch.complex64:
lpc_cuda_kernel_complex64[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
runner = lpc_cuda_kernel_complex64[blocks_per_grid, threads_per_block]
elif x.dtype == torch.complex128:
lpc_cuda_kernel_complex128[blocks_per_grid, threads_per_block](
cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order
)
runner = lpc_cuda_kernel_complex128[blocks_per_grid, threads_per_block]
else:
raise NotImplementedError

runner(cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order)

return padded_y[:, order:].contiguous()


Expand Down

0 comments on commit 924800e

Please sign in to comment.