Skip to content

Commit

Permalink
Rebase and fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 20, 2024
1 parent 390699f commit f54c49d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def colfax_cutlass(self, q, k, v):
@register_benchmark(enabled=bool(tk_fwd is not None))
def tk(self, q, k, v):
o = torch.zeros_like(v)
l = torch.zeros_like(o).to(torch.float32)
l_tensor = torch.zeros_like(o).to(torch.float32)

def tk_dispatcher():
tk_fwd.attention_forward(q, k, v, o, l, causal=self.causal)
tk_fwd.attention_forward(q, k, v, o, l_tensor, causal=self.causal)
return o

return tk_dispatcher
Expand Down

0 comments on commit f54c49d

Please sign in to comment.