diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java index bc0a7c0..f0eb079 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java @@ -140,7 +140,7 @@ public static TUint8 buildUByte(RandomAccessibleInterval tenso int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer); return ndarray; @@ -169,7 +169,7 @@ public static TInt32 buildInt(RandomAccessibleInterval tensor) int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); IntDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer); @@ -199,7 +199,7 @@ private static TInt64 buildLong(RandomAccessibleInterval tensor) int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); LongDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer); @@ -230,7 +230,7 @@ public static TFloat32 buildFloat( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray; @@ -260,7 +260,7 @@ private static TFloat64 buildDouble( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( new long[tensorShape.length], flatArr, sArr ); + blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray;