Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
inductor(cpu): make variable number used of masked vectorization path…
… align with scalar path (pytorch#96510) Fix pytorch#96484, for CPP reduction vectorization path, there has an assumption that the vectorization path var number used should be aligned with the scalar path, but currently, masked doesn't meet such requirement and will report var not defined error. before: ``` { { { #pragma omp declare reduction(min:at::vec::Vectorized<float>:omp_out = at::vec::minimum(omp_out, omp_in)) initializer(omp_priv={{std::numeric_limits<float>::infinity()}}) float tmp7 = std::numeric_limits<float>::infinity(); auto tmp7_vec = at::vec::Vectorized<float>(tmp7); for(long i0=0; i0<0; i0+=1) { auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr1 + 16*i0); auto tmp0 = at::vec::Vectorized<int>(static_cast<int>(0)); auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(2)); auto tmp2 = tmp0 < tmp1; auto tmp3 = at::vec::Vectorized<float>(0.0); { auto tmp4 = at::vec::Vectorized<float>(in_ptr0[0]); tmp3 = decltype(tmp4)::blendv(tmp3, tmp4, to_float_mask(tmp2) != at::vec::Vectorized<float>(0)); } auto tmp6 = tmp3 + tmp5; tmp7_vec = at::vec::minimum(tmp7_vec, tmp6); } #pragma omp simd simdlen(8) reduction(min:tmp8) for(long i0=0; i0<2; i0+=1) { auto tmp6 = in_ptr1[i0]; auto tmp0 = static_cast<long>(0); auto tmp1 = static_cast<long>(2); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = in_ptr0[0]; return tmp4; } ; auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp7 = tmp5 + tmp6; tmp8 = std::min(tmp8, tmp7); } tmp7 = std::min(tmp7, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return at::vec::minimum(x, y);}, tmp7_vec)); out_ptr0[0] = tmp8; } } } ``` after: ``` { { { #pragma omp declare reduction(min:at::vec::Vectorized<float>:omp_out = at::vec::minimum(omp_out, omp_in)) initializer(omp_priv={{std::numeric_limits<float>::infinity()}}) float tmp8 = std::numeric_limits<float>::infinity(); auto tmp8_vec = at::vec::Vectorized<float>(tmp8); for(long i0=0; i0<0; i0+=1) { auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + 16*i0); auto tmp0 = at::vec::Vectorized<int>(static_cast<int>(0)); auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(2)); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = at::vec::Vectorized<float>(in_ptr0[0]); return tmp4; } ; auto tmp5 = decltype(tmp3())::blendv(at::vec::Vectorized<float>(0.0), tmp3(), to_float_mask(tmp2) != at::vec::Vectorized<float>(0)); auto tmp7 = tmp5 + tmp6; tmp8_vec = at::vec::minimum(tmp8_vec, tmp7); } #pragma omp simd simdlen(8) reduction(min:tmp8) for(long i0=0; i0<2; i0+=1) { auto tmp6 = in_ptr1[i0]; auto tmp0 = static_cast<long>(0); auto tmp1 = static_cast<long>(2); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = in_ptr0[0]; return tmp4; } ; auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp7 = tmp5 + tmp6; tmp8 = std::min(tmp8, tmp7); } tmp8 = std::min(tmp8, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return at::vec::minimum(x, y);}, tmp8_vec)); out_ptr0[0] = tmp8; } } } ``` Pull Request resolved: pytorch#96510 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
- Loading branch information