Skip to content

Commit

Permalink
feat: optimize voxel indexing in preprocess_kernel.cu
Browse files Browse the repository at this point in the history
Signed-off-by: Taekjin LEE <[email protected]>
  • Loading branch information
technolojin committed Dec 11, 2024
1 parent 01944ef commit 8069241
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,12 @@ __global__ void generateBaseFeatures_kernel(
unsigned int * mask, float * voxels, int grid_y_size, int grid_x_size, int max_voxel_size,
unsigned int * pillar_num, float * voxel_features, float * voxel_num, int * voxel_idxs)
{
unsigned int voxel_idx = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int voxel_idy = blockIdx.y * blockDim.y + threadIdx.y;
// exchange x and y to process in a row-major order
// flip x axis direction to process front to back
unsigned int voxel_idx_inverted = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int voxel_idy = blockIdx.x * blockDim.x + threadIdx.x;
if (voxel_idx_inverted >= grid_x_size || voxel_idy >= grid_y_size) return;
unsigned int voxel_idx = grid_x_size - 1 - voxel_idx_inverted;

if (voxel_idx >= grid_x_size || voxel_idy >= grid_y_size) return;

Expand Down Expand Up @@ -220,9 +224,10 @@ cudaError_t generateBaseFeatures_launch(
unsigned int * pillar_num, float * voxel_features, float * voxel_num, int * voxel_idxs,
cudaStream_t stream)
{
// exchange x and y to process in a row-major order
dim3 threads = {32, 32};
dim3 blocks = {
(grid_x_size + threads.x - 1) / threads.x, (grid_y_size + threads.y - 1) / threads.y};
(grid_y_size + threads.x - 1) / threads.x, (grid_x_size + threads.y - 1) / threads.y};

generateBaseFeatures_kernel<<<blocks, threads, 0, stream>>>(
mask, voxels, grid_y_size, grid_x_size, max_voxel_size, pillar_num, voxel_features, voxel_num,
Expand Down

0 comments on commit 8069241

Please sign in to comment.