From 9238b1489aec04fabcc6d187522548f61de74a2f Mon Sep 17 00:00:00 2001 From: "steve.an" Date: Tue, 10 Oct 2023 10:39:19 +0900 Subject: [PATCH] Use power of 2 size tensors in Attention benchmark --- benchmarks/benchmark_attention.py | 6 +++--- trident/kernel/attention.py | 9 +++++++++ trident/operation/attention.py | 14 +++++++++++--- trident/util/util.py | 10 ++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_attention.py b/benchmarks/benchmark_attention.py index 264c5d7d..c8a29a6d 100644 --- a/benchmarks/benchmark_attention.py +++ b/benchmarks/benchmark_attention.py @@ -20,7 +20,7 @@ @util.report( - "attention forward", ["y_size"], [32 * i for i in range(1, 21)], {"num_batches": 64, "num_heads": 8, "x_size": 64} + "attention forward", ["y_size"], [2**i for i in range(5, 10)], {"num_batches": 64, "num_heads": 8, "x_size": 64} ) def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backend): factory_kwargs = {"device": "cuda", "dtype": dtype} @@ -39,8 +39,8 @@ def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backe @util.report( "attention backward", ["y_size"], - [64 * i for i in range(1, 21)], - {"num_batches": 64, "num_heads": 8, "x_size": 64}, + [2**i for i in range(5, 10)], + {"num_batches": 32, "num_heads": 8, "x_size": 64}, ) def bench_attention_backward(num_batches, num_heads, y_size, x_size, dtype, backend): factory_kwargs = {"device": "cuda", "dtype": dtype} diff --git a/trident/kernel/attention.py b/trident/kernel/attention.py index b509fb60..18dce895 100644 --- a/trident/kernel/attention.py +++ b/trident/kernel/attention.py @@ -26,6 +26,14 @@ def attention_configs(): return configs +def attention_backward_configs(): + configs = [] + for y_block_size in [64]: + for num_warps in [2, 4, 8]: + configs.append(triton.Config({"y_block_size": y_block_size}, num_warps)) + return configs + + class Attention: @staticmethod @util.autotune(attention_configs(), ["y_size"]) @@ -165,6 +173,7 @@ def forward( tl.store(log2sum_block_ptr, log2sum.to(dtype)) @staticmethod + @util.autotune(attention_backward_configs(), ["y_size"]) @triton.jit def backward( grad_query_ptr: tl.tensor, diff --git a/trident/operation/attention.py b/trident/operation/attention.py index fa02549e..5793fa85 100644 --- a/trident/operation/attention.py +++ b/trident/operation/attention.py @@ -28,6 +28,16 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: raise ValueError("The dimension of query, key and value should be 4.") + if ( + not util.is_power_of_two(query.shape[-2]) + or not util.is_power_of_two(query.shape[-1]) + or not util.is_power_of_two(key.shape[-2]) + or not util.is_power_of_two(key.shape[-1]) + or not util.is_power_of_two(value.shape[-2]) + or not util.is_power_of_two(value.shape[-1]) + ): + raise ValueError("Attention supports only for power of 2 size tensors.") + if mask is not None: if is_causal: raise ValueError("Error because both attn_mask and is_causal are set.") @@ -183,9 +193,7 @@ def grid(meta): softmax_scale, use_accelerator, util.dtype(grad_query.dtype), - 64, - triton.next_power_of_2(x_size), - num_warps=4 if x_size <= 64 else 8, + x_block_size=triton.next_power_of_2(x_size), ) util.pop_trace() diff --git a/trident/util/util.py b/trident/util/util.py index 048ec74c..0bbc08da 100644 --- a/trident/util/util.py +++ b/trident/util/util.py @@ -58,6 +58,16 @@ def dtype(input): raise ValueError(f"Unable to convert the given input: '{input}'.") +def is_power_of_two(n): + if n <= 0: + return False + while n > 1: + if n % 2 != 0: + return False + n = n // 2 + return True + + def size_and_stride(input: torch.Tensor, dim: int): if input.dim() == 2: if dim == 0: