diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle index 450c832e8032..c2b0ee9dc7b7 100644 --- a/engines/pytorch/pytorch-jni/build.gradle +++ b/engines/pytorch/pytorch-jni/build.gradle @@ -24,7 +24,13 @@ processResources { "osx-x86_64/cpu/libdjl_torch.dylib", "win-x86_64/cpu/djl_torch.dll" ] - if (ptVersion.startsWith("2.0.")) { + if (ptVersion.startsWith("2.1.")) { + files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") + files.add("linux-x86_64/cu121/libdjl_torch.so") + files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so") + files.add("win-x86_64/cu121/djl_torch.dll") + files.add("osx-aarch64/cpu/libdjl_torch.dylib") + } else if (ptVersion.startsWith("2.0.")) { files.add("linux-aarch64/cpu-precxx11/libdjl_torch.so") files.add("linux-x86_64/cu118/libdjl_torch.so") files.add("linux-x86_64/cu118-precxx11/libdjl_torch.so") diff --git a/engines/pytorch/pytorch-native/CMakeLists.txt b/engines/pytorch/pytorch-native/CMakeLists.txt index 4453186be6f4..4eddbfc3566e 100644 --- a/engines/pytorch/pytorch-native/CMakeLists.txt +++ b/engines/pytorch/pytorch-native/CMakeLists.txt @@ -64,7 +64,7 @@ add_library(djl_torch SHARED ${SOURCE_FILES}) if(NOT BUILD_ANDROID) target_link_libraries(djl_torch "${TORCH_LIBRARIES}") target_include_directories(djl_torch PUBLIC build/include ${JNI_INCLUDE_DIRS} ${UTILS_INCLUDE_DIR}) - set_property(TARGET djl_torch PROPERTY CXX_STANDARD 14) + set_property(TARGET djl_torch PROPERTY CXX_STANDARD 17) # We have to kill the default rpath and use current dir set(CMAKE_SKIP_RPATH TRUE) if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle index b4a195e109f0..99a658bf3ed2 100644 --- a/engines/pytorch/pytorch-native/build.gradle +++ b/engines/pytorch/pytorch-native/build.gradle @@ -24,6 +24,8 @@ if (project.hasProperty("cu11")) { FLAVOR = "cu117" } else if (VERSION.startsWith("2.0.")) { FLAVOR = "cu118" + } else if (VERSION.startsWith("2.1.")) { + FLAVOR = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${VERSION}") } @@ -88,15 +90,17 @@ def prepareNativeLib(String binaryRoot, String ver) { def officialPytorchUrl = "https://download.pytorch.org/libtorch" def aarch64PytorchUrl = "https://djl-ai.s3.amazonaws.com/publish/pytorch" - String cu11 + String cuda if (ver.startsWith("1.11.")) { - cu11 = "cu113" + cuda = "cu113" } else if (ver.startsWith("1.12.")) { - cu11 = "cu116" + cuda = "cu116" } else if (ver.startsWith("1.13.")) { - cu11 = "cu117" + cuda = "cu117" } else if (ver.startsWith("2.0.")) { - cu11 = "cu118" + cuda = "cu118" + } else if (ver.startsWith("2.1.")) { + cuda = "cu121" } else { throw new GradleException("Unsupported PyTorch version: ${ver}") } @@ -105,10 +109,10 @@ def prepareNativeLib(String binaryRoot, String ver) { "cpu/libtorch-cxx11-abi-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/linux-x86_64", "cpu/libtorch-macos-${ver}.zip" : "cpu/osx-x86_64", "cpu/libtorch-win-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/win-x86_64", - "${cu11}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cu11}.zip": "${cu11}/linux-x86_64", - "${cu11}/libtorch-win-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}/win-x86_64", + "${cuda}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cuda}.zip": "${cuda}/linux-x86_64", + "${cuda}/libtorch-win-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}/win-x86_64", "cpu/libtorch-shared-with-deps-${ver}%2Bcpu.zip" : "cpu-precxx11/linux-x86_64", - "${cu11}/libtorch-shared-with-deps-${ver}%2B${cu11}.zip" : "${cu11}-precxx11/linux-x86_64" + "${cuda}/libtorch-shared-with-deps-${ver}%2B${cuda}.zip" : "${cuda}-precxx11/linux-x86_64" ] def aarch64Files = [ @@ -138,17 +142,12 @@ def copyNativeLibToOutputDir(Map fileStoreMap, String binaryRoot from zipTree(file) into outputDir } - // CPU dependencies - copy { - from("${outputDir}/libtorch/lib/") { - include "libc10.*", "c10.dll", "libiomp5*.*", "libarm_compute*.*", "libgomp*.*", "libnvfuser_codegen.so", "libtorch.*", "libtorch_cpu.*", "torch.dll", "torch_cpu.dll", "fbgemm.dll", "asmjit.dll", "uv.dll", "nvfuser_codegen.dll" - } - into("${outputDir}/native/lib") - } - // GPU dependencies + delete "${outputDir}/libtorch/lib/*.lib" + delete "${outputDir}/libtorch/lib/*.a" + copy { from("${outputDir}/libtorch/lib/") { - include "libtorch_cuda*.so", "torch_cuda*.dll", "libc10_cuda.so", "c10_cuda.dll", "libcaffe2_nvrtc.so", "libnvrtc*.so.*", "libcudart*.*", "*nvToolsExt*.*", "cudnn*.dll", "caffe2_nvrtc.dll", "nvrtc64*.dll", "uv.dll", "libcublas*", "zlibwapi.dll" + include "libarm_compute*", "libc10_cuda.so", "libc10.*", "libcaffe2_nvrtc.so", "libcu*", "libgfortran-*", "libgomp*", "libiomp*", "libnv*", "libopenblasp-*", "libtorch_cpu.*", "libtorch_cuda*.so", "libtorch.*", "asmjit.dll", "c10_cuda.dll", "c10.dll", "caffe2_nvrtc.dll", "cu*.dll", "fbgemm.dll", "nv*.dll", "torch_cpu.dll", "torch_cuda*.dll", "torch.dll", "uv.dll", "zlibwapi.dll" } into("${outputDir}/native/lib") } @@ -287,9 +286,9 @@ tasks.register('uploadS3') { "${BINARY_ROOT}/cpu/win-x86_64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-aarch64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/linux-x86_64/native/lib/", - "${BINARY_ROOT}/cu118/win-x86_64/native/lib/", - "${BINARY_ROOT}/cu118-precxx11/linux-x86_64/native/lib/" + "${BINARY_ROOT}/cu121/linux-x86_64/native/lib/", + "${BINARY_ROOT}/cu121/win-x86_64/native/lib/", + "${BINARY_ROOT}/cu121-precxx11/linux-x86_64/native/lib/" ] uploadDirs.each { item -> fileTree(item).files.name.each { diff --git a/engines/pytorch/pytorch-native/build.sh b/engines/pytorch/pytorch-native/build.sh index 78c59d6bf2a7..d8060144ef39 100755 --- a/engines/pytorch/pytorch-native/build.sh +++ b/engines/pytorch/pytorch-native/build.sh @@ -23,7 +23,7 @@ ARCH=$4 if [[ ! -d "libtorch" ]]; then if [[ $PLATFORM == 'linux' ]]; then - if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118)$ ]]; then + if [[ ! "$FLAVOR" =~ ^(cpu|cu102|cu113|cu116|cu117|cu118|cu121)$ ]]; then echo "$FLAVOR is not supported." exit 1 fi diff --git a/gradle.properties b/gradle.properties index 87dc4fe5a15a..ff506da43cae 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,7 +13,7 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true djl_version=0.24.0 mxnet_version=1.9.1 -pytorch_version=2.0.1 +pytorch_version=2.1.0 tensorflow_version=2.10.1 tflite_version=2.6.2 trt_version=8.4.1