From 4df8a019d506e3186f9d5253ba5c105bdf827fe3 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:04:05 -0500 Subject: [PATCH] [rocm6.3_internal_testing] Fix SWDEV-459623 - 2 (#1629) Fix faulty conflict merge when cherry-picking https://github.com/ROCm/pytorch/commit/aea038675ddb56a473066ca7d0ef80d345ce2123 to rocm6.3_internal_testing --------- Co-authored-by: Xinya Zhang --- .../src/ATen/native/transformers/hip/flash_attn/flash_api.hip | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 174b253d11499..1cc0789f53f6e 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -96,6 +96,10 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int window_size_right, const bool return_softmax, std::optional gen_) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream);