Skip to content

Commit

Permalink
changing fp8 attention input logic
Browse files Browse the repository at this point in the history
Summary: inputs did not match some of the comments + adding some more functionality. Any thoughts?

Reviewed By: xuzhao9

Differential Revision: D66982898

fbshipit-source-id: 6d054de381de6d5749c4413712c8c17519dd202c
  • Loading branch information
adamomainz authored and facebook-github-bot committed Dec 16, 2024
1 parent e7074bf commit 158b2bf
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.embedding_dim = args.embedding_dim
self.D_HEAD = args.d_head
self.causal = args.causal
self.requires_grad = not self.tb_args.mode == "fwd_no_grad"
self.sm_scale = 1.3

def colfax_preprocess(self, q, k, v):
Expand Down Expand Up @@ -123,7 +124,6 @@ def get_input_iter(self) -> Generator:
head_dims = [64, 128, 256]
BATCH = self.BATCH
D_HEAD = self.D_HEAD
requires_grad = True
for N_CTX in [2**i for i in range(7, 15)]:
self.N_CTX = N_CTX
H = self.embedding_dim // D_HEAD
Expand All @@ -133,26 +133,27 @@ def get_input_iter(self) -> Generator:
(BATCH, H, N_CTX, D_HEAD),
dtype=torch.float16,
device=self.device,
requires_grad=True,
requires_grad=self.requires_grad,
)
k = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=torch.float16,
device=self.device,
requires_grad=True,
requires_grad=self.requires_grad,
)
v = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=torch.float16,
device=self.device,
requires_grad=True,
requires_grad=self.requires_grad,
)
yield (q, k, v)

@register_metric()
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
H = self.embedding_dim // self.D_HEAD
flops_per_matmul = 2.0 * self.BATCH * H * self.N_CTX * self.N_CTX * self.D_HEAD
q, _, _ = example_inputs
BATCH, H, N_CTX, D_HEAD = q.shape
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
return 2 * flops_per_matmul

0 comments on commit 158b2bf

Please sign in to comment.