Skip to content

Commit

Permalink
[inductor] more accurate throughput calculations for kernel benchmarks (
Browse files Browse the repository at this point in the history
pytorch#118858)

Our current throughput calculations for kernel benchmarks have some issues,
particularly when we slice inputs in the kernel. In such cases, we count
the original inputs as part of the memory traffic passed across the kernel.
This is incorrect because it may result in a much larger throughput
calculation, which can even exceed the theoretical bandwidth.

Instead, we should only count the size of the "slices" that contribute to
the actual memory traffic.

Pull Request resolved: pytorch#118858
Approved by: https://github.com/jansel
  • Loading branch information
chenyang78 authored and pytorchmergebot committed Feb 1, 2024
1 parent 20484a1 commit 61b572e
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 27 deletions.
209 changes: 197 additions & 12 deletions test/inductor/test_kernel_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def verify_compiled_kernels(self, GB_count=1):
exactly=1,
).run(bench_out)

def check_bandwidth(self, compiled_module, num_gb):
# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -k".split(),
stderr=subprocess.STDOUT,
).decode()

# make sure we have the bandwidth information in the output
FileCheck().check_count(
f"{num_gb} GB ",
1,
exactly=1,
).run(bench_out)

def test_pw_kernel_benchmark(self):
@torch.compile
def f(x):
Expand Down Expand Up @@ -117,7 +131,7 @@ def f(a, b):
exactly=1,
).run(source_code)

def test_bandwidth_computation(self):
def test_matmul_bandwidth_computation(self):
"""
The test does a matmul and then mul. Without max-autotune, we use
the matmul in aten. So there is a single triton kernel for mul.
Expand Down Expand Up @@ -146,18 +160,189 @@ def f(x, y):

compiled_module = self.get_compiled_module()

# now run the compiled module in subprocess and check its output
bench_out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__} -k".split(),
stderr=subprocess.STDOUT,
).decode()
self.check_bandwidth(compiled_module, 0.008)

# make sure we have the bandwidth information in the output
FileCheck().check_count(
"0.008 GB ",
1,
exactly=1,
).run(bench_out)
def test_unused_input_bandwidth_computation(self):
M, N = 5, 1000000

@torch.compile
def f(a, b, c):
return a + c

a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(c, 0)
inputs = (a, b, c)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# num_gb = size_a + size_c + size_out
# num_gb = (5 * 1000000 + 5 * 1000000 + 5 * 1000000) * 2 / 1e9
# = 0.030
self.check_bandwidth(compiled_module, "0.030")

def test_reduction_bandwidth_computation(self):
@torch.compile
def f(a):
return torch.sum(a, dim=1)

a = torch.rand(1000, 20, 1000, dtype=torch.float16, device=GPU_TYPE)
inputs = (a,)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# num_gb = size_a + size_out
# num_gb = (1000 * 20 * 1000 + 1000 * 1000) * 2 / 1e9
# = 0.042
self.check_bandwidth(compiled_module, "0.042")

@config.patch(max_autotune=True)
def test_fused_layernorm_bandwidth_computation(self):
M, N = 10, 1000000

@torch.compile
def f(a, b, c, d):
x0 = a + b
x1 = torch.nn.functional.layer_norm(
x0, normalized_shape=(N,), weight=c, bias=d, eps=1e-05
)
x2 = torch.sigmoid(x1)
return x0 * x2

a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
d = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c, d)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# num_gb = size_a + size_b + size_c + size_d + size_out
# num_gb = (10 * 1000000 + 1000000 + 1000000 + 1000000 + 10 * 1000000) * 2 / 1e9
# = 0.046
self.check_bandwidth(compiled_module, "0.046")

def test_slice_add_cat_bandwidth_computation(self):
M, N = 5, 1000000

@torch.compile
def f(a, b, c):
x0 = torch.narrow(b, 1, N, N)
# broadcasting
x1 = x0 + c
return torch.cat([a, x1], dim=1)

a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
inputs = (a, b, c)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# we overestimate the size of "slice_b" due to torch.cat
# num_gp = size_a + size_slice_b + size_c + size_out
# num_gb = (5 * 1000000 + 5 * 2000000 + 1000000 + 5 * 2000000) * 2 / 1e9
# = 0.052
self.check_bandwidth(compiled_module, "0.052")

def test_slice_add_bandwidth_computation(self):
M, N = 5, 1000000

@torch.compile
def f(a, b, c):
x0 = torch.narrow(b, 1, N, N)
return a + x0 + c

a = torch.rand(M, N, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(M, N * 5, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
inputs = (a, b, c)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# num_gb = size_a + size_slice_b + size_c + out_size
# num_gb = (5 * 1000000 + 5 * 1000000 + 1000000 + 5 * 1000000) * 2 / 1e9
# = 0.032
self.check_bandwidth(compiled_module, "0.032")

def test_mm_slice_add_bandwidth_computation(self):
M, N, K = 1000, 1000, 30

@torch.compile
def f(a, b, c):
x0 = torch.mm(a, b)
x1 = torch.narrow(c, 1, 20 * N, N)
x2 = torch.narrow(c, 1, 21 * N, N)
return x0 + x1 + x2

a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# torch.mm becomes an extern kernel, so we measure the nbytes
# for the pointwise add kernel:
# num_gb = x0 + 2 * size_slice_c + size_out
# num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
# = 0.008
self.check_bandwidth(compiled_module, "0.008")

def test_mm_slice_add_bandwidth_computation_2(self):
M, N, K = 1000, 1000, 30

@torch.compile
def f(a, b, c):
x0 = torch.mm(a, b)
x1 = torch.narrow(c, 1, 20 * N, N)
x2 = torch.narrow(c, 1, 20 * N, N)
return x0 + x1 + x2

a = torch.rand(M, K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
c = torch.rand(N, N * 100, dtype=torch.float16, device=GPU_TYPE)
inputs = (a, b, c)
out = f(*inputs)

compiled_module = self.get_compiled_module()
# torch.mm becomes an extern kernel, so we measure the nbytes
# for the pointwise add kernel:
# num_gb = x0 + size_slice_c + size_out
# num_gb = (1000 * 1000 + 1000 * 1000 + 1000 * 1000) * 2 / 1e9
# = 0.006
# note that we only count one size_slice_c because two accesses
# have the same index.
self.check_bandwidth(compiled_module, "0.006")

@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_slice_mm_bandwidth_computation(self):
M, N, K = 1000, 2000, 3000

@torch.compile
def f(a, b):
x = torch.narrow(a, 1, K, K)
return torch.mm(x, b)

a = torch.rand(M, 3 * K, dtype=torch.float16, device=GPU_TYPE)
b = torch.rand(K, N, dtype=torch.float16, device=GPU_TYPE)
torch._dynamo.mark_dynamic(a, 0)
inputs = (a, b)
out = f(*inputs)

compiled_module = self.get_compiled_module()

# c[1000, 2000] = x[1000, 3000] @ b[3000, 2000]
# num_gb = (1000 * 2000 + 1000 * 3000 + 3000 * 2000) * 2 / 1e9
# = 0.022
self.check_bandwidth(compiled_module, "0.022")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 61b572e

Please sign in to comment.