Skip to content

Commit

Permalink
Add back has_contextual_seq_len, add TritonCC support for UC mask
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/generative-recommenders#168

Without `has_contextual_seq_len`, we'll have to list all possible `contextual_seq_len` in the tritoncc named spec, which could be painful since `contextual_seq_len` could change and it also varies for different product surfaces.

Reviewed By: hanli0612, LinjianMa

Differential Revision: D66990536

fbshipit-source-id: 750617c49b185e21fb371a544ca93aaf3bf0a381
  • Loading branch information
ruochen99 authored and facebook-github-bot committed Dec 16, 2024
1 parent 172eace commit e7074bf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def forward(
"BLOCK_D_Q": DimQ,
"BLOCK_D_V": DimV,
"MAX_ATTN_LEN": None,
"CONTEXTUAL_SEQ_LEN": 0,
"HAS_CONTEXTUAL_SEQ_LEN": False,
"contextual_seq_len": 0,
"HAS_SORT_BY_LENGTH_INDICES": False,
"sort_by_length_indices": None,
}
Expand Down Expand Up @@ -180,7 +181,7 @@ def forward(
kwargs["num_targets"],
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len
kwargs["contextual_seq_len"], # contextual_seq_len
self.sort_by_length,
)

Expand Down

0 comments on commit e7074bf

Please sign in to comment.