diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index abb29955..fc5ba43e 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -377,10 +377,10 @@ def colfax_cutlass(self, q, k, v): @register_benchmark(enabled=bool(tk_fwd is not None)) def tk(self, q, k, v): o = torch.zeros_like(v) - l = torch.zeros_like(o).to(torch.float32) + l_tensor = torch.zeros_like(o).to(torch.float32) def tk_dispatcher(): - tk_fwd.attention_forward(q, k, v, o, l, causal=self.causal) + tk_fwd.attention_forward(q, k, v, o, l_tensor, causal=self.causal) return o return tk_dispatcher