diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java index 0423725a60d..83b01f7a1e8 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java @@ -12,8 +12,11 @@ */ package ai.djl.pytorch.integration; +import ai.djl.engine.EngineException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; import org.testng.Assert; import org.testng.annotations.Test; @@ -34,4 +37,12 @@ public void testStringTensor() { Assert.assertThrows(UnsupportedOperationException.class, () -> arr.get(0)); } } + + @Test + public void testLargeTensor() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray array = manager.zeros(new Shape(10 * 2850, 18944), DataType.FLOAT32); + Assert.assertThrows(EngineException.class, array::toByteArray); + } + } } diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index f33f1061bf6..e6aa2c145b4 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -308,8 +308,13 @@ JNIEXPORT jbyteArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDataPtr // sparse and mkldnn are required to be converted to dense to access data ptr auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr; tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous(); - jbyteArray result = env->NewByteArray(tensor.nbytes()); - env->SetByteArrayRegion(result, 0, tensor.nbytes(), static_cast(tensor.data_ptr())); + size_t nbytes = tensor.nbytes(); + if (nbytes > 0x7fffffff) { + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "toByteBuffer() is not supported for large tensor"); + return env->NewByteArray(0); + } + jbyteArray result = env->NewByteArray(nbytes); + env->SetByteArrayRegion(result, 0, nbytes, static_cast(tensor.data_ptr())); return result; API_END_RETURN() }