diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 94a1c4a678cf7..8c3e74e502a55 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Launch activation and gating kernel. #ifdef USE_ROCM -#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - vllm::scaled_act_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, \ - 1.0 / (*scale.data_ptr())); \ - }); + #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + vllm::scaled_act_and_mul_kernel> \ + <<>>( \ + out.data_ptr(), \ + input.data_ptr(), d, \ + 1.0 / (*scale.data_ptr())); \ + }); #endif void silu_and_mul(torch::Tensor& out, // [..., d] diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index de098a9ee0c19..405ba213628f6 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -247,7 +247,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] LAUNCH_RMS_NORM(0); } #else - LAUNCH_RMS_NORM(0); + LAUNCH_RMS_NORM(0); #endif } diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 591bedfa3a6f1..77eadb2997689 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -218,12 +218,6 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self, @@ -459,10 +453,12 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - if logits_soft_cap is not None: - raise ValueError( - "ROCmFlashAttention does not support attention logits soft " - "capping.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -487,6 +483,14 @@ def __init__( # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention @@ -511,6 +515,11 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") + self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") @@ -716,6 +725,7 @@ def forward( causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) # common code for prefill