diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 085eb39d..901f9bf5 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -41,7 +41,7 @@ ) @pytest.mark.parametrize("block_size", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("sm_scale", [1.0, 0.123]) +@pytest.mark.parametrize("softmax_scale", [1.0, 0.123]) @pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]) @pytest.mark.parametrize("use_segment_ids", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) @@ -52,7 +52,7 @@ def test_triton_fwd_only_against_ref( per_head_dim: int, block_size: int, causal: bool, - sm_scale: float, + softmax_scale: float, attention_bias_type: Literal["2d", "4d", None], use_segment_ids: bool, input_dtype: jnp.dtype, @@ -82,13 +82,13 @@ def impl(q, k, v, bias, segment_ids): block_q=block_size, block_k=block_size, causal=causal, - softmax_scale=sm_scale, + softmax_scale=softmax_scale, ) out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) return out o = impl(q, k, v, bias, segment_ids) - o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) + o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale) chex.assert_trees_all_close(o, o_ref, atol=0.07) @@ -147,7 +147,7 @@ def test_triton_against_xla_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - sm_scale = q.shape[-1] ** -0.5 + softmax_scale = q.shape[-1] ** -0.5 # Compare outputs. jax_out = flash_attention( @@ -157,11 +157,13 @@ def test_triton_against_xla_ref( bias, segment_ids, causal=causal, - softmax_scale=sm_scale, + softmax_scale=softmax_scale, block_q=block_size, block_k=block_size, ) - jax_ref_out = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) + jax_ref_out = mha_reference( + q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale + ) if input_dtype == jnp.float16: chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005) elif input_dtype == jnp.float32: @@ -177,14 +179,14 @@ def fn(q, k, v, bias, segment_ids): bias, segment_ids, causal=causal, - softmax_scale=sm_scale, + softmax_scale=softmax_scale, block_q=block_size, block_k=block_size, ).sum() def ref_fn(q, k, v, bias, segment_ids): return mha_reference( - q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale + q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale ).sum() # Compare gradients. @@ -224,11 +226,13 @@ def test_cudnn_against_triton_ref( jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype ) - sm_scale = q.shape[-1] ** -0.5 + softmax_scale = q.shape[-1] ** -0.5 # Compare outputs. - jax_out = cudnn_dot_product_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale) - jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale) + jax_out = cudnn_dot_product_attention( + q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale + ) + jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale) if dtype == jnp.bfloat16: # We relax the atol to support bf16 in the unit test. chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5) @@ -239,11 +243,11 @@ def test_cudnn_against_triton_ref( def fn(q, k, v): return cudnn_dot_product_attention( - q, k, v, bias=None, causal=causal, softmax_scale=sm_scale + q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale ).sum() def ref_fn(q, k, v): - return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale).sum() + return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum() # Compare gradients. jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v) diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 2335498e..7b44266b 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -212,8 +212,8 @@ def _legacy_tpu_flash_attention( ab=bias, segment_ids=SegmentIds(q=segment_ids, kv=segment_ids) if segment_ids is not None else None, causal=causal, - # If sm_scale==1.0, the kernel skips applying it. - sm_scale=1.0, + # If softmax_scale==1.0, the kernel skips applying it. + softmax_scale=1.0, block_sizes=block_sizes, debug=False, ) @@ -332,7 +332,7 @@ def _tpu_splash_attention( jax.jit, static_argnames=[ "causal", - "sm_scale", + "softmax_scale", "block_sizes", "debug", ], @@ -345,7 +345,7 @@ def pallas_tpu_flash_attention( segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len] *, causal: bool = False, - sm_scale: float = 1.0, + softmax_scale: float = 1.0, block_sizes: Optional[LegacyBlockSizes] = None, debug: bool = False, ): @@ -396,7 +396,9 @@ def pallas_tpu_flash_attention( block_sizes = LegacyBlockSizes.get_default( batch_size, num_heads, q_seq_len, kv_seq_len, d_model ) - return _flash_attention(q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug) + return _flash_attention( + q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug + ) @functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) @@ -408,7 +410,7 @@ def _flash_attention( segment_ids, save_residuals, causal, - sm_scale, + softmax_scale, block_sizes, debug, ): @@ -420,7 +422,7 @@ def _flash_attention( segment_ids, save_residuals, causal, - sm_scale, + softmax_scale, block_sizes.block_b, block_sizes.block_q, block_sizes.block_k_major, @@ -437,20 +439,22 @@ def _flash_attention_fwd( segment_ids, save_residuals, causal, - sm_scale, + softmax_scale, block_sizes, debug, ): if save_residuals: raise NotImplementedError("Higher-order AD not supported") - o, l, m = _flash_attention(q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug) + o, l, m = _flash_attention( + q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug + ) return o, (q, k, v, ab, segment_ids, o, l, m) def _flash_attention_bwd( save_residuals: bool, causal: bool, - sm_scale: float, + softmax_scale: float, block_sizes: LegacyBlockSizes, debug: bool, residuals, @@ -483,7 +487,7 @@ def _flash_attention_bwd( block_k_major=block_sizes.block_k_major_dkv, block_k=block_sizes.block_k_dkv, block_q=block_sizes.block_q_dkv, - sm_scale=sm_scale, + softmax_scale=softmax_scale, causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, @@ -502,7 +506,7 @@ def _flash_attention_bwd( block_q_major=block_sizes.block_q_dq, block_k_major=block_sizes.block_k_major_dq, block_k=block_sizes.block_k_dq, - sm_scale=sm_scale, + softmax_scale=softmax_scale, causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, @@ -521,7 +525,7 @@ def _flash_attention_impl( segment_ids, save_residuals, causal, - sm_scale, + softmax_scale, block_b, block_q, block_k_major, @@ -590,7 +594,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _): _flash_attention_kernel, causal=causal, mask_value=DEFAULT_MASK_VALUE, - sm_scale=sm_scale, + softmax_scale=softmax_scale, block_k=block_k, kv_seq_len=kv_seq_len, ) @@ -722,7 +726,7 @@ def _flash_attention_bwd_dkv( block_q: Optional[int], block_k_major: Optional[int], block_k: Optional[int], - sm_scale: float, + softmax_scale: float, causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, debug: bool = False, @@ -874,7 +878,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): _flash_attention_dkv_kernel, block_q=block_q, block_k=block_k, - sm_scale=sm_scale, + softmax_scale=softmax_scale, causal=causal, mask_value=mask_value, q_seq_len=q_seq_len, @@ -922,7 +926,7 @@ def _flash_attention_bwd_dq( block_q_major: Optional[int], block_k_major: Optional[int], block_k: Optional[int], - sm_scale: float, + softmax_scale: float, causal: bool, mask_value: float, debug: bool, @@ -1064,7 +1068,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) kernel = functools.partial( _flash_attention_dq_kernel, - sm_scale=sm_scale, + softmax_scale=softmax_scale, causal=causal, mask_value=mask_value, block_k=block_k,