From b5924a7321a9665e29b7c6de39476d668cd6cd9c Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Thu, 2 Jan 2025 00:56:13 -0800 Subject: [PATCH] Init MKL for Pytorch XPU and enable fft_c2c --- CMakeLists.txt | 1 + cmake/Modules/FindONEMKL.cmake | 66 ++++ cmake/ONEMKL.cmake | 11 + src/ATen/CMakeLists.txt | 3 + src/ATen/native/xpu/SpectralOps.cpp | 28 ++ src/ATen/native/xpu/XPUFallback.template | 1 - src/ATen/native/xpu/mkl/SpectralOps.cpp | 403 +++++++++++++++++++++++ src/ATen/native/xpu/mkl/SpectralOps.h | 18 + src/BuildOnLinux.cmake | 1 + src/BuildOnWindows.cmake | 3 +- src/CMakeLists.txt | 4 + test/xpu/skip_list_common.py | 25 ++ test/xpu/test_spectral_ops_xpu.py | 81 +++++ test/xpu/xpu_test_utils.py | 21 +- yaml/native/native_functions.yaml | 11 + 15 files changed, 674 insertions(+), 3 deletions(-) create mode 100644 cmake/Modules/FindONEMKL.cmake create mode 100644 cmake/ONEMKL.cmake create mode 100644 src/ATen/native/xpu/SpectralOps.cpp create mode 100644 src/ATen/native/xpu/mkl/SpectralOps.cpp create mode 100644 src/ATen/native/xpu/mkl/SpectralOps.h create mode 100644 test/xpu/test_spectral_ops_xpu.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a0ba1fd99..49cf232a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ set(TORCH_XPU_OPS_ROOT ${PROJECT_SOURCE_DIR}) list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules) include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake) +include(${TORCH_XPU_OPS_ROOT}/cmake/ONEMKL.cmake) include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake) if(BUILD_TEST) diff --git a/cmake/Modules/FindONEMKL.cmake b/cmake/Modules/FindONEMKL.cmake new file mode 100644 index 000000000..09339d24a --- /dev/null +++ b/cmake/Modules/FindONEMKL.cmake @@ -0,0 +1,66 @@ +set(ONEMKL_FOUND FALSE) + +set(ONEMKL_LIBRARIES) + +# In order to be compatible with various situations of Pytorch development +# bundle setup, ENV{MKLROOT} and SYCL_ROOT will be checked sequentially to get +# the root directory of oneMKL. +if(DEFINED ENV{MKLROOT}) + # Directly get the root directory of oneMKL if ENV{MKLROOT} exists. + set(ONEMKL_ROOT $ENV{MKLROOT}) +elseif(SYCL_FOUND) + # oneMKL configuration may not be imported into the build system. Get the root + # directory of oneMKL based on the root directory of compiler relatively. + get_filename_component(ONEMKL_ROOT "${SYCL_ROOT}/../../mkl/latest" REALPATH) +endif() + +if(NOT DEFINED ONEMKL_ROOT) + message( + WARNING + "Cannot find either ENV{MKLROOT} or SYCL_ROOT, please setup oneAPI environment before building!!" + ) + return() +endif() + +if(NOT EXISTS ${ONEMKL_ROOT}) + message( + WARNING + "${ONEMKL_ROOT} not found, please setup oneAPI environment before building!!" + ) + return() +endif() + +find_file( + ONEMKL_INCLUDE_DIR + NAMES include + HINTS ${ONEMKL_ROOT} + NO_DEFAULT_PATH) + +find_file( + ONEMKL_LIB_DIR + NAMES lib + HINTS ${ONEMKL_ROOT} + NO_DEFAULT_PATH) + +if((ONEMKL_INCLUDE_DIR STREQUAL "ONEMKL_INCLUDE_DIR-NOTFOUND") + OR (ONEMKL_LIB_DIR STREQUAL "ONEMKL_LIB_DIR-NOTFOUND")) + message(WARNING "oneMKL sdk is incomplete!!") + return() +endif() + +if(WIN32) + set(MKL_LIB_NAMES "mkl_sycl" "mkl_intel_lp64" "mkl_intel_thread" "mkl_core") +else() + set(MKL_LIB_NAMES "mkl_sycl_dft" "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core") +endif() + +foreach(LIB_NAME IN LISTS MKL_LIB_NAMES) + find_library( + ${LIB_NAME}_library + NAMES ${LIB_NAME} + HINTS ${ONEMKL_LIB_DIR} + NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH) + list(APPEND ONEMKL_LIBRARIES ${${LIB_NAME}_library}) +endforeach() + +set(ONEMKL_FOUND TRUE) diff --git a/cmake/ONEMKL.cmake b/cmake/ONEMKL.cmake new file mode 100644 index 000000000..2d40ebccf --- /dev/null +++ b/cmake/ONEMKL.cmake @@ -0,0 +1,11 @@ +find_package(ONEMKL) +if(NOT ONEMKL_FOUND) + message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!") +endif() + +set(TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR ${ONEMKL_INCLUDE_DIR}) + +set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES}) + +list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 1 "-Wl,--start-group") +list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--end-group") diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index af9ef7d94..515ac2a29 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -2,14 +2,17 @@ file(GLOB xpu_h "xpu/*.h") file(GLOB xpu_cpp "xpu/*.cpp") +file(GLOB xpu_mkl "native/xpu/mkl/*.cpp") file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp") file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp") list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) +list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp}) list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl}) set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) +set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE) set(ATen_XPU_NATIVE_CPP_SRCS ${ATen_XPU_NATIVE_CPP_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) diff --git a/src/ATen/native/xpu/SpectralOps.cpp b/src/ATen/native/xpu/SpectralOps.cpp new file mode 100644 index 000000000..af82394f1 --- /dev/null +++ b/src/ATen/native/xpu/SpectralOps.cpp @@ -0,0 +1,28 @@ +#include +#include +#include + +namespace at::native { + +Tensor _fft_c2c_xpu( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward) { + TORCH_CHECK(self.is_complex()); + + return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); +} + +Tensor& _fft_c2c_xpu_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out) { + TORCH_CHECK(self.is_complex()); + + return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); +} + +} // namespace at::native diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 72f2aacdd..c4f20b2c1 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_cholesky_solve_helper", "dot", "_efficient_attention_forward", - "_fft_c2c", "_fft_c2r", "_fft_r2c", "_flash_attention_forward", diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp new file mode 100644 index 000000000..4479dd9c3 --- /dev/null +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -0,0 +1,403 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace oneapi::mkl::dft; + +namespace at::native::xpu { + +namespace impl { + +constexpr int64_t mkl_max_ndim = 3; + +// Sort transform dimensions by input layout, for best performance +// exclude_last is for onesided transforms where the last dimension cannot be +// reordered +static DimVector _sort_dims( + const Tensor& self, + IntArrayRef dim, + bool exclude_last = false) { + DimVector sorted_dims(dim.begin(), dim.end()); + auto self_strides = self.strides(); + std::sort( + sorted_dims.begin(), + sorted_dims.end() - exclude_last, + [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); + return sorted_dims; +} + +template +void _mkl_dft( + const Tensor& input, + Tensor& output, + int64_t signal_ndim, + bool complex_input, + bool complex_output, + bool inverse, + IntArrayRef checked_signal_sizes, + bool onesided, + int64_t batch) { + auto& queue = at::xpu::getCurrentSYCLQueue(); + std::vector mkl_signal_sizes( + checked_signal_sizes.begin() + 1, checked_signal_sizes.end()); + + auto istrides = input.strides(); + auto ostrides = output.strides(); + int64_t idist = istrides[0]; + int64_t odist = ostrides[0]; + + std::vector fwd_strides(1 + signal_ndim, 0), + bwd_strides(1 + signal_ndim, 0); + + for (int64_t i = 1; i <= signal_ndim; i++) { + if (!inverse) { + fwd_strides[i] = istrides[i]; + bwd_strides[i] = ostrides[i]; + } else { + fwd_strides[i] = ostrides[i]; + bwd_strides[i] = istrides[i]; + } + } + + auto desc = descriptor(mkl_signal_sizes); + desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE); + desc.set_value(config_param::NUMBER_OF_TRANSFORMS, batch); + + if (!inverse) { + desc.set_value(config_param::FWD_DISTANCE, idist); + desc.set_value(config_param::BWD_DISTANCE, odist); + } else { + desc.set_value(config_param::FWD_DISTANCE, odist); + desc.set_value(config_param::BWD_DISTANCE, idist); + } + + if (!fwd_strides.empty()) { + desc.set_value(config_param::FWD_STRIDES, fwd_strides.data()); + } + if (!bwd_strides.empty()) { + desc.set_value(config_param::BWD_STRIDES, bwd_strides.data()); + } + + if (!complex_input || !complex_output) { + desc.set_value( + config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX); + } + + desc.set_value( + oneapi::mkl::dft::config_param::WORKSPACE, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + desc.commit(queue); + + // Obtain the size of workspace required after commit. + size_t workspaceSizeBytes = 0; + desc.get_value( + oneapi::mkl::dft::config_param::WORKSPACE_BYTES, &workspaceSizeBytes); + + // Allocate USM workspace and provide it to the descriptor. + Tensor workspaceBuf = at::empty( + {(long)(workspaceSizeBytes / sizeof(double))}, + input.options().dtype(at::kDouble), + c10::nullopt); + desc.set_workspace((double*)workspaceBuf.mutable_data_ptr()); + + auto in_data = (scalar_t*)input.const_data_ptr(); + auto out_data = (scalar_t*)output.mutable_data_ptr(); + + sycl::event event; + if (!inverse) { + event = compute_forward(desc, in_data, out_data); + } else { + event = compute_backward(desc, in_data, out_data); + } + event.wait_and_throw(); + queue.throw_asynchronous(); +} + +void _fft_with_size( + Tensor& output, + const Tensor& self, + int64_t signal_ndim, + bool complex_input, + bool complex_output, + bool inverse, + IntArrayRef checked_signal_sizes, + bool onesided) { + int64_t batch = self.size(0); + Tensor input_ = self; + // real/imag dimension must aligned when viewed as of complex type + + if (complex_input) { + bool need_contiguous = input_.stride(-1) != 1; + + for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) { + need_contiguous |= input_.stride(i) % 2 != 0; + } + + if (need_contiguous) { + input_ = input_.contiguous(); + } + } + + bool complex_type = inverse ? complex_output : complex_input; + + void (*dft_func)( + const class at::Tensor&, + class at::Tensor&, + int64_t, + bool, + bool, + bool, + class c10::ArrayRef, + bool, + int64_t); + Tensor input = input_; + + if (input.scalar_type() == ScalarType::Float || + input.scalar_type() == ScalarType::ComplexFloat) { + dft_func = complex_type + ? _mkl_dft + : _mkl_dft; + } else if ( + input.scalar_type() == ScalarType::Double || + input.scalar_type() == ScalarType::ComplexDouble) { + dft_func = complex_type + ? _mkl_dft + : _mkl_dft; + } else { + AT_ERROR("MKL FFT doesn't support tensor of type"); + } + + dft_func( + input, + output, + signal_ndim, + complex_input, + complex_output, + inverse, + checked_signal_sizes, + onesided, + batch); +} + +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +Tensor& _exec_fft( + Tensor& out, + Tensor self, + IntArrayRef out_sizes, + IntArrayRef dim, + bool onesided, + bool forward) { + const auto ndim = self.dim(); + const int64_t signal_ndim = dim.size(); + const auto batch_dims = ndim - signal_ndim; + + // Permute dimensions so batch dimensions come first, and in stride order + // This maximizes data locality when collapsing to a single batch dimension + DimVector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0}); + + c10::SmallVector is_transformed_dim(ndim); + for (const auto& d : dim) { + is_transformed_dim[d] = true; + } + + auto batch_end = + std::partition(dim_permute.begin(), dim_permute.end(), [&](int64_t d) { + return !is_transformed_dim[d]; + }); + + auto self_strides = self.strides(); + std::sort(dim_permute.begin(), batch_end, [&](int64_t a, int64_t b) { + return self_strides[a] > self_strides[b]; + }); + std::copy(dim.cbegin(), dim.cend(), batch_end); + + auto input = self.permute(dim_permute); + + // Collapse batch dimensions into a single dimension + DimVector batched_sizes(signal_ndim + 1); + batched_sizes[0] = -1; + std::copy( + input.sizes().cbegin() + batch_dims, + input.sizes().cend(), + batched_sizes.begin() + 1); + input = input.reshape(batched_sizes); + + const auto batch_size = input.sizes()[0]; + DimVector signal_size(signal_ndim + 1); + signal_size[0] = batch_size; + + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input.sizes()[i + 1]; + auto out_size = out_sizes[dim[i]]; + signal_size[i + 1] = std::max(in_size, out_size); + TORCH_INTERNAL_ASSERT( + in_size == signal_size[i + 1] || + in_size == (signal_size[i + 1] / 2) + 1); + TORCH_INTERNAL_ASSERT( + out_size == signal_size[i + 1] || + out_size == (signal_size[i + 1] / 2) + 1); + } + + batched_sizes[0] = batch_size; + DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); + + for (size_t i = 0; i < dim.size(); ++i) { + batched_out_sizes[i + 1] = out_sizes[dim[i]]; + } + + out.resize_(batched_out_sizes, MemoryFormat::Contiguous); + + // run the FFT + _fft_with_size( + out, + input, + signal_ndim, + input.is_complex(), + out.is_complex(), + !forward, + signal_size, + onesided); + + // Inplace reshaping to original batch shape and inverting the dimension + // permutation + DimVector out_strides(ndim); + int64_t batch_numel = 1; + + for (int64_t i = batch_dims - 1; i >= 0; --i) { + out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; + batch_numel *= out_sizes[dim_permute[i]]; + } + + for (int64_t i = batch_dims; i < ndim; ++i) { + out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; + } + + out.as_strided_(out_sizes, out_strides, out.storage_offset()); + + return out; +} + +double _dft_scale( + IntArrayRef dim, + IntArrayRef input_sizes, + IntArrayRef out_sizes, + int64_t normalization) { + const auto norm = static_cast(normalization); + double double_scale = 1.0; + + if (norm == fft_norm_mode::none) { + return double_scale; + } + + const int64_t signal_ndim = dim.size(); + int64_t signal_numel = 1; + + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input_sizes[dim[i]]; + auto out_size = out_sizes[dim[i]]; + auto signal_size = std::max(in_size, out_size); + + signal_numel *= signal_size; + TORCH_INTERNAL_ASSERT( + in_size == signal_size || in_size == (signal_size / 2) + 1); + TORCH_INTERNAL_ASSERT( + out_size == signal_size || out_size == (signal_size / 2) + 1); + } + + if (norm == fft_norm_mode::by_root_n) { + double_scale = 1.0 / std::sqrt(signal_numel); + } else { + double_scale = 1.0 / static_cast(signal_numel); + } + + return double_scale; +} + +const Tensor& _fft_apply_normalization( + const Tensor& self, + int64_t normalization, + IntArrayRef sizes, + IntArrayRef dims) { + auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); + return (scale == 1.0) ? self : self.mul_(scale); +} + +Tensor& _fft_apply_normalization_out( + Tensor& out, + const Tensor& self, + int64_t normalization, + IntArrayRef sizes, + IntArrayRef dims) { + auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); + return at::mul_out(out, self, c10::scalar_to_tensor(scale)); +} + +} // namespace impl + +Tensor _fft_c2c_mkl( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward) { + if (dim.empty()) { + return self.clone(); + } + + auto sorted_dims = impl::_sort_dims(self, dim); + auto out_sizes = self.sizes(); + auto out = at::empty(out_sizes, self.options()); + auto input_sizes = self.sizes(); + auto working_tensor = self; + + while (!sorted_dims.empty()) { + const auto max_dims = + std::min(static_cast(impl::mkl_max_ndim), sorted_dims.size()); + auto fft_dims = + IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims); + + impl::_exec_fft( + out, + working_tensor, + out_sizes, + fft_dims, + /*onesided=*/false, + forward); + + sorted_dims.resize(sorted_dims.size() - max_dims); + + if (sorted_dims.empty()) { + break; + } + + sorted_dims = impl::_sort_dims(self, sorted_dims); + + if (working_tensor.is_same(self)) { + working_tensor = std::move(out); + out = at::empty(out_sizes, self.options()); + } else { + std::swap(out, working_tensor); + } + } + + return impl::_fft_apply_normalization(out, normalization, input_sizes, dim); +} + +Tensor& _fft_c2c_mkl_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out) { + auto result = _fft_c2c_mkl( + self, dim, static_cast(fft_norm_mode::none), forward); + at::native::resize_output(out, result.sizes()); + return impl::_fft_apply_normalization_out( + out, result, normalization, result.sizes(), dim); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/mkl/SpectralOps.h b/src/ATen/native/xpu/mkl/SpectralOps.h new file mode 100644 index 000000000..0d66a6dae --- /dev/null +++ b/src/ATen/native/xpu/mkl/SpectralOps.h @@ -0,0 +1,18 @@ +#pragma once + +namespace at::native::xpu { + +TORCH_XPU_API Tensor _fft_c2c_mkl( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward); + +TORCH_XPU_API Tensor& _fft_c2c_mkl_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out); + +} // namespace at::native::xpu diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 1590919c0..1842482fa 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -7,6 +7,7 @@ add_library( torch_xpu_ops STATIC ${ATen_XPU_CPP_SRCS} + ${ATen_XPU_MKL_SRCS} ${ATen_XPU_NATIVE_CPP_SRCS} ${ATen_XPU_GEN_SRCS}) diff --git a/src/BuildOnWindows.cmake b/src/BuildOnWindows.cmake index b4757acb1..1fda3582e 100644 --- a/src/BuildOnWindows.cmake +++ b/src/BuildOnWindows.cmake @@ -6,7 +6,8 @@ set(SYCL_LINK_LIBRARIES_KEYWORD PRIVATE) add_library( torch_xpu_ops STATIC - ${ATen_XPU_CPP_SRCS}) + ${ATen_XPU_CPP_SRCS} + ${ATen_XPU_MKL_SRCS}) set(PATH_TO_TORCH_XPU_OPS_ATEN_LIB \"torch_xpu_ops_aten.dll\") target_compile_options(torch_xpu_ops PRIVATE -DPATH_TO_TORCH_XPU_OPS_ATEN_LIB=${PATH_TO_TORCH_XPU_OPS_ATEN_LIB}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0716ca5af..de2a5ea7b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,6 +2,7 @@ include(${TORCH_XPU_OPS_ROOT}/cmake/Codegen.cmake) set(ATen_XPU_CPP_SRCS) +set(ATen_XPU_MKL_SRCS) set(ATen_XPU_NATIVE_CPP_SRCS) set(ATen_XPU_SYCL_SRCS) @@ -27,3 +28,6 @@ if(CLANG_FORMAT) add_custom_target(CL_FORMAT_CSRCS COMMAND ${CLANG_FORMAT_EXEC} -i -style=file ${ALL_CSRCS}) add_dependencies(torch_xpu_ops CL_FORMAT_CSRCS) endif() + +target_include_directories(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR}) +target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES}) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index fdf481f9c..b623036e8 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -676,6 +676,26 @@ # TODO: passed from source code building version, investigate "test_python_ref__refs_log2_xpu_complex128", + + # The following dtypes did not work in backward but are listed by the OpInfo: {torch.bfloat16}. + "test_dtypes_fft_fft2_xpu", + "test_dtypes_fft_fft_xpu", + "test_dtypes_fft_fftn_xpu", + "test_dtypes_fft_hfft2_xpu", + "test_dtypes_fft_hfft_xpu", + "test_dtypes_fft_hfftn_xpu", + "test_dtypes_fft_ifft2_xpu", + "test_dtypes_fft_ifft_xpu", + "test_dtypes_fft_ifftn_xpu", + "test_dtypes_fft_ihfft2_xpu", + "test_dtypes_fft_ihfft_xpu", + "test_dtypes_fft_ihfftn_xpu", + "test_dtypes_fft_irfft2_xpu", + "test_dtypes_fft_irfft_xpu", + "test_dtypes_fft_irfftn_xpu", + "test_dtypes_fft_rfft2_xpu", + "test_dtypes_fft_rfft_xpu", + "test_dtypes_fft_rfftn_xpu", ), "test_binary_ufuncs_xpu.py": ( @@ -3300,6 +3320,11 @@ "test_set_default_dtype_works_with_foreach_SGD_xpu_float64", ), + "test_spectral_ops_xpu.py": ( + # CUDA specific case + "test_cufft_plan_cache_xpu_float64", + ), + "test_sparse_xpu.py": ( "test_bmm_deterministic_xpu_float64", # - AssertionError: Torch not compiled with CUDA enabled "test_bmm_oob_xpu", # - NotImplementedError: Could not run 'aten::bmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was ... diff --git a/test/xpu/test_spectral_ops_xpu.py b/test/xpu/test_spectral_ops_xpu.py new file mode 100644 index 000000000..bc60cf2ae --- /dev/null +++ b/test/xpu/test_spectral_ops_xpu.py @@ -0,0 +1,81 @@ +# Owner(s): ["module: intel"] + +import torch +import numpy as np +from packaging import version +from itertools import product + +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, ops, onlyNativeDeviceTypes) +from torch.testing._internal.common_methods_invocations import ( + spectral_funcs, SpectralFuncType) +from torch.testing._internal.common_utils import run_tests + +try: + from .xpu_test_utils import XPUPatchForImport +except Exception as e: + from ..xpu_test_utils import XPUPatchForImport + +with XPUPatchForImport(False): + from test_spectral_ops import TestFFT + +has_scipy_fft = False +try: + import scipy.fft + has_scipy_fft = True +except ModuleNotFoundError: + pass + +REFERENCE_NORM_MODES = ( + (None, "forward", "backward", "ortho") + if version.parse(np.__version__) >= version.parse('1.20.0') and ( + not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) + else (None, "ortho")) + +@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], + allowed_dtypes=(torch.float, torch.cfloat)) +def _test_reference_1d(self, device, dtype, op): + if op.ref is None: + raise unittest.SkipTest("No reference implementation") + + norm_modes = REFERENCE_NORM_MODES + test_args = [ + *product( + # input + (torch.randn(67, device=device, dtype=dtype), + torch.randn(80, device=device, dtype=dtype), + torch.randn(12, 14, device=device, dtype=dtype), + torch.randn(9, 6, 3, device=device, dtype=dtype)), + # n + (None, 50, 6), + # dim + (-1, 0), + # norm + norm_modes + ), + # Test transforming middle dimensions of multi-dim tensor + *product( + (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), + (None,), + (1, 2, -2,), + norm_modes + ) + ] + + for iargs in test_args: + args = list(iargs) + input = args[0] + args = args[1:] + + expected = op.ref(input.cpu().numpy(), *args) + exact_dtype = dtype in (torch.double, torch.complex128) + actual = op(input, *args) + self.assertEqual(actual, expected, exact_dtype=exact_dtype, atol=1e-4, rtol=1e-5) + +TestFFT.test_reference_1d = _test_reference_1d + +instantiate_device_type_tests(TestFFT, globals(), only_for=("xpu"), allow_xpu=True) + + +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 1d18a27e2..92e0011ec 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -31,7 +31,6 @@ _xpu_computation_op_list = [ "empty", "eye", - "fill", "zeros", "zeros_like", "clone", @@ -74,6 +73,26 @@ "exp2", "expm1", "exponential", + "fft.fft", + "fft.fft2", + "fft.fftn", + "fft.hfft", + "fft.hfft2", + "fft.hfftn", + "fft.rfft", + "fft.rfft2", + "fft.rfftn", + "fft.ifft", + "fft.ifft2", + "fft.ifftn", + "fft.ihfft", + "fft.ihfft2", + "fft.ihfftn", + "fft.irfft", + "fft.irfft2", + "fft.irfftn", + "fft.fftshift", + "fft.ifftshift", "fill", "fmod", "gcd", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 7a257f0fd..1c354cc5c 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8579,3 +8579,14 @@ dispatch: SparseXPU: copy_sparse_ autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out + +# Standard complex to complex FFT (forward or backward) +- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + variants: function + dispatch: + XPU: _fft_c2c_xpu + +- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + variants: function + dispatch: + XPU: _fft_c2c_xpu_out