Skip to content

Commit

Permalink
AOTriton 0.7.1 compile fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pruthvistony committed Oct 10, 2024
1 parent a1e8b0e commit 7ac294f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ _efficient_attention_backward(
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::cast_dtype;
using sdp::aotriton_adapter::mk_aoscalartensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::cast_dtype;
using sdp::aotriton_adapter::mk_aoscalartensor;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand Down

0 comments on commit 7ac294f

Please sign in to comment.