diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 327c283..a3c1608 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -150,7 +150,8 @@ def forward( "BLOCK_D_Q": DimQ, "BLOCK_D_V": DimV, "MAX_ATTN_LEN": None, - "CONTEXTUAL_SEQ_LEN": 0, + "HAS_CONTEXTUAL_SEQ_LEN": False, + "contextual_seq_len": 0, "HAS_SORT_BY_LENGTH_INDICES": False, "sort_by_length_indices": None, } @@ -180,7 +181,7 @@ def forward( kwargs["num_targets"], kwargs["ATTN_BIAS_TYPE"], # relative_bias_type kwargs["MAX_ATTN_LEN"], # max_attn_len - kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len + kwargs["contextual_seq_len"], # contextual_seq_len self.sort_by_length, )