Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Use power of 2 size tensors in Attention benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
steve.an committed Oct 10, 2023
1 parent 96b0328 commit 9238b14
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@util.report(
"attention forward", ["y_size"], [32 * i for i in range(1, 21)], {"num_batches": 64, "num_heads": 8, "x_size": 64}
"attention forward", ["y_size"], [2**i for i in range(5, 10)], {"num_batches": 64, "num_heads": 8, "x_size": 64}
)
def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
Expand All @@ -39,8 +39,8 @@ def bench_attention_forward(num_batches, num_heads, y_size, x_size, dtype, backe
@util.report(
"attention backward",
["y_size"],
[64 * i for i in range(1, 21)],
{"num_batches": 64, "num_heads": 8, "x_size": 64},
[2**i for i in range(5, 10)],
{"num_batches": 32, "num_heads": 8, "x_size": 64},
)
def bench_attention_backward(num_batches, num_heads, y_size, x_size, dtype, backend):
factory_kwargs = {"device": "cuda", "dtype": dtype}
Expand Down
9 changes: 9 additions & 0 deletions trident/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def attention_configs():
return configs


def attention_backward_configs():
configs = []
for y_block_size in [64]:
for num_warps in [2, 4, 8]:
configs.append(triton.Config({"y_block_size": y_block_size}, num_warps))
return configs


class Attention:
@staticmethod
@util.autotune(attention_configs(), ["y_size"])
Expand Down Expand Up @@ -165,6 +173,7 @@ def forward(
tl.store(log2sum_block_ptr, log2sum.to(dtype))

@staticmethod
@util.autotune(attention_backward_configs(), ["y_size"])
@triton.jit
def backward(
grad_query_ptr: tl.tensor,
Expand Down
14 changes: 11 additions & 3 deletions trident/operation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def forward(ctx: Any, *args: Any, **kwargs: Any):
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError("The dimension of query, key and value should be 4.")

if (
not util.is_power_of_two(query.shape[-2])
or not util.is_power_of_two(query.shape[-1])
or not util.is_power_of_two(key.shape[-2])
or not util.is_power_of_two(key.shape[-1])
or not util.is_power_of_two(value.shape[-2])
or not util.is_power_of_two(value.shape[-1])
):
raise ValueError("Attention supports only for power of 2 size tensors.")

if mask is not None:
if is_causal:
raise ValueError("Error because both attn_mask and is_causal are set.")
Expand Down Expand Up @@ -183,9 +193,7 @@ def grid(meta):
softmax_scale,
use_accelerator,
util.dtype(grad_query.dtype),
64,
triton.next_power_of_2(x_size),
num_warps=4 if x_size <= 64 else 8,
x_block_size=triton.next_power_of_2(x_size),
)
util.pop_trace()

Expand Down
10 changes: 10 additions & 0 deletions trident/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def dtype(input):
raise ValueError(f"Unable to convert the given input: '{input}'.")


def is_power_of_two(n):
if n <= 0:
return False
while n > 1:
if n % 2 != 0:
return False
n = n // 2
return True


def size_and_stride(input: torch.Tensor, dim: int):
if input.dim() == 2:
if dim == 0:
Expand Down

0 comments on commit 9238b14

Please sign in to comment.