Skip to content

Commit

Permalink
Fix batch norm vectorize path accuracy issue by enforcing shape align…
Browse files Browse the repository at this point in the history
…ment (#1238)

When to implement a kernel for all shapes with vectorized LD/ST, we have
to handle non-aligned head (base address) and short tail (tail < vector
size). Before the commit, there was lack of head handling which leaded
to non-vector-size-aligned vectorized ST.
The fixing enforces shapes for vectorize path. The vectorized kernel
could be called only in the case in which feature dim could be divided
by vector size. There will always be aligned head and exact tail.

---------

Co-authored-by: mengfei25 <[email protected]>
  • Loading branch information
xytintel and mengfei25 authored Jan 3, 2025
1 parent 06f6339 commit f634c3c
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
} else {
invstd =
static_cast<stat_accscalar_t>(1) /
device_sqrt(
std::sqrt(
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
}

Expand All @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE;
feature_vec_begin < fs;
feature_vec_begin += VEC_SIZE * item.get_local_range(1)) {
auto remaining = fs - feature_vec_begin;
if (remaining < VEC_SIZE) {
for (index_t idx = 0; idx < remaining; ++idx) {
index_t feature = feature_vec_begin + idx;
o[feature] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
} else {
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
#pragma unroll
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
}
}
}
Expand Down Expand Up @@ -1459,7 +1450,7 @@ void batch_norm_elemt_template(
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float)) {
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
Expand Down

0 comments on commit f634c3c

Please sign in to comment.