-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Init MKL for Pytorch XPU and enable fft_c2c
- Loading branch information
Showing
15 changed files
with
674 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
set(ONEMKL_FOUND FALSE) | ||
|
||
set(ONEMKL_LIBRARIES) | ||
|
||
# In order to be compatible with various situations of Pytorch development | ||
# bundle setup, ENV{MKLROOT} and SYCL_ROOT will be checked sequentially to get | ||
# the root directory of oneMKL. | ||
if(DEFINED ENV{MKLROOT}) | ||
# Directly get the root directory of oneMKL if ENV{MKLROOT} exists. | ||
set(ONEMKL_ROOT $ENV{MKLROOT}) | ||
elseif(SYCL_FOUND) | ||
# oneMKL configuration may not be imported into the build system. Get the root | ||
# directory of oneMKL based on the root directory of compiler relatively. | ||
get_filename_component(ONEMKL_ROOT "${SYCL_ROOT}/../../mkl/latest" REALPATH) | ||
endif() | ||
|
||
if(NOT DEFINED ONEMKL_ROOT) | ||
message( | ||
WARNING | ||
"Cannot find either ENV{MKLROOT} or SYCL_ROOT, please setup oneAPI environment before building!!" | ||
) | ||
return() | ||
endif() | ||
|
||
if(NOT EXISTS ${ONEMKL_ROOT}) | ||
message( | ||
WARNING | ||
"${ONEMKL_ROOT} not found, please setup oneAPI environment before building!!" | ||
) | ||
return() | ||
endif() | ||
|
||
find_file( | ||
ONEMKL_INCLUDE_DIR | ||
NAMES include | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH) | ||
|
||
find_file( | ||
ONEMKL_LIB_DIR | ||
NAMES lib | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH) | ||
|
||
if((ONEMKL_INCLUDE_DIR STREQUAL "ONEMKL_INCLUDE_DIR-NOTFOUND") | ||
OR (ONEMKL_LIB_DIR STREQUAL "ONEMKL_LIB_DIR-NOTFOUND")) | ||
message(WARNING "oneMKL sdk is incomplete!!") | ||
return() | ||
endif() | ||
|
||
if(WIN32) | ||
set(MKL_LIB_NAMES "mkl_sycl" "mkl_intel_lp64" "mkl_intel_thread" "mkl_core") | ||
else() | ||
set(MKL_LIB_NAMES "mkl_sycl_dft" "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core") | ||
endif() | ||
|
||
foreach(LIB_NAME IN LISTS MKL_LIB_NAMES) | ||
find_library( | ||
${LIB_NAME}_library | ||
NAMES ${LIB_NAME} | ||
HINTS ${ONEMKL_LIB_DIR} | ||
NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH) | ||
list(APPEND ONEMKL_LIBRARIES ${${LIB_NAME}_library}) | ||
endforeach() | ||
|
||
set(ONEMKL_FOUND TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
find_package(ONEMKL) | ||
if(NOT ONEMKL_FOUND) | ||
message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!") | ||
endif() | ||
|
||
set(TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR ${ONEMKL_INCLUDE_DIR}) | ||
|
||
set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES}) | ||
|
||
list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 1 "-Wl,--start-group") | ||
list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--end-group") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#include <ATen/native/Resize.h> | ||
#include <ATen/native/xpu/mkl/SpectralOps.h> | ||
#include <comm/xpu_aten.h> | ||
|
||
namespace at::native { | ||
|
||
Tensor _fft_c2c_xpu( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); | ||
} | ||
|
||
Tensor& _fft_c2c_xpu_out( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward, | ||
Tensor& out) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); | ||
} | ||
|
||
} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.