Skip to content

Commit

Permalink
[GPU] Fix a bug of resample_onnx kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jade-cho committed Nov 28, 2024
1 parent 79493c2 commit 6aa8ba4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ KERNEL (resample_onnx)(__global INPUT0_TYPE* input,
OUT_VEC_TYPE out = TO_OUT_VEC_TYPE(ACTIVATION(res, ACTIVATION_PARAMS));
#endif // #if HAS_FUSED_OPS

#if defined (NOT_ALIGNED_TO_FEATURE)
output[OUTPUT_GET_INDEX(b, feature_num, z, y, (x + out_x))] = out;
#else
WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, z, y, (x + out_x)), out);
#endif // #if defined(NOT_ALIGNED_TO_FEATURE)
}
#else // #if defined (THREE_SPATIAL_RESAMPLE)

Expand Down Expand Up @@ -220,11 +224,19 @@ KERNEL (resample_onnx)(__global INPUT0_TYPE* input,
OUT_VEC_TYPE out = TO_OUT_VEC_TYPE(ACTIVATION(res, ACTIVATION_PARAMS));
#endif

#if OUTPUT_DIMS == 5
WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, z, y, (x + out_x)), out);
#if defined (NOT_ALIGNED_TO_FEATURE)
#if OUTPUT_DIMS == 5
output[OUTPUT_GET_INDEX(b, feature_num, z, y, (x + out_x))] = out;
#else
output[OUTPUT_GET_INDEX(b, feature_num, y, (x + out_x))] = out;
#endif
#else
#if OUTPUT_DIMS == 5
WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, z, y, (x + out_x)), out);
#else
WRITE_FUNC(output, OUTPUT_GET_INDEX(b, feature_block, y, (x + out_x)), out);
#endif
#endif
#endif // #if defined (NOT_ALIGNED_TO_FEATURE)
}
#endif // #if defined (THREE_SPATIAL_RESAMPLE)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ JitConstants ResampleKernelOnnx::GetJitConstants(const resample_params& params)

jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size));

if (params.outputs[0].Feature().v % sub_group_size != 0) {
jit.AddConstant(MakeJitConstant("NOT_ALIGNED_TO_FEATURE", 1));
}

if (!params.fused_ops.empty()) {
std::vector<std::string> idx_order;
if (params.inputs[0].Dimentions() == 5)
Expand Down

0 comments on commit 6aa8ba4

Please sign in to comment.