Skip to content

Commit

Permalink
[Release/2.6] Fix torch.lerp RuntimeError when weight is CPU scalar w…
Browse files Browse the repository at this point in the history
…hile input & end are XPU tensor (#1201)

solve issue #1200

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
chunhuanMeng and xytintel authored Dec 24, 2024
1 parent 7ecb0b1 commit 49acdfc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/scripts/apply_torch_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Fallback to CPU for XPU FP64
"https://github.com/pytorch/pytorch/pull/126516",
# Modify the tolerance level in TIMM benchmark
"https://github.com/pytorch/pytorch/pull/129735",
"https://github.com/pytorch/pytorch/pull/143739",
# [XPU] Update XPU C Shim Header
"https://github.com/pytorch/pytorch/pull/141086",
]
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/transformers/SDPUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

namespace sdp {

using c10::array_of;

bool check_all_tensors_on_device(sdp_params const& params, bool debug) {
// Check that all tensors are on the GPU device
// This should be handled by the stub dispatch, but whe call
Expand Down
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/LerpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,29 @@ struct LerpScalarFunctor {
opmath_t weight_val_;
};

void lerp_scalar_kernel(
at::TensorIteratorBase& iter,
const c10::Scalar& weight);

void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_xpu", [&] {
if (iter.is_cpu_scalar(3)) {
auto weight_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return lerp_scalar_kernel(iter, weight_val);
}
gpu_kernel(iter, LerpTensorComplexFunctor<scalar_t>());
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "lerp_xpu", [&] {
if (iter.is_cpu_scalar(3)) {
auto weight_val = iter.scalar_value<scalar_t>(3);
iter.remove_operand(3);
return lerp_scalar_kernel(iter, weight_val);
}
gpu_kernel(iter, LerpTensorFunctor<scalar_t>());
});
}
Expand Down

0 comments on commit 49acdfc

Please sign in to comment.