Skip to content

Commit

Permalink
improve the robustness of the creation of tf tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 28, 2023
1 parent 1def8ab commit 6cbde22
Showing 1 changed file with 55 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package io.bioimage.modelrunner.tensorflow.v2.api050.tensor;

import io.bioimage.modelrunner.tensor.Utils;

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.img.Img;
Expand All @@ -33,6 +33,10 @@
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

import java.nio.ByteBuffer;
import java.util.Arrays;

import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
Expand Down Expand Up @@ -131,16 +135,24 @@ public static TUint8 buildUByte(RandomAccessibleInterval<UnsignedByteType> tenso
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< UnsignedByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<UnsignedByteType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().getByte();
}
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -160,16 +172,24 @@ public static TInt32 buildInt(RandomAccessibleInterval<IntType> tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
dataBuffer);
Expand All @@ -190,16 +210,24 @@ private static TInt64 buildLong(RandomAccessibleInterval<LongType> tensor)
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< LongType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final long[] flatArr = new long[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<LongType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
dataBuffer);
Expand All @@ -221,16 +249,24 @@ public static TFloat32 buildFloat(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand All @@ -251,16 +287,24 @@ private static TFloat64 buildDouble(
throws IllegalArgumentException
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flatArr, false);
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
Expand Down

0 comments on commit 6cbde22

Please sign in to comment.