Skip to content

Commit

Permalink
correct shm tensor conversion and handle crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 3, 2024
1 parent 05d1675 commit 6825604
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,11 @@ private void launchModelLoadOnProcess() throws IOException, InterruptedException
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
else if (task.status == TaskStatus.CRASHED) {
this.runner.close();
runner = null;
throw new RuntimeException();
}
}

/**
Expand Down Expand Up @@ -360,8 +363,11 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
else if (task.status == TaskStatus.CRASHED) {
this.runner.close();
runner = null;
throw new RuntimeException();
}
for (int i = 0; i < outputTensors.size(); i ++) {
String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY);
SharedMemoryArray shm = shmaOutputList.stream()
Expand Down Expand Up @@ -491,8 +497,11 @@ public void closeModel() {
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
else if (task.status == TaskStatus.CRASHED) {
this.runner.close();
runner = null;
throw new RuntimeException();
}
this.runner.close();
this.runner = null;
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -114,7 +117,10 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -127,7 +133,10 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -140,7 +149,10 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
if (PlatformDetection.isWindows()) shma.close();
}

Expand All @@ -154,7 +166,10 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws

SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
ByteBuffer buff = shma.getDataBufferNoHeader();
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
byte[] flat = new byte[buff.capacity()];
ByteBuffer buff2 = ByteBuffer.wrap(flat);
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
buff = buff2;
if (PlatformDetection.isWindows()) shma.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
import net.imglib2.util.Cast;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;

import org.tensorflow.Tensor;
Expand Down Expand Up @@ -102,7 +98,10 @@ private static TUint8 buildUByte(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
byte[] flat = new byte[buff.capacity()];
buff.get(flat);
buff.rewind();
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -117,8 +116,10 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
IntBuffer intBuff = buff.asIntBuffer();
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
int[] flat = new int[buff.capacity() / 4];
buff.asIntBuffer().get(flat);
buff.rewind();
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
dataBuffer);
return ndarray;
Expand All @@ -134,8 +135,10 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
LongBuffer longBuff = buff.asLongBuffer();
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
long[] flat = new long[buff.capacity() / 8];
buff.asLongBuffer().get(flat);
buff.rewind();
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
dataBuffer);
return ndarray;
Expand All @@ -151,8 +154,10 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
FloatBuffer floatBuff = buff.asFloatBuffer();
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
float[] flat = new float[buff.capacity() / 4];
buff.asFloatBuffer().get(flat);
buff.rewind();
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand All @@ -167,8 +172,10 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
if (!tensor.isNumpyFormat())
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
ByteBuffer buff = tensor.getDataBufferNoHeader();
DoubleBuffer floatBuff = buff.asDoubleBuffer();
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
double[] flat = new double[buff.capacity() / 8];
buff.asDoubleBuffer().get(flat);
buff.rewind();
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
return ndarray;
}
Expand Down

0 comments on commit 6825604

Please sign in to comment.