Skip to content

Commit

Permalink
Rebase latest updated kernel to align on kernel specific max work gro…
Browse files Browse the repository at this point in the history
…up size
  • Loading branch information
fengyuan14 committed Jul 7, 2024
1 parent fa94c71 commit c127887
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/ATen/native/xpu/sycl/BucketizationKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,6 @@ void searchsorted_template(
const bool& right,
const Tensor& sorter) {
int64_t numel_in = input.numel();
int64_t rng, grng, tile_size;
tile_size = syclMaxWorkGroupSize();
rng = numel_in;
if (rng == 0) {
rng = static_cast<int64_t>(1);
}

grng = rng;
if (tile_size > grng) {
tile_size = grng;
} else if (grng > tile_size) {
int64_t xMode = static_cast<int64_t>(grng % tile_size);
if (xMode != 0) {
grng += static_cast<int64_t>(tile_size - xMode);
}
}

bool is_scalar_input = input.dim() == 0 && numel_in == 1;
// inner most dim size of input and boundaries
int64_t idim_in = is_scalar_input ? 1 : input.sizes().back();
Expand All @@ -167,6 +150,23 @@ void searchsorted_template(
data_bd_data,
data_out_data);

int64_t rng, grng, tile_size;
tile_size = syclMaxWorkGroupSize(kfn);
rng = numel_in;
if (rng == 0) {
rng = static_cast<int64_t>(1);
}

grng = rng;
if (tile_size > grng) {
tile_size = grng;
} else if (grng > tile_size) {
int64_t xMode = static_cast<int64_t>(grng % tile_size);
if (xMode != 0) {
grng += static_cast<int64_t>(tile_size - xMode);
}
}

sycl_kernel_submit(grng, tile_size, getCurrentSYCLQueue(), kfn);
}

Expand Down Expand Up @@ -243,4 +243,4 @@ void searchsorted_kernel(
result.copy_(out);
}
}
} // namespace at::native::xpu
} // namespace at::native::xpu

0 comments on commit c127887

Please sign in to comment.