diff --git a/src/ATen/native/xpu/PointwiseOps.cpp b/src/ATen/native/xpu/PointwiseOps.cpp index 210cec3e6..a01bdc391 100644 --- a/src/ATen/native/xpu/PointwiseOps.cpp +++ b/src/ATen/native/xpu/PointwiseOps.cpp @@ -6,6 +6,63 @@ namespace at { +TensorIterator addcdiv_meta( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value, + Tensor& out) { + if (isIntegralType(tensor1.scalar_type(), /*includeBool=*/true) && + isIntegralType(tensor2.scalar_type(), /*includeBool=*/true)) { + TORCH_CHECK( + false, + "Integer division with addcdiv is no longer supported, and in a future ", + "release addcdiv will perform a true division of tensor1 and tensor2. ", + "The historic addcdiv behavior can be implemented as ", + "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ", + "for integer inputs and as ", + "(input + value * tensor1 / tensor2) for float inputs. ", + "The future addcdiv behavior is just the latter implementation: ", + "(input + value * tensor1 / tensor2), for all dtypes."); + } + + TensorIterator iter; + iter.build_ternary_op(out, self, tensor1, tensor2); + return iter; +} + +Tensor& XPUNativeFunctions::addcdiv_out( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value, + Tensor& out) { + auto iter = addcdiv_meta(self, tensor1, tensor2, value, out); + native::xpu::addcdiv_kernel(iter, value); + return out; +} + +Tensor XPUNativeFunctions::addcdiv( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value) { + Tensor out; + auto iter = addcdiv_meta(self, tensor1, tensor2, value, out); + native::xpu::addcdiv_kernel(iter, value); + return iter.output(); +} + +Tensor& XPUNativeFunctions::addcdiv_( + Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value) { + auto iter = addcdiv_meta(self, tensor1, tensor2, value, self); + native::xpu::addcdiv_kernel(iter, value); + return self; +} + TensorIterator addcmul_meta( const Tensor& self, const Tensor& tensor1, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 496eb00f1..b9f9708f9 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -163,7 +163,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "adaptive_max_pool2d.out", "adaptive_max_pool3d_backward.grad_input", "adaptive_max_pool3d.out", - "addcdiv.out", "aminmax.out", "angle", "argmin.out", diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp index 7b00d09e3..d38f511d7 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp @@ -1,6 +1,6 @@ #include +#include #include -#include #include #include @@ -8,31 +8,98 @@ namespace at::native::xpu { template -struct AddcmulKernelFunctor { - using opmath_t = at::opmath_type; +struct AddcmulFunctor { + using accscalar_t = at::acc_type; scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { - return static_cast(a) + - alpha_ * static_cast(b) * static_cast(c); + return static_cast(a) + + alpha_ * static_cast(b) * static_cast(c); } - AddcmulKernelFunctor(opmath_t alpha) : alpha_(alpha) {} + AddcmulFunctor(accscalar_t alpha) : alpha_(alpha) {} private: - opmath_t alpha_; + accscalar_t alpha_; +}; + +template +struct AddcmulComplexFunctor { + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * b * c; + } + + AddcmulComplexFunctor(scalar_t alpha) : alpha_(alpha) {} + + private: + scalar_t alpha_; }; void addcmul_kernel(TensorIterator& iter, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "addcmul_xpu", - [&]() { - using opmath_t = at::opmath_type; - auto alpha = value.to(); - AddcmulKernelFunctor f(alpha); - gpu_kernel(iter, f); - }); + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_xpu", [&]() { + auto alpha = value.to(); + gpu_kernel(iter, AddcmulComplexFunctor(alpha)); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "addcmul_xpu", + [&]() { + using accscalar_t = at::acc_type; + auto alpha = value.to(); + gpu_kernel(iter, AddcmulFunctor(alpha)); + }); + } +} + +template +struct AddcdivFunctor { + using accscalar_t = at::acc_type; + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * (b / static_cast(c)); + } + + AddcdivFunctor(accscalar_t alpha) : alpha_(alpha) {} + + private: + accscalar_t alpha_; +}; + +template +struct AddcdivComplexFunctor { + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * (b / c); + } + + AddcdivComplexFunctor(scalar_t alpha) : alpha_(alpha) {} + + private: + scalar_t alpha_; +}; + +void addcdiv_kernel(TensorIterator& iter, Scalar value) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_xpu", [&]() { + auto alpha = value.to(); + AddcdivComplexFunctor f(alpha); + gpu_kernel(iter, f); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "addcdiv_xpu", + [&]() { + using accscalar_t = at::acc_type; + auto alpha = value.to(); + AddcdivFunctor f(alpha); + gpu_kernel(iter, f); + }); + } } template diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h index fdb216dbd..c775b88e5 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h @@ -6,6 +6,8 @@ namespace at::native::xpu { void addcmul_kernel(TensorIterator& iter, Scalar value); +void addcdiv_kernel(TensorIterator& iter, Scalar value); + void mse_backward_kernel(TensorIterator& iter, const Scalar& value); } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 6511f4120..e7c9da34c 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -44,6 +44,7 @@ "bitwise_or", "bitwise_xor", "addcmul", + "addcdiv", "clamp", "clamp_max", "clamp_min", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2cd535394..db6f77415 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -496,6 +496,9 @@ supported: - avg_pool2d.out - avg_pool2d_backward - avg_pool2d_backward.grad_input + - addcdiv.out + - addcdiv + - addcdiv_ - addcmul.out - addcmul - addcmul_