Skip to content

Commit

Permalink
Renames sm_scale to softmax_scale for consistency. (apple#894)
Browse files Browse the repository at this point in the history
* Renames `sm_scale` to `softmax_scale` for consistency.

* black
  • Loading branch information
ruomingp authored Dec 17, 2024
1 parent a7e2a95 commit a15a3bc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 32 deletions.
32 changes: 18 additions & 14 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
@pytest.mark.parametrize("block_size", [64, 128])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("sm_scale", [1.0, 0.123])
@pytest.mark.parametrize("softmax_scale", [1.0, 0.123])
@pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"])
@pytest.mark.parametrize("use_segment_ids", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32])
Expand All @@ -52,7 +52,7 @@ def test_triton_fwd_only_against_ref(
per_head_dim: int,
block_size: int,
causal: bool,
sm_scale: float,
softmax_scale: float,
attention_bias_type: Literal["2d", "4d", None],
use_segment_ids: bool,
input_dtype: jnp.dtype,
Expand Down Expand Up @@ -82,13 +82,13 @@ def impl(q, k, v, bias, segment_ids):
block_q=block_size,
block_k=block_size,
causal=causal,
softmax_scale=sm_scale,
softmax_scale=softmax_scale,
)
out, _ = jax.vjp(fn, q, k, v, bias, segment_ids)
return out

o = impl(q, k, v, bias, segment_ids)
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale)
o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale)
chex.assert_trees_all_close(o, o_ref, atol=0.07)


Expand Down Expand Up @@ -147,7 +147,7 @@ def test_triton_against_xla_ref(
jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None
)

sm_scale = q.shape[-1] ** -0.5
softmax_scale = q.shape[-1] ** -0.5

# Compare outputs.
jax_out = flash_attention(
Expand All @@ -157,11 +157,13 @@ def test_triton_against_xla_ref(
bias,
segment_ids,
causal=causal,
softmax_scale=sm_scale,
softmax_scale=softmax_scale,
block_q=block_size,
block_k=block_size,
)
jax_ref_out = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale)
jax_ref_out = mha_reference(
q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale
)
if input_dtype == jnp.float16:
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005)
elif input_dtype == jnp.float32:
Expand All @@ -177,14 +179,14 @@ def fn(q, k, v, bias, segment_ids):
bias,
segment_ids,
causal=causal,
softmax_scale=sm_scale,
softmax_scale=softmax_scale,
block_q=block_size,
block_k=block_size,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale
q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale
).sum()

# Compare gradients.
Expand Down Expand Up @@ -224,11 +226,13 @@ def test_cudnn_against_triton_ref(
jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype
)

sm_scale = q.shape[-1] ** -0.5
softmax_scale = q.shape[-1] ** -0.5

# Compare outputs.
jax_out = cudnn_dot_product_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale)
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale)
jax_out = cudnn_dot_product_attention(
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale
)
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale)
if dtype == jnp.bfloat16:
# We relax the atol to support bf16 in the unit test.
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5)
Expand All @@ -239,11 +243,11 @@ def test_cudnn_against_triton_ref(

def fn(q, k, v):
return cudnn_dot_product_attention(
q, k, v, bias=None, causal=causal, softmax_scale=sm_scale
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale
).sum()

def ref_fn(q, k, v):
return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=sm_scale).sum()
return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum()

# Compare gradients.
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v)
Expand Down
40 changes: 22 additions & 18 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def _legacy_tpu_flash_attention(
ab=bias,
segment_ids=SegmentIds(q=segment_ids, kv=segment_ids) if segment_ids is not None else None,
causal=causal,
# If sm_scale==1.0, the kernel skips applying it.
sm_scale=1.0,
# If softmax_scale==1.0, the kernel skips applying it.
softmax_scale=1.0,
block_sizes=block_sizes,
debug=False,
)
Expand Down Expand Up @@ -332,7 +332,7 @@ def _tpu_splash_attention(
jax.jit,
static_argnames=[
"causal",
"sm_scale",
"softmax_scale",
"block_sizes",
"debug",
],
Expand All @@ -345,7 +345,7 @@ def pallas_tpu_flash_attention(
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
*,
causal: bool = False,
sm_scale: float = 1.0,
softmax_scale: float = 1.0,
block_sizes: Optional[LegacyBlockSizes] = None,
debug: bool = False,
):
Expand Down Expand Up @@ -396,7 +396,9 @@ def pallas_tpu_flash_attention(
block_sizes = LegacyBlockSizes.get_default(
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
)
return _flash_attention(q, k, v, ab, segment_ids, False, causal, sm_scale, block_sizes, debug)
return _flash_attention(
q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug
)


@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
Expand All @@ -408,7 +410,7 @@ def _flash_attention(
segment_ids,
save_residuals,
causal,
sm_scale,
softmax_scale,
block_sizes,
debug,
):
Expand All @@ -420,7 +422,7 @@ def _flash_attention(
segment_ids,
save_residuals,
causal,
sm_scale,
softmax_scale,
block_sizes.block_b,
block_sizes.block_q,
block_sizes.block_k_major,
Expand All @@ -437,20 +439,22 @@ def _flash_attention_fwd(
segment_ids,
save_residuals,
causal,
sm_scale,
softmax_scale,
block_sizes,
debug,
):
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
o, l, m = _flash_attention(q, k, v, ab, segment_ids, True, causal, sm_scale, block_sizes, debug)
o, l, m = _flash_attention(
q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug
)
return o, (q, k, v, ab, segment_ids, o, l, m)


def _flash_attention_bwd(
save_residuals: bool,
causal: bool,
sm_scale: float,
softmax_scale: float,
block_sizes: LegacyBlockSizes,
debug: bool,
residuals,
Expand Down Expand Up @@ -483,7 +487,7 @@ def _flash_attention_bwd(
block_k_major=block_sizes.block_k_major_dkv,
block_k=block_sizes.block_k_dkv,
block_q=block_sizes.block_q_dkv,
sm_scale=sm_scale,
softmax_scale=softmax_scale,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
debug=debug,
Expand All @@ -502,7 +506,7 @@ def _flash_attention_bwd(
block_q_major=block_sizes.block_q_dq,
block_k_major=block_sizes.block_k_major_dq,
block_k=block_sizes.block_k_dq,
sm_scale=sm_scale,
softmax_scale=softmax_scale,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
debug=debug,
Expand All @@ -521,7 +525,7 @@ def _flash_attention_impl(
segment_ids,
save_residuals,
causal,
sm_scale,
softmax_scale,
block_b,
block_q,
block_k_major,
Expand Down Expand Up @@ -590,7 +594,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
_flash_attention_kernel,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
sm_scale=sm_scale,
softmax_scale=softmax_scale,
block_k=block_k,
kv_seq_len=kv_seq_len,
)
Expand Down Expand Up @@ -722,7 +726,7 @@ def _flash_attention_bwd_dkv(
block_q: Optional[int],
block_k_major: Optional[int],
block_k: Optional[int],
sm_scale: float,
softmax_scale: float,
causal: bool = False,
mask_value: float = DEFAULT_MASK_VALUE,
debug: bool = False,
Expand Down Expand Up @@ -874,7 +878,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
_flash_attention_dkv_kernel,
block_q=block_q,
block_k=block_k,
sm_scale=sm_scale,
softmax_scale=softmax_scale,
causal=causal,
mask_value=mask_value,
q_seq_len=q_seq_len,
Expand Down Expand Up @@ -922,7 +926,7 @@ def _flash_attention_bwd_dq(
block_q_major: Optional[int],
block_k_major: Optional[int],
block_k: Optional[int],
sm_scale: float,
softmax_scale: float,
causal: bool,
mask_value: float,
debug: bool,
Expand Down Expand Up @@ -1064,7 +1068,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)

kernel = functools.partial(
_flash_attention_dq_kernel,
sm_scale=sm_scale,
softmax_scale=softmax_scale,
causal=causal,
mask_value=mask_value,
block_k=block_k,
Expand Down

0 comments on commit a15a3bc

Please sign in to comment.