Skip to content

Commit

Permalink
inductor(cpu): make variable number used of masked vectorization path…
Browse files Browse the repository at this point in the history
… align with scalar path (pytorch#96510)

Fix pytorch#96484, for CPP reduction vectorization path, there has an assumption that the vectorization path var number used should be aligned with the scalar path, but currently, masked doesn't meet such requirement and will report var not defined error.

before:
```
{
    {
        {
            #pragma omp declare reduction(min:at::vec::Vectorized<float>:omp_out = at::vec::minimum(omp_out, omp_in)) initializer(omp_priv={{std::numeric_limits<float>::infinity()}})
            float tmp7 = std::numeric_limits<float>::infinity();
            auto tmp7_vec = at::vec::Vectorized<float>(tmp7);
            for(long i0=0; i0<0; i0+=1)
            {
                auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr1 + 16*i0);
                auto tmp0 = at::vec::Vectorized<int>(static_cast<int>(0));
                auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(2));
                auto tmp2 = tmp0 < tmp1;
                auto tmp3 = at::vec::Vectorized<float>(0.0);
                {
                    auto tmp4 = at::vec::Vectorized<float>(in_ptr0[0]);
                    tmp3 = decltype(tmp4)::blendv(tmp3, tmp4, to_float_mask(tmp2) != at::vec::Vectorized<float>(0));
                }
                auto tmp6 = tmp3 + tmp5;
                tmp7_vec = at::vec::minimum(tmp7_vec, tmp6);
            }
            #pragma omp simd simdlen(8)  reduction(min:tmp8)
            for(long i0=0; i0<2; i0+=1)
            {
                auto tmp6 = in_ptr1[i0];
                auto tmp0 = static_cast<long>(0);
                auto tmp1 = static_cast<long>(2);
                auto tmp2 = tmp0 < tmp1;
                auto tmp3 = [&]
                {
                    auto tmp4 = in_ptr0[0];
                    return tmp4;
                }
                ;
                auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                auto tmp7 = tmp5 + tmp6;
                tmp8 = std::min(tmp8, tmp7);
            }
            tmp7 = std::min(tmp7, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return at::vec::minimum(x, y);}, tmp7_vec));
            out_ptr0[0] = tmp8;
        }
    }
}
```
after:

```
{
    {
        {
            #pragma omp declare reduction(min:at::vec::Vectorized<float>:omp_out = at::vec::minimum(omp_out, omp_in)) initializer(omp_priv={{std::numeric_limits<float>::infinity()}})
            float tmp8 = std::numeric_limits<float>::infinity();
            auto tmp8_vec = at::vec::Vectorized<float>(tmp8);
            for(long i0=0; i0<0; i0+=1)
            {
                auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + 16*i0);
                auto tmp0 = at::vec::Vectorized<int>(static_cast<int>(0));
                auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(2));
                auto tmp2 = tmp0 < tmp1;
                auto tmp3 = [&]
                {
                    auto tmp4 = at::vec::Vectorized<float>(in_ptr0[0]);
                    return tmp4;
                }
                ;
                auto tmp5 = decltype(tmp3())::blendv(at::vec::Vectorized<float>(0.0), tmp3(), to_float_mask(tmp2) != at::vec::Vectorized<float>(0));
                auto tmp7 = tmp5 + tmp6;
                tmp8_vec = at::vec::minimum(tmp8_vec, tmp7);
            }
            #pragma omp simd simdlen(8)  reduction(min:tmp8)
            for(long i0=0; i0<2; i0+=1)
            {
                auto tmp6 = in_ptr1[i0];
                auto tmp0 = static_cast<long>(0);
                auto tmp1 = static_cast<long>(2);
                auto tmp2 = tmp0 < tmp1;
                auto tmp3 = [&]
                {
                    auto tmp4 = in_ptr0[0];
                    return tmp4;
                }
                ;
                auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                auto tmp7 = tmp5 + tmp6;
                tmp8 = std::min(tmp8, tmp7);
            }
            tmp8 = std::min(tmp8, at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return at::vec::minimum(x, y);}, tmp8_vec));
            out_ptr0[0] = tmp8;
        }
    }
}

```

Pull Request resolved: pytorch#96510
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
  • Loading branch information
XiaobingSuper authored and pytorchmergebot committed Mar 13, 2023
1 parent 2cbce06 commit 279ada5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
15 changes: 15 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5858,6 +5858,21 @@ def fn(p0, p1):
compiled_out = opt_fn(p0, p1)
assert same(real_out, compiled_out)

def test_reduce_with_masked(self):
# https://github.com/pytorch/pytorch/issues/96484
def fn(a, b):
a = torch.nn.functional.pad(a, (0, -1))
c = a + b
return c.min(0).values

a = torch.randn([2])
b = torch.randn([2])
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(a, b)
real_out = fn(a, b)
compiled_out = opt_fn(a, b)
assert same(real_out, compiled_out)

@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
Expand Down
32 changes: 17 additions & 15 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,27 +553,29 @@ def masked(mask, body, other):
assert opt_ctx.is_masked_load

code = BracesBuffer()

var = V.kernel.cse.newvar()
code.writeline(f"auto {var} = [&]")
with V.kernel.swap_buffers(code), code.indent():
result = body()
code.writeline(f"return {result};")
code.writeline(";")
V.kernel.compute.splice(code)

if other == float("-inf"):
code.writeline(
f"auto {var} = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());"
other_code = (
"at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())"
)
elif other == float("inf"):
code.writeline(
f"auto {var} = at::vec::Vectorized<float>(std::numeric_limits<float>::infinity());"
other_code = (
"at::vec::Vectorized<float>(std::numeric_limits<float>::infinity())"
)
else:
code.writeline(f"auto {var} = at::vec::Vectorized<float>({other!r});")
other_code = f"at::vec::Vectorized<float>({other!r})"

with V.kernel.swap_buffers(code), code.indent():
result = body()
zero_val = "at::vec::Vectorized<float>(0)"
float_mask = f"to_float_mask({mask})"
blendv = f"decltype({result})::blendv({var}, {result}, {float_mask} != {zero_val})"
code.writeline(f"{var} = {blendv};")
V.kernel.compute.splice(code)
return var
type = f"decltype({var}())"
zero_val = "at::vec::Vectorized<float>(0)"
float_mask = f"to_float_mask({mask})"
return f"{type}::blendv({other_code}, {var}(), {float_mask} != {zero_val})"

@staticmethod
def index_expr(expr, dtype):
Expand Down Expand Up @@ -818,7 +820,7 @@ def masked(mask, body, other):
if other == float("-inf"):
other_code = f"-std::numeric_limits<{type}>::infinity()"
elif other == float("inf"):
other_code = "std::numeric_limits<{type}>::infinity()"
other_code = f"std::numeric_limits<{type}>::infinity()"
elif isinstance(other, bool):
other_code = f"static_cast<{type}>({str(other).lower()})"
else:
Expand Down

0 comments on commit 279ada5

Please sign in to comment.