Skip to content

Commit

Permalink
improve creatoon of tensors from mglib2
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent 9e064ba commit 8343fbd
Showing 1 changed file with 52 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
*/
package io.bioimage.modelrunner.tensorflow.v2.api050.tensor;

import io.bioimage.modelrunner.utils.IndexingUtils;
import io.bioimage.modelrunner.tensor.Utils;

import net.imglib2.Cursor;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
Expand All @@ -41,7 +40,7 @@
import org.tensorflow.types.family.TType;

/**
* A {@link Img} builder for TensorFlow {@link Tensor} objects.
* A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects.
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
* from Tensorflow 2 {@link Tensor}
*
Expand All @@ -57,36 +56,36 @@ private ImgLib2Builder()
}

/**
* Creates a {@link Img} from a given {@link TType} tensor
* Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor
*
* @param <T>
* the possible ImgLib2 datatypes of the image
* @param tensor
* The {@link TType} tensor data is read from.
* @return The {@link Img} built from the {@link TType} tensor.
* @return The {@link RandomAccessibleInterval} built from the {@link TType} tensor.
* @throws IllegalArgumentException If the {@link TType} tensor type is not supported.
*/
public static <T extends Type<T>> Img<T> build(TType tensor) throws IllegalArgumentException
public static <T extends Type<T>> RandomAccessibleInterval<T> build(TType tensor) throws IllegalArgumentException
{
if (tensor instanceof TUint8)
{
return (Img<T>) buildFromTensorUByte((TUint8) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorUByte((TUint8) tensor);
}
else if (tensor instanceof TInt32)
{
return (Img<T>) buildFromTensorInt((TInt32) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorInt((TInt32) tensor);
}
else if (tensor instanceof TFloat32)
{
return (Img<T>) buildFromTensorFloat((TFloat32) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorFloat((TFloat32) tensor);
}
else if (tensor instanceof TFloat64)
{
return (Img<T>) buildFromTensorDouble((TFloat64) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorDouble((TFloat64) tensor);
}
else if (tensor instanceof TInt64)
{
return (Img<T>) buildFromTensorLong((TInt64) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorLong((TInt64) tensor);
}
else
{
Expand All @@ -95,140 +94,102 @@ else if (tensor instanceof TInt64)
}

/**
* Builds a {@link Img} from a unsigned byte-typed {@link TUint8} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor.
*
* @param tensor
* The {@link TUint8} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link UnsignedByteType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
*/
private static Img<UnsignedByteType> buildFromTensorUByte(TUint8 tensor)
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(TUint8 tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< UnsignedByteType > factory = new ArrayImgFactory<>( new UnsignedByteType() );
final Img< UnsignedByteType > outputImg = factory.create(tensorShape);
Cursor<UnsignedByteType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
byte[] flatArr = new byte[totalSize];
tensor.asRawTensor().data().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
byte val = flatArr[flatPos];
if (val < 0)
tensorCursor.get().set(256 + (int) val);
else
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<UnsignedByteType> rai = ArrayImgs.unsignedBytes(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned int32-typed {@link TInt32} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor.
*
* @param tensor
* The {@link TInt32} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link IntType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
*/
private static Img<IntType> buildFromTensorInt(TInt32 tensor)
private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< IntType > factory = new ArrayImgFactory<>( new IntType() );
final Img< IntType > outputImg = factory.create(tensorShape);
Cursor<IntType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
int[] flatArr = new int[totalSize];
tensor.asRawTensor().data().asInts().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
int val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<IntType> rai = ArrayImgs.ints(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned float32-typed {@link TFloat32} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor.
*
* @param tensor
* The {@link TFloat32} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link FloatType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
*/
private static Img<FloatType> buildFromTensorFloat(TFloat32 tensor)
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32 tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< FloatType > factory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > outputImg = factory.create(tensorShape);
Cursor<FloatType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
float[] flatArr = new float[totalSize];
tensor.asRawTensor().data().asFloats().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
float val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned float64-typed {@link TFloat64} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor.
*
* @param tensor
* The {@link TFloat64} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link DoubleType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
*/
private static Img<DoubleType> buildFromTensorDouble(TFloat64 tensor)
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat64 tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< DoubleType > factory = new ArrayImgFactory<>( new DoubleType() );
final Img< DoubleType > outputImg = factory.create(tensorShape);
Cursor<DoubleType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
double[] flatArr = new double[totalSize];
tensor.asRawTensor().data().asDoubles().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
double val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<DoubleType> rai = ArrayImgs.doubles(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned int64-typed {@link TInt64} tensor.
* Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor.
*
* @param tensor
* The {@link TInt64} tensor data is read from.
* @return The {@link Img} built from the tensor, of type {@link LongType}.
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
*/
private static Img<LongType> buildFromTensorLong(TInt64 tensor)
private static RandomAccessibleInterval<LongType> buildFromTensorLong(TInt64 tensor)
{
long[] tensorShape = tensor.shape().asArray();
final ArrayImgFactory< LongType > factory = new ArrayImgFactory<>( new LongType() );
final Img< LongType > outputImg = factory.create(tensorShape);
Cursor<LongType> tensorCursor= outputImg.cursor();
long[] arrayShape = tensor.shape().asArray();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {totalSize *= i;}
long[] flatArr = new long[totalSize];
tensor.asRawTensor().data().asLongs().read(flatArr);
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos, tensorShape);
long val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<LongType> rai = ArrayImgs.longs(flatArr, tensorShape);
return Utils.transpose(rai);
}
}

0 comments on commit 8343fbd

Please sign in to comment.