From 158b2bf0a61fbaef0ebbf7ce7cb876145e517baf Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Mon, 16 Dec 2024 14:57:34 -0800 Subject: [PATCH] changing fp8 attention input logic Summary: inputs did not match some of the comments + adding some more functionality. Any thoughts? Reviewed By: xuzhao9 Differential Revision: D66982898 fbshipit-source-id: 6d054de381de6d5749c4413712c8c17519dd202c --- tritonbench/operators/fp8_attention/operator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 1426465..af31e67 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -56,6 +56,7 @@ def __init__( self.embedding_dim = args.embedding_dim self.D_HEAD = args.d_head self.causal = args.causal + self.requires_grad = not self.tb_args.mode == "fwd_no_grad" self.sm_scale = 1.3 def colfax_preprocess(self, q, k, v): @@ -123,7 +124,6 @@ def get_input_iter(self) -> Generator: head_dims = [64, 128, 256] BATCH = self.BATCH D_HEAD = self.D_HEAD - requires_grad = True for N_CTX in [2**i for i in range(7, 15)]: self.N_CTX = N_CTX H = self.embedding_dim // D_HEAD @@ -133,19 +133,19 @@ def get_input_iter(self) -> Generator: (BATCH, H, N_CTX, D_HEAD), dtype=torch.float16, device=self.device, - requires_grad=True, + requires_grad=self.requires_grad, ) k = torch.randn( (BATCH, H, N_CTX, D_HEAD), dtype=torch.float16, device=self.device, - requires_grad=True, + requires_grad=self.requires_grad, ) v = torch.randn( (BATCH, H, N_CTX, D_HEAD), dtype=torch.float16, device=self.device, - requires_grad=True, + requires_grad=self.requires_grad, ) yield (q, k, v) @@ -153,6 +153,7 @@ def get_input_iter(self) -> Generator: def flops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: - H = self.embedding_dim // self.D_HEAD - flops_per_matmul = 2.0 * self.BATCH * H * self.N_CTX * self.N_CTX * self.D_HEAD + q, _, _ = example_inputs + BATCH, H, N_CTX, D_HEAD = q.shape + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD return 2 * flops_per_matmul