diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 8610fb5708b51..7eb84284d0dc1 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -56,6 +56,20 @@ def verify_compiled_kernels(self, GB_count=1): exactly=1, ).run(bench_out) + def check_bandwidth(self, compiled_module, num_gb): + # now run the compiled module in subprocess and check its output + bench_out = subprocess.check_output( + f"{sys.executable} {compiled_module.__file__} -k".split(), + stderr=subprocess.STDOUT, + ).decode() + + # make sure we have the bandwidth information in the output + FileCheck().check_count( + f"{num_gb} GB ", + 1, + exactly=1, + ).run(bench_out) + def test_pw_kernel_benchmark(self): @torch.compile def f(x): @@ -117,7 +131,7 @@ def f(a, b): exactly=1, ).run(source_code) - def test_bandwidth_computation(self): + def test_matmul_bandwidth_computation(self): """ The test does a matmul and then mul. Without max-autotune, we use the matmul in aten. So there is a single triton kernel for mul. @@ -146,18 +160,189 @@ def f(x, y): compiled_module = self.get_compiled_module() - # now run the compiled module in subprocess and check its output - bench_out = subprocess.check_output( - f"{sys.executable} {compiled_module.__file__} -k".split(), - stderr=subprocess.STDOUT, - ).decode() + self.check_bandwidth(compiled_module, 0.008) - # make sure we have the bandwidth information in the output - FileCheck().check_count( - "0.008 GB ", - 1, - exactly=1, - ).run(bench_out) + def test_unused_input_bandwidth_computation(self): + M, N = 5, 1000000 + + @torch.compile + def f(a, b, c): + return a + c + + a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(b, 0) + torch._dynamo.mark_dynamic(c, 0) + inputs = (a, b, c) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # num_gb = size_a + size_c + size_out + # num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9 + # = 0.030 + self.check_bandwidth(compiled_module, "0.030") + + def test_reduction_bandwidth_computation(self): + @torch.compile + def f(a): + return torch.sum(a, dim=1) + + a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE) + inputs = (a,) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # num_gb = size_a + size_out + # num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9 + # = 0.042 + self.check_bandwidth(compiled_module, "0.042") + + @config.patch(max_autotune=True) + def test_fused_layernorm_bandwidth_computation(self): + M, N = 10, 1000000 + + @torch.compile + def f(a, b, c, d): + x0 = a + b + x1 = torch.nn.functional.layer_norm( + x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05 + ) + x2 = torch.sigmoid(x1) + return x0 * x2 + + a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) + d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) + inputs = (a, b, c, d) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # num_gb = size_a + size_b + size_c + size_d + size_out + # num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9 + # = 0.046 + self.check_bandwidth(compiled_module, "0.046") + + def test_slice_add_cat_bandwidth_computation(self): + M, N = 5, 1000000 + + @torch.compile + def f(a, b, c): + x0 = torch.narrow(b, 1, N, N) + # broadcasting + x1 = x0 + c + return torch.cat([a, x1], dim=1) + + a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(b, 0) + inputs = (a, b, c) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # we overestimate the size of "slice_b" due to torch.cat + # num_gp = size_a + size_slice_b + size_c + size_out + # num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9 + # = 0.052 + self.check_bandwidth(compiled_module, "0.052") + + def test_slice_add_bandwidth_computation(self): + M, N = 5, 1000000 + + @torch.compile + def f(a, b, c): + x0 = torch.narrow(b, 1, N, N) + return a + x0 + c + + a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(b, 0) + inputs = (a, b, c) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # num_gb = size_a + size_slice_b + size_c + out_size + # num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9 + # = 0.032 + self.check_bandwidth(compiled_module, "0.032") + + def test_mm_slice_add_bandwidth_computation(self): + M, N, K = 1000, 1000, 30 + + @torch.compile + def f(a, b, c): + x0 = torch.mm(a, b) + x1 = torch.narrow(c, 1, 20 * N, N) + x2 = torch.narrow(c, 1, 21 * N, N) + return x0 + x1 + x2 + + a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE) + inputs = (a, b, c) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # torch.mm becomes an extern kernel, so we measure the nbytes + # for the pointwise add kernel: + # num_gb = x0 + 2 * size_slice_c + size_out + # num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9 + # = 0.008 + self.check_bandwidth(compiled_module, "0.008") + + def test_mm_slice_add_bandwidth_computation_2(self): + M, N, K = 1000, 1000, 30 + + @torch.compile + def f(a, b, c): + x0 = torch.mm(a, b) + x1 = torch.narrow(c, 1, 20 * N, N) + x2 = torch.narrow(c, 1, 20 * N, N) + return x0 + x1 + x2 + + a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) + c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE) + inputs = (a, b, c) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + # torch.mm becomes an extern kernel, so we measure the nbytes + # for the pointwise add kernel: + # num_gb = x0 + size_slice_c + size_out + # num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9 + # = 0.006 + # note that we only count one size_slice_c because two accesses + # have the same index. + self.check_bandwidth(compiled_module, "0.006") + + @config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") + def test_slice_mm_bandwidth_computation(self): + M, N, K = 1000, 2000, 3000 + + @torch.compile + def f(a, b): + x = torch.narrow(a, 1, K, K) + return torch.mm(x, b) + + a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE) + b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE) + torch._dynamo.mark_dynamic(a, 0) + inputs = (a, b) + out = f(*inputs) + + compiled_module = self.get_compiled_module() + + # c[1000, 2000] = x[1000, 3000] @ b[3000, 2000] + # num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9 + # = 0.022 + self.check_bandwidth(compiled_module, "0.022") if __name__ == "__main__": diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a71e9568846bc..f893b8d19931c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -15,6 +15,7 @@ Any, Callable, Counter, + DefaultDict, Dict, Iterable, List, @@ -37,7 +38,7 @@ from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..codecache import code_hash, get_path, PyCodeCache -from ..dependencies import MemoryDep, StarDep +from ..dependencies import Dep, MemoryDep, StarDep from ..ir import IRNode, ReductionHint, TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction from ..scheduler import BaseScheduling, WhyNoFuse @@ -45,6 +46,7 @@ from ..utils import ( cache_on_self, do_bench, + get_dtype_size, get_fused_kernel_name, get_kernel_metadata, green_text, @@ -1221,6 +1223,8 @@ def __init__( self.min_elem_per_thread = min_elem_per_thread self.last_usage: Set[str] = set() self.block_ptr_id = itertools.count() + # buffer accesses in the kernel + self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) self.persistent_reduction: bool = ( not disable_persistent_reduction @@ -2395,7 +2399,7 @@ def codegen_body(self): self.stores.clear() self.suffix.clear() - def codegen_kernel_benchmark(self, grid=None): + def codegen_kernel_benchmark(self, num_gb, grid=None): result = IndentedBuffer() argdefs, call_args, signature = self.args.python_argdefs() @@ -2474,10 +2478,8 @@ def codegen_kernel_benchmark(self, grid=None): f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" ) - ninplace_args = len(unique(self.args.inplace_buffers.values())) result.writelines(["\n", "\n", "if __name__ == '__main__':"]) with result.indent(): - result.writeline("from torch._inductor.utils import get_num_bytes") result.writeline("from triton.testing import do_bench") result.writeline("") @@ -2485,9 +2487,7 @@ def codegen_kernel_benchmark(self, grid=None): result.writeline( "ms = do_bench(lambda: call(args), rep=40, fast_flush=True)" ) - result.writeline( - f"num_gb = get_num_bytes(*args, num_in_out_args={ninplace_args}) / 1e9" - ) + result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") result.writeline( 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' @@ -2521,6 +2521,63 @@ class defined. else: return "" + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _ = self.args.python_argdefs() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in self.buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices = set() + for dep in self.buf_accesses[arg]: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + def codegen_kernel(self, name=None): from triton import next_power_of_2 @@ -2613,6 +2670,10 @@ def codegen_kernel(self, name=None): "mutated_arg_names": mutated_args, "no_x_dim": self.no_x_dim, } + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb for tree in self.active_range_trees(): sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) @@ -2684,7 +2745,7 @@ def codegen_kernel(self, name=None): code.splice(self.body) if config.benchmark_kernel: - code.splice(self.codegen_kernel_benchmark()) + code.splice(self.codegen_kernel_benchmark(num_gb)) return code.getvalue() @@ -3084,10 +3145,14 @@ def codegen_nodes(self, nodes): _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + buf_accesses = collections.defaultdict(list) + for node in nodes: + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) schedule_log.debug("Schedule:\n %s", node_schedule) - return self.codegen_node_schedule(node_schedule, numel, rnumel) + return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) @staticmethod def reduction_hint(node): @@ -3227,7 +3292,9 @@ def codegen_comment(self, node_schedule): f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" ) - def codegen_node_schedule(self, node_schedule, numel, reduction_numel): + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) reduction_hint_val, mutations, index_dtype = self.get_kernel_args( node_schedule, numel, reduction_numel @@ -3243,6 +3310,7 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): *kernel_args, **kernel_kwargs, ) + kernel.buf_accesses = buf_accesses self.codegen_node_schedule_with_kernel(node_schedule, kernel) @@ -3400,13 +3468,14 @@ def codegen_template(self, template_node, epilogue_nodes): node_schedule = [template_node, *epilogue_nodes] if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) assert kernel.meta is not None, "meta is None" grid = kernel.grid_fn(*grid_args, kernel.meta) src_code = ( f"{kernel.imports_for_benchmark_kernel()}\n" f"{src_code}\n" - f"{kernel.codegen_kernel_benchmark(grid).getvalue()}" + f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" ) kernel_name = self.define_kernel(src_code, node_schedule) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 20b88d6524c41..a02cd56e9e5c8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3,6 +3,7 @@ import inspect import itertools import logging +import operator import sys import textwrap import time @@ -24,7 +25,14 @@ from .codegen.triton import texpr, TritonKernel, TritonPrinter, TritonScheduling from .codegen.triton_utils import config_of, signature_to_meta from .exc import CUDACompileError -from .utils import do_bench, Placeholder, sympy_dot, sympy_product, unique +from .utils import ( + do_bench, + get_dtype_size, + Placeholder, + sympy_dot, + sympy_product, + unique, +) from .virtualized import V log = logging.getLogger(__name__) @@ -109,6 +117,21 @@ def __init__( def need_numel_args(self): return False + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + def jit_line(self): if self.use_jit: return "@triton.jit" @@ -122,7 +145,12 @@ def jit_line(self): } triton_meta["configs"] = [config_of(signature)] - inductor_meta = {"kernel_name": str(Placeholder.DESCRIPTIVE_NAME)} + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb return textwrap.dedent( f""" @template( diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index efc0ab40e5a5a..83306154401c0 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -730,7 +730,9 @@ def run(self, *args, grid, stream): if arg_name.startswith("in_out_ptr") ] ) - num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 gb_per_s = num_gb / (ms / 1e3) self.cached = (ms, num_gb, gb_per_s, kernel_name) else: diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 66b2b71458f70..d033dfa6bb904 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -83,7 +83,9 @@ def get_triton_kernel(mod): if arg_name.startswith("in_out_ptr") ] ) - num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 def get_info_str(ms, n_regs, n_spills, shared, prefix=""): if not any(x is None for x in [n_regs, n_spills, shared]):