diff --git a/horde_worker_regen/amd_go_fast/amd_go_fast.py b/horde_worker_regen/amd_go_fast/amd_go_fast.py index 1d2404d7..ba8190ca 100644 --- a/horde_worker_regen/amd_go_fast/amd_go_fast.py +++ b/horde_worker_regen/amd_go_fast/amd_go_fast.py @@ -5,15 +5,15 @@ def _patch_sdpa( - patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor], + patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None, bool], Tensor], ): - """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)""" + """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)""" torch_sdpa = torch.nn.functional.scaled_dot_product_attention - def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False): try: - return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale) + return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa) except Exception: hidden_states = torch_sdpa( query=query, @@ -23,6 +23,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa dropout_p=dropout_p, is_causal=is_causal, scale=scale, + enable_gqa=enable_gqa, ) return hidden_states @@ -32,7 +33,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa try: from flash_attn import flash_attn_func - def sdpa_hijack_flash(q, k, v, m, p, c, s): + def sdpa_hijack_flash(q, k, v, m, p, c, s, g): assert m is None result = flash_attn_func( q=q.transpose(1, 2),