Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility Breakage Due to Flash Attention API Update (v2.7.0) #56

Open
TerryXhx opened this issue Nov 14, 2024 · 0 comments
Open

Compatibility Breakage Due to Flash Attention API Update (v2.7.0) #56

TerryXhx opened this issue Nov 14, 2024 · 0 comments

Comments

@TerryXhx
Copy link

After the recent Flash Attention update to version 2.7.0, particularly in commit 83e41b3, compatibility issues arise when used with PyTorch 2.4.1 due to changes in input, output, and custom operator wrapper.

Details

Input Changes

  • A new argument softcap has been introduced.
  • window_size has been split into two parameters: window_size_left and window_size_right.

Output Changes

  • _flash_attn_forward now returns 4 values instead of the previous 8.

Custom Operator Update:

            params = get_default_args(_flash_attn_forward).copy()
            params.update(
                {
                    "q": q,
                    "k": k,
                    "v": v,
                    "dropout_p": dropout_p,
                    "softmax_scale": softmax_scale,
                    "causal": causal and step == 0,
                    "window_size": window_size,
                    "alibi_slopes": alibi_slopes,
                    "return_softmax": True and dropout_p > 0,
                }
            )
            block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
[rank1]: Traceback (most recent call last):
[rank1]:   File "/users/xhaoxuan/ring-flash-attention/test/test_ring_flash_attn_func.py", line 60, in <module>
[rank1]:     ring_out, ring_lse, _ = fn(
[rank1]:   File "/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/ring_flash_attn/ring_flash_attn.py", line 224, in ring_flash_attn_qkvpacked_func
[rank1]:     return RingFlashAttnFunc.apply(
[rank1]:   File "/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/ring_flash_attn/ring_flash_attn.py", line 169, in forward
[rank1]:     out, softmax_lse = ring_flash_attn_forward(
[rank1]:   File "/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/ring_flash_attn/ring_flash_attn.py", line 49, in ring_flash_attn_forward
[rank1]:     block_out, block_lse, _, _ = _flash_attn_forward(**params)
[rank1]: TypeError: CustomOpDef.__call__() got multiple values for argument 'self'

Temporary Fix

  1. Input and output changes (points 1 and 2) were straightforward to adjust.
  2. For the custom op wrapper (point 3), I applied this fix:
            if torch.__version__ >= "2.4.0":
                params = {}
            else:
                params = get_default_args(_flash_attn_forward).copy()

Remaining Issue

After applying these fixes, tests pass with acceptable error. However, warnings persist during the backward stage:

##############################
# forward:
##############################
out: max 2.94, mean 0.041
lse: max 8.98, mean 7.75
out diff:
[0] max 0.00391, mean 7.77e-05
[1] max 0.00195, mean 8.44e-05
[2] max 0.000977, mean 6.77e-05
[3] max 0.000977, mean 5.77e-05
lse diff:
[0] max 9.54e-07, mean 1.37e-07
[1] max 1.43e-06, mean 2.19e-07
[2] max 1.91e-06, mean 2.45e-07
[3] max 1.91e-06, mean 2.9e-07
##############################
# backward:
##############################
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
 return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
 return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
 return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
 return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
local_dqkv:
[0] max 4.47, mean 0.0747
[1] max 0.516, mean 0.0299
[2] max 0.217, mean 0.0199
[3] max 0.297, mean 0.0126
dq diff:
[0] max 0.0156, mean 0.00066
[1] max 0.000977, mean 5.56e-05
[2] max 0.000488, mean 4.46e-05
[3] max 0.000488, mean 1.96e-05
dk diff:
[0] max 0.0156, mean 0.00058
[1] max 0.000977, mean 6.58e-05
[2] max 0.000488, mean 4.01e-05
[3] max 0.000488, mean 1.74e-05
dv diff:
[0] max 0.0156, mean 0.000576
[1] max 0.000977, mean 6.72e-05
[2] max 0.000488, mean 3.65e-05
[3] max 0.000488, mean 1.73e-05
[rank0]:[W1114 20:14:11.788433858 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

I reran the tests with Flash Attention 2.6.3, and these warnings persist, likely due to using torch version 2.4.1 instead of the flash-attn update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant