Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix RLHF slowdown in attention multi steps extend_step. (apple#849)
jax.lax.dynamic_update_slice_in_dim is generally faster than advanced indexing, but an unusual slowdown was observed, with RLHF sampling taking up to 3 hours per run. Investigate and fix it. https://a1350286.slack.com/archives/C03HJAYC7JA/p1731998432387409?thread_ts=1731968765.840839&cid=C03HJAYC7JA For your information, in https://github.pie.apple.com/foundation-models/axlearn/pull/894, I experimented with both dynamic_update_slice and advanced indexing on TPUv4 and chose the faster option. It's also known that dynamic_update_slice performs better when copying contiguous memory. This is a very surprising case. Advanced Indexing ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 7.16 ms 0.623 ms 492 QkvLinearExtendStepBenchmark/2048/16/4096/1 8.52 ms 0.624 ms 561 QkvLinearExtendStepBenchmark/2048/16/32768/1 34.6 ms 1.64 ms 78 QkvLinearExtendStepBenchmark/2048/16/4096/8 63.6 ms 1.74 ms 81 QkvLinearExtendStepBenchmark/2048/16/4096/64 276 ms 2.40 ms 81 QkvLinearExtendStepBenchmark/2048/16/4096/512 2541 ms 81.6 ms 1 dynamic_update_slice ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.70 ms 0.513 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.40 ms 0.519 ms 1174 QkvLinearExtendStepBenchmark/2048/16/32768/1 20.1 ms 0.930 ms 404 QkvLinearExtendStepBenchmark/2048/16/4096/8 3.68 ms 0.524 ms 1139 QkvLinearExtendStepBenchmark/2048/16/4096/64 3.74 ms 0.532 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/512 2530 ms 80.4 ms 1
- Loading branch information