You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After the recent Flash Attention update to version 2.7.0, particularly in commit 83e41b3, compatibility issues arise when used with PyTorch 2.4.1 due to changes in input, output, and custom operator wrapper.
Details
Input Changes
A new argument softcap has been introduced.
window_size has been split into two parameters: window_size_left and window_size_right.
Output Changes
_flash_attn_forward now returns 4 values instead of the previous 8.
After applying these fixes, tests pass with acceptable error. However, warnings persist during the backward stage:
############################### forward:##############################
out: max 2.94, mean 0.041
lse: max 8.98, mean 7.75
out diff:
[0] max 0.00391, mean 7.77e-05
[1] max 0.00195, mean 8.44e-05
[2] max 0.000977, mean 6.77e-05
[3] max 0.000977, mean 5.77e-05
lse diff:
[0] max 9.54e-07, mean 1.37e-07
[1] max 1.43e-06, mean 2.19e-07
[2] max 1.91e-06, mean 2.45e-07
[3] max 1.91e-06, mean 2.9e-07
############################### backward:##############################
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/users/xhaoxuan/miniconda3/envs/ring-ulysses/lib/python3.10/site-packages/torch/autograd/graph.py:769: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1724789115564/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
local_dqkv:
[0] max 4.47, mean 0.0747
[1] max 0.516, mean 0.0299
[2] max 0.217, mean 0.0199
[3] max 0.297, mean 0.0126
dq diff:
[0] max 0.0156, mean 0.00066
[1] max 0.000977, mean 5.56e-05
[2] max 0.000488, mean 4.46e-05
[3] max 0.000488, mean 1.96e-05
dk diff:
[0] max 0.0156, mean 0.00058
[1] max 0.000977, mean 6.58e-05
[2] max 0.000488, mean 4.01e-05
[3] max 0.000488, mean 1.74e-05
dv diff:
[0] max 0.0156, mean 0.000576
[1] max 0.000977, mean 6.72e-05
[2] max 0.000488, mean 3.65e-05
[3] max 0.000488, mean 1.73e-05
[rank0]:[W1114 20:14:11.788433858 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
I reran the tests with Flash Attention 2.6.3, and these warnings persist, likely due to using torch version 2.4.1 instead of the flash-attn update.
The text was updated successfully, but these errors were encountered:
After the recent Flash Attention update to version 2.7.0, particularly in commit
83e41b3
, compatibility issues arise when used with PyTorch 2.4.1 due to changes in input, output, and custom operator wrapper.Details
Input Changes
softcap
has been introduced.Output Changes
_flash_attn_forward
now returns 4 values instead of the previous 8.Custom Operator Update:
torch.library.custom_op
for wrapping custom operations when torch version is higher than 2.4.0. This causes an error in the following code snippet:Temporary Fix
Remaining Issue
After applying these fixes, tests pass with acceptable error. However, warnings persist during the backward stage:
I reran the tests with Flash Attention 2.6.3, and these warnings persist, likely due to using torch version 2.4.1 instead of the flash-attn update.
The text was updated successfully, but these errors were encountered: