diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java index 1b5dccd..3984a93 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java @@ -36,7 +36,6 @@ import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TType; -import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.numeric.integer.IntType; import net.imglib2.type.numeric.integer.LongType; import net.imglib2.type.numeric.integer.UnsignedByteType; @@ -44,11 +43,10 @@ import net.imglib2.type.numeric.real.FloatType; /** - * A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects. - * Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor}) - * from Tensorflow 2 {@link Tensor} + * A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for + * interprocessing communication * - * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + * @author Carlos Garcia Lopez de Haro */ public final class ShmBuilder { @@ -59,16 +57,15 @@ private ShmBuilder() { } - /** - * Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor - * - * @param - * the possible ImgLib2 datatypes of the image - * @param tensor - * The {@link TType} tensor data is read from. - * @throws IllegalArgumentException If the {@link TType} tensor type is not supported. - * @throws IOException - */ + /** + * Create a {@link SharedMemoryArray} from a {@link Tensor} + * @param tensor + * the tensor to be passed into the other process through the shared memory + * @param memoryName + * the name of the memory region where the tensor is going to be copied + * @throws IllegalArgumentException if the data type of the tensor is not supported + * @throws IOException if there is any error creating the shared memory array + */ @SuppressWarnings("unchecked") public static void build(Tensor tensor, String memoryName) throws IllegalArgumentException, IOException { @@ -89,14 +86,6 @@ public static void build(Tensor tensor, String memoryName) thr } } - /** - * Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor. - * - * @param tensor - * The {@link TUint8} tensor data is read from. - * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}. - * @throws IOException - */ private static void buildFromTensorUByte(Tensor tensor, String memoryName) throws IOException { long[] arrayShape = tensor.shape().asArray(); @@ -114,14 +103,6 @@ private static void buildFromTensorUByte(Tensor tensor, String memoryNam if (PlatformDetection.isWindows()) shma.close(); } - /** - * Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor. - * - * @param tensor - * The {@link TInt32} tensor data is read from. - * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}. - * @throws IOException - */ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws IOException { long[] arrayShape = tensor.shape().asArray(); @@ -140,14 +121,6 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) if (PlatformDetection.isWindows()) shma.close(); } - /** - * Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor. - * - * @param tensor - * The {@link TFloat32} tensor data is read from. - * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}. - * @throws IOException - */ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throws IOException { long[] arrayShape = tensor.shape().asArray(); @@ -166,14 +139,6 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryN if (PlatformDetection.isWindows()) shma.close(); } - /** - * Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor. - * - * @param tensor - * The {@link TFloat64} tensor data is read from. - * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}. - * @throws IOException - */ private static void buildFromTensorDouble(Tensor tensor, String memoryName) throws IOException { long[] arrayShape = tensor.shape().asArray(); @@ -192,14 +157,6 @@ private static void buildFromTensorDouble(Tensor tensor, String memory if (PlatformDetection.isWindows()) shma.close(); } - /** - * Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor. - * - * @param tensor - * The {@link TInt64} tensor data is read from. - * @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}. - * @throws IOException - */ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws IOException { long[] arrayShape = tensor.shape().asArray(); diff --git a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java index 7b7e84b..9714139 100644 --- a/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java +++ b/src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java @@ -55,10 +55,9 @@ import org.tensorflow.types.family.TType; /** - * A TensorFlow 2 {@link Tensor} builder from {@link Img} and - * {@link io.bioimage.modelrunner.tensor.Tensor} objects. + * Utility class to build Tensorflow tensors from shm segments using {@link SharedMemoryArray} * - * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando + * @author Carlos Garcia Lopez de Haro */ public final class TensorBuilder { @@ -68,16 +67,13 @@ public final class TensorBuilder { private TensorBuilder() {} /** - * Creates {@link TType} instance with the same size and information as the - * given {@link RandomAccessibleInterval}. + * Creates {@link Tensor} instance from a {@link SharedMemoryArray} * - * @param - * the ImgLib2 data types the {@link RandomAccessibleInterval} can be * @param array - * the {@link RandomAccessibleInterval} that is going to be converted into - * a {@link TType} tensor - * @return a {@link TType} tensor - * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval} + * the {@link SharedMemoryArray} that is going to be converted into + * a {@link Tensor} tensor + * @return the Tensorflow {@link Tensor} as the one stored in the shared memory segment + * @throws IllegalArgumentException if the type of the {@link SharedMemoryArray} * is not supported */ public static Tensor build(SharedMemoryArray array) throws IllegalArgumentException @@ -103,17 +99,7 @@ else if (array.getOriginalDataType().equals("int64")) { } } - /** - * Creates a {@link TType} tensor of type {@link TUint8} from an - * {@link RandomAccessibleInterval} of type {@link UnsignedByteType} - * - * @param tensor - * The {@link RandomAccessibleInterval} to fill the tensor with. - * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data. - * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is - * not compatible - */ - public static Tensor buildUByte(SharedMemoryArray tensor) + private static Tensor buildUByte(SharedMemoryArray tensor) throws IllegalArgumentException { long[] ogShape = tensor.getOriginalShape(); @@ -128,17 +114,7 @@ public static Tensor buildUByte(SharedMemoryArray tensor) return ndarray; } - /** - * Creates a {@link TInt32} tensor of type {@link TInt32} from an - * {@link RandomAccessibleInterval} of type {@link IntType} - * - * @param tensor - * The {@link RandomAccessibleInterval} to fill the tensor with. - * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data. - * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is - * not compatible - */ - public static Tensor buildInt(SharedMemoryArray tensor) + private static Tensor buildInt(SharedMemoryArray tensor) throws IllegalArgumentException { long[] ogShape = tensor.getOriginalShape(); @@ -156,16 +132,6 @@ public static Tensor buildInt(SharedMemoryArray tensor) return ndarray; } - /** - * Creates a {@link TInt64} tensor of type {@link TInt64} from an - * {@link RandomAccessibleInterval} of type {@link LongType} - * - * @param tensor - * The {@link RandomAccessibleInterval} to fill the tensor with. - * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data. - * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is - * not compatible - */ private static Tensor buildLong(SharedMemoryArray tensor) throws IllegalArgumentException { @@ -184,17 +150,7 @@ private static Tensor buildLong(SharedMemoryArray tensor) return ndarray; } - /** - * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an - * {@link RandomAccessibleInterval} of type {@link FloatType} - * - * @param tensor - * The {@link RandomAccessibleInterval} to fill the tensor with. - * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data. - * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is - * not compatible - */ - public static Tensor buildFloat(SharedMemoryArray tensor) + private static Tensor buildFloat(SharedMemoryArray tensor) throws IllegalArgumentException { long[] ogShape = tensor.getOriginalShape(); @@ -212,16 +168,6 @@ public static Tensor buildFloat(SharedMemoryArray tensor) return ndarray; } - /** - * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an - * {@link RandomAccessibleInterval} of type {@link DoubleType} - * - * @param tensor - * The {@link RandomAccessibleInterval} to fill the tensor with. - * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data. - * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is - * not compatible - */ private static Tensor buildDouble(SharedMemoryArray tensor) throws IllegalArgumentException {