From 6cbde22744ffb965a5cdb7ab51188cb0687c7dfd Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 28 Nov 2023 14:32:57 +0100 Subject: [PATCH] improve the robustness of the creation of tf tensors --- .../v2/api050/tensor/TensorBuilder.java | 66 +++++++++++++++---- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java index f0eb079..122b465 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java @@ -22,7 +22,7 @@ package io.bioimage.modelrunner.tensorflow.v2.api050.tensor; import io.bioimage.modelrunner.tensor.Utils; - +import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; import net.imglib2.blocks.PrimitiveBlocks; import net.imglib2.img.Img; @@ -33,6 +33,10 @@ import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Util; +import net.imglib2.view.Views; + +import java.nio.ByteBuffer; +import java.util.Arrays; import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; @@ -131,8 +135,10 @@ public static TUint8 buildUByte(RandomAccessibleInterval tenso throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< UnsignedByteType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -140,7 +146,13 @@ public static TUint8 buildUByte(RandomAccessibleInterval tenso int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().getByte(); + } ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer); return ndarray; @@ -160,8 +172,10 @@ public static TInt32 buildInt(RandomAccessibleInterval tensor) throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -169,7 +183,13 @@ public static TInt32 buildInt(RandomAccessibleInterval tensor) int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } IntDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer); @@ -190,8 +210,10 @@ private static TInt64 buildLong(RandomAccessibleInterval tensor) throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< LongType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -199,7 +221,13 @@ private static TInt64 buildLong(RandomAccessibleInterval tensor) int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } LongDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer); @@ -221,8 +249,10 @@ public static TFloat32 buildFloat( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -230,7 +260,13 @@ public static TFloat32 buildFloat( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray; @@ -251,8 +287,10 @@ private static TFloat64 buildDouble( throws IllegalArgumentException { long[] ogShape = tensor.dimensionsAsLongArray(); + if (CommonUtils.int32Overflows(ogShape)) + throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape) + + " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE); tensor = Utils.transpose(tensor); - PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor ); long[] tensorShape = tensor.dimensionsAsLongArray(); int size = 1; for (long ll : tensorShape) size *= ll; @@ -260,7 +298,13 @@ private static TFloat64 buildDouble( int[] sArr = new int[tensorShape.length]; for (int i = 0; i < sArr.length; i ++) sArr[i] = (int) tensorShape[i]; - blocks.copy( tensor.minAsLongArray(), flatArr, sArr ); + + Cursor cursor = Views.flatIterable(tensor).cursor(); + int i = 0; + while (cursor.hasNext()) { + cursor.fwd(); + flatArr[i ++] = cursor.get().get(); + } DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false); TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer); return ndarray;