Skip to content

Commit

Permalink
Fix RLHF slowdown in attention multi steps extend_step.
Browse files Browse the repository at this point in the history
jax.lax.dynamic_update_slice_in_dim is generally faster than advanced indexing,
but an unusual slowdown was observed, with RLHF sampling taking up to 3 hours
per run. Investigate and fix it.
https://a1350286.slack.com/archives/C03HJAYC7JA/p1731998432387409?thread_ts=1731968765.840839&cid=C03HJAYC7JA

For your information, in
https://github.pie.apple.com/foundation-models/axlearn/pull/894, I experimented
with both dynamic_update_slice and advanced indexing on TPUv4 and chose the
faster option. It's also known that dynamic_update_slice performs better when
copying contiguous memory. This is a very surprising case.

Advanced Indexing
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         7.16 ms        0.623 ms          492
QkvLinearExtendStepBenchmark/2048/16/4096/1         8.52 ms        0.624 ms          561
QkvLinearExtendStepBenchmark/2048/16/32768/1        34.6 ms         1.64 ms           78
QkvLinearExtendStepBenchmark/2048/16/4096/8         63.6 ms         1.74 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/64         276 ms         2.40 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/512       2541 ms         81.6 ms            1

dynamic_update_slice
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.70 ms        0.513 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.40 ms        0.519 ms         1174
QkvLinearExtendStepBenchmark/2048/16/32768/1        20.1 ms        0.930 ms          404
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.68 ms        0.524 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.74 ms        0.532 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/512       2530 ms         80.4 ms            1
  • Loading branch information
ds-hwang committed Nov 19, 2024
1 parent 420ed7a commit f7ea994
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,16 +886,18 @@ def extend_step(
k_proj = k_proj.astype(cached_key.dtype)
v_proj = v_proj.astype(cached_value.dtype)

# Function to update the cache for a single batch element.
def update_single(cached_kv_slice, kv_proj_slice, time_idx):
return jax.lax.dynamic_update_slice_in_dim(
cached_kv_slice, kv_proj_slice, time_idx, axis=0
)
# TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced
# indexing, but an unusual slowdown was observed, with RLHF sampling taking up to
# 3 hours per run. Investigate and fix it.
# Note: All X_idx are small, so generating them on-demand is not costly.
b, _, n, h = cached_key.shape
b_idx = jnp.arange(b)[:, None, None, None]
t_idx = (jnp.arange(k_proj.shape[1])[None] + time_step[:, None])[:, :, None, None]
n_idx = jnp.arange(n)[None, None, :, None]
h_idx = jnp.arange(h)[None, None, None, :]
k_proj = cached_key.at[b_idx, t_idx, n_idx, h_idx].set(k_proj)
v_proj = cached_value.at[b_idx, t_idx, n_idx, h_idx].set(v_proj)

# Use jax.vmap to vectorize over the batch dimension.
vmap_update = jax.vmap(update_single)
k_proj = vmap_update(cached_key, k_proj, time_step)
v_proj = vmap_update(cached_value, v_proj, time_step)
updated_state.update(key=k_proj, value=v_proj)
return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj)

Expand Down

0 comments on commit f7ea994

Please sign in to comment.