Skip to content

Commit

Permalink
[AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic an…
Browse files Browse the repository at this point in the history
…d CI fixes and refactoring (vllm-project#4129)
  • Loading branch information
hongxiayang authored Apr 22, 2024
1 parent a37d815 commit 95e5b08
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 217 deletions.
2 changes: 0 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ steps:
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py

- label: Core Test
command: pytest -v -s core
Expand Down
5 changes: 1 addition & 4 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"

ARG FA_BRANCH="3d2b6f5"
ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"

# whether to build flash-attention
Expand Down Expand Up @@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip numba
RUN python3 -m pip install xformers==0.0.23 --no-deps

RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..
Expand Down
33 changes: 0 additions & 33 deletions patch_xformers.rocm.sh

This file was deleted.

13 changes: 0 additions & 13 deletions rocm_patch/commonpy_xformers-0.0.23.rocm.patch

This file was deleted.

152 changes: 0 additions & 152 deletions rocm_patch/flashpy_xformers-0.0.23.rocm.patch

This file was deleted.

31 changes: 18 additions & 13 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,30 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
if self.use_naive_attn:
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
# if not using triton, navi3x not use flash-attn either
if torch.cuda.get_device_capability()[0] == 11:
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True

if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
Expand Down Expand Up @@ -247,13 +252,13 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_naive_attn or self.use_triton_flash_attn:
if self.use_triton_flash_attn or self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_fuc(
out = self.attn_func(
query,
key,
value,
Expand Down

0 comments on commit 95e5b08

Please sign in to comment.