Skip to content

Commit

Permalink
[inductor] make multi-kernel work with cpp-wrapper (pytorch#117813)
Browse files Browse the repository at this point in the history
Make multi-kernel work with cpp-wrapper. multi-kernel generates two equivalent variants for a reduction. At runtime the faster one is picked. But cpp-wrapper need save cubin file during codegen. They don't work with each other at the beginning.

Thanks Jason for suggesting a neat way to integrate these two. cpp-wrapper does 2 passes codegen right now. For the first pass, we still generate multi-kernel code and run it; for the second pass, we load the cubin file for the faster kernel directly. And multi-kernel python code is not generated for the second pass since they should not be needed.

Pull Request resolved: pytorch#117813
Approved by: https://github.com/jansel
  • Loading branch information
shunting314 authored and pytorchmergebot committed Feb 1, 2024
1 parent 54668ad commit 20484a1
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 41 deletions.
93 changes: 84 additions & 9 deletions test/inductor/test_multi_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch._dynamo.testing import reset_rng_state

from torch._inductor import config
from torch._inductor import config, test_operators
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
Expand Down Expand Up @@ -43,16 +43,51 @@ def _contains_multi_kernel_code(wrapper_code: str):
)


@config.patch({"triton.multi_kernel": 1, "benchmark_kernel": True})
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
Make this as a free function rather than staticmethod in MultiKernelTest.
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
error in py3.8. (py3.10 works)
"""

@config.patch("cpp_wrapper", True)
def fn(self):
# The same kernel may have been compiled by previous tests with
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
# the kernel with cpp_wrapper enabled.
from torch._inductor import codecache

codecache.PyCodeCache.clear()
return orig_test(self, **extra_args)

return fn


@config.patch(
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
}
)
@instantiate_parametrized_tests
class MultiKernelTest(TestCase):
def test_softmax(self):
def test_softmax(self, expect_multi_kernel=True):
x = torch.rand(2, 1024).cuda()
ref = torch.softmax(x, -1)
compiled_fn = torch.compile(torch.softmax)
act, (wrapper_code,) = run_and_get_code(compiled_fn, x, -1)
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)

# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertTrue(torch.allclose(ref, act))
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
if expect_multi_kernel:
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
else:
self.assertFalse(_contains_multi_kernel_code(wrapper_code))

@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
Expand Down Expand Up @@ -90,6 +125,10 @@ def mock_run(self, kernel_calls):
def test_softmax_warn_mixed_layout(self):
self.test_softmax()

test_softmax_cpp_wrapper = make_cpp_wrapper_test(
test_softmax, expect_multi_kernel=False
)

def test_layernorm(self):
ln = nn.LayerNorm(1024).cuda()
x = torch.rand(2, 1024).cuda()
Expand All @@ -114,10 +153,6 @@ def f(x, y):
self.assertTrue(torch.allclose(ref, act))

def test_transformer_snippet(self):
"""
Test a snippet of transformer that will cause different arglist for
the persistent and non-persistent flavor of reductions.
"""
model = TransformerSnippet().cuda()
x = model.example_inputs()

Expand Down Expand Up @@ -195,6 +230,46 @@ def f(x, y):
act = torch.compile(f)(x, y)
self.assertTrue(torch.allclose(y_ref, y))

def test_reduction_scratch_buffer(self, force_multi_kernel=1):
"""
The explicited realized buffer in the test function will be passed in
as a scratch buffer for the non-persistent reduction kernel but
can be skipped for the persistent reduction kernel.
This causes different argument lists for non-persistent reduction kernel and
persistent reduction kernel.
Check documentation around torch._inductor.config.triton.multi_kernel about
how to interpret the force_multi_kernel argument.
"""

def f(x):
x = x.sum(dim=-1, keepdim=True) + x
x = test_operators.realize(x)
x = x.sum(dim=-1, keepdim=True) + x
return x

x = torch.rand(16, 16, device="cuda")
ref = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
act = torch.compile(f)(x)
self.assertTrue(torch.allclose(ref, act))

# Use benchmarking to pick the faster kernel
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
test_reduction_scratch_buffer, force_multi_kernel=1
)
# force pick persistent reduction. This can be a good test since this persistent
# reduction uses less call arguments than the corresponding non-persistent
# reduction.
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
)
# force pick non-persistent reduction
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
78 changes: 66 additions & 12 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ def _run_and_assert_no_indirect_indexing(test_case, func, *args, **kwargs):


def assertGeneratedKernelCountEqual(self: TestCase, expected: int):
if config.triton.multi_kernel:
# when multi_kernel is enabled, we generated both persistent reduction
# and non-persistent reduction kernels for the same node schedule.
# That will mess up with the kernel count. Just don't check it.
return
if config.cpp_wrapper:
expected *= 2
self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected)
Expand Down Expand Up @@ -1200,7 +1205,20 @@ def fn(x, y):
z = x * y
return z.sum((0, 1))

self.common(fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)))
atol = None
rtol = None

# By default, inductor generate non-persistent reduction kernels in this
# case. But when multi-kernel is enabled, inductor will pick the faster
# of persistent reduction and non-persistent-reduction kernel.
# In this case, inductor picked the persistent-reduction kernel.
# The persistent reduction kernel happens to need looser tolerance.
if config.triton.multi_kernel:
atol = 1e-5
rtol = 1e-5
self.common(
fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)), atol=atol, rtol=rtol
)

def test_min_max_reduction(self):
def fn(a, b):
Expand Down Expand Up @@ -8600,7 +8618,13 @@ def fn(a: torch.Tensor) -> torch.Tensor:
return torch.sum(a)

kernels = self.get_kernels(fn, [torch.randn([256, 256], device=GPU_TYPE)])
self.assertTrue(len(kernels) == 2, "SUM should result in two kernels")
if config.triton.multi_kernel:
self.assertTrue(
len(kernels) == 4,
"SUM should result in four kernels when multi-kernel is enabled",
)
else:
self.assertTrue(len(kernels) == 2, "SUM should result in two kernels")

# kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of
# size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is
Expand All @@ -8611,8 +8635,11 @@ def fn(a: torch.Tensor) -> torch.Tensor:
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel0, (0, 1, 3))

# kernel1 reduces from 8 elements to a single scalar.
# Since multi-kernel generate 2 variants for each kernel. The second
# persistent-reduction has index 2.
kernel1_index = 2 if config.triton.multi_kernel else 1
arguments_that_are_divisible_by_16_in_kernel1 = (
kernels[1].triton_meta["configs"][0].divisible_by_16
kernels[kernel1_index].triton_meta["configs"][0].divisible_by_16
)
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
torch._dynamo.reset()
Expand Down Expand Up @@ -8808,7 +8835,9 @@ def fn(a, b):
torch.randn(1, N, K, device=GPU_TYPE),
]
code = run_and_get_triton_code(fn_opt, *inps)
self.assertEqual(code.count("tl.store"), 1)
self.assertEqual(
code.count("tl.store"), 2 if config.triton.multi_kernel else 1
)
self.assertTrue("out_ptr1" in code)
self.assertFalse("out_ptr0" in code)
self.assertEqual(fn_opt(*inps), fn(*inps))
Expand Down Expand Up @@ -8936,12 +8965,24 @@ def f(a, b):
)
code = run_and_get_triton_code(f, *inps)
lines = [line for line in code.split("\n") if "tl.load" in line]
self.assertExpectedInline(
"\n".join(lines),
"""\
if config.triton.multi_kernel:
# the first 2 lines are generated for the persistent reduction
# variant.
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, other=0.0)
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
)
)
else:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
)

@skipIfRocm
@config.patch("triton.use_block_ptr", True)
Expand All @@ -8957,12 +8998,25 @@ def f(a, b):
)
code = run_and_get_triton_code(f, *inps)
lines = [line for line in code.split("\n") if "tl.load" in line]
self.assertExpectedInline(
"\n".join(lines),
"""\

if config.triton.multi_kernel:
# the first 2 lines are generated for the persistent reduction
# variant.
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, RBLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)
else:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""",
)
)

# Disable index propagation, so the indirect indexing isn't optimized away
@patch.object(config, "constant_and_index_propagation", False)
Expand Down
Loading

0 comments on commit 20484a1

Please sign in to comment.