From f3fe4a6376a4a3dd3c77da719588211ab3c547b9 Mon Sep 17 00:00:00 2001 From: Bulat Date: Tue, 20 Mar 2018 11:47:19 +0300 Subject: [PATCH] Fix/map tensor (#10) * Map reimplementation and cleanup * Typed tensor refactor and DataGenerator tests * Shape class and generation fixes --- .../contract/utils/ContractBuilders.scala | 38 +- .../contract/utils/DataGenerator.scala | 75 ++-- .../contract/utils/SignatureChecker.scala | 89 ---- .../serving/contract/utils/TypedTensor.scala | 181 --------- .../description/ContractDescription.scala | 4 +- .../utils/description/FieldDescription.scala | 2 +- .../description/SignatureDescription.scala | 45 +- .../contract/utils/ops/Implicits.scala | 2 + .../contract/utils/ops/ModelContractOps.scala | 4 +- .../contract/utils/ops/ModelFieldOps.scala | 90 ++-- .../utils/ops/ModelSignatureOps.scala | 34 +- .../contract/utils/ops/TensorInfoOps.scala | 38 -- .../contract/utils/validation/package.scala | 4 +- .../serving/tensorflow/TensorShape.scala | 37 ++ .../tensorflow/tensor/BoolTensor.scala | 25 ++ .../tensorflow/tensor/DComplexTensor.scala | 24 ++ .../tensorflow/tensor/DoubleTensor.scala | 24 ++ .../tensorflow/tensor/FloatTensor.scala | 24 ++ .../tensorflow/tensor/Int16Tensor.scala | 18 + .../tensorflow/tensor/Int32Tensor.scala | 18 + .../tensorflow/tensor/Int64Tensor.scala | 24 ++ .../tensorflow/tensor/Int8Tensor.scala | 18 + .../serving/tensorflow/tensor/IntTensor.scala | 15 + .../serving/tensorflow/tensor/MapTensor.scala | 33 ++ .../tensorflow/tensor/SComplexTensor.scala | 24 ++ .../tensorflow/tensor/StringTensor.scala | 25 ++ .../tensorflow/tensor/TensorProtoLens.scala | 13 + .../tensorflow/tensor/TypedTensor.scala | 24 ++ .../tensor/TypedTensorFactory.scala | 95 +++++ .../tensorflow/tensor/Uint16Tensor.scala | 18 + .../tensorflow/tensor/Uint32Tensor.scala | 18 + .../tensorflow/tensor/Uint64Tensor.scala | 23 ++ .../tensorflow/tensor/Uint8Tensor.scala | 18 + .../tensorflow/tensor/UintTensor.scala | 9 + .../tensorflow/utils/ops/DataTypeOps.scala | 15 + .../utils/ops/TensorShapeProtoOps.scala | 31 +- .../src/test/scala/ContractOpsSpecs.scala | 19 +- .../src/test/scala/DataGeneratorSpecs.scala | 164 ++++++++ .../test/scala/SignatureCheckerSpecs.scala | 384 ++++++------------ .../src/test/scala/TypedTensorSpecs.scala | 69 ++++ .../contract/model_field.proto | 15 +- src/hydro_serving_grpc/tf/api/predict.proto | 2 +- src/hydro_serving_grpc/tf/tensor.proto | 6 +- src/hydro_serving_grpc/tf/tensor_info.proto | 15 - src/hydro_serving_grpc/tf/tensor_slice.proto | 37 -- src/hydro_serving_grpc/tf/types.proto | 29 -- version | 2 +- 47 files changed, 1097 insertions(+), 824 deletions(-) delete mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/SignatureChecker.scala delete mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/TypedTensor.scala delete mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorInfoOps.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/TensorShape.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/BoolTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DComplexTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DoubleTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/FloatTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int16Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int32Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int64Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int8Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/IntTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/MapTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/SComplexTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/StringTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TensorProtoLens.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensorFactory.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint16Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint32Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint64Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint8Tensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/UintTensor.scala create mode 100644 scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/DataTypeOps.scala rename scala-package/src/main/scala/io/hydrosphere/serving/{contract => tensorflow}/utils/ops/TensorShapeProtoOps.scala (51%) create mode 100644 scala-package/src/test/scala/DataGeneratorSpecs.scala create mode 100644 scala-package/src/test/scala/TypedTensorSpecs.scala delete mode 100644 src/hydro_serving_grpc/tf/tensor_info.proto delete mode 100644 src/hydro_serving_grpc/tf/tensor_slice.proto diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ContractBuilders.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ContractBuilders.scala index 01cbd78..0c750ae 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ContractBuilders.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ContractBuilders.scala @@ -1,37 +1,18 @@ package io.hydrosphere.serving.contract.utils import io.hydrosphere.serving.contract.model_field.ModelField -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo +import io.hydrosphere.serving.tensorflow.TensorShape import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto import io.hydrosphere.serving.tensorflow.types.DataType +import io.hydrosphere.serving.tensorflow.utils.ops.TensorShapeProtoOps object ContractBuilders { - def createUnknownTensorShape(): TensorShapeProto = { - TensorShapeProto(unknownRank = true) - } - - def createTensorShape(dims: Seq[Long], unknownRank: Boolean = false): TensorShapeProto = { - TensorShapeProto( - dim = dims.map { d => - TensorShapeProto.Dim(d) - }, - unknownRank = unknownRank - ) - } - - def createTensorInfo( - dataType: DataType, - shape: Option[Seq[Long]], - unknownRank: Boolean = false - ): TensorInfo = { - TensorInfo(dataType, shape.map(s => createTensorShape(s, unknownRank))) - } - - def complexField(name: String, subFields: Seq[ModelField]): ModelField = { + def complexField(name: String, shape: Option[TensorShapeProto], subFields: Seq[ModelField]): ModelField = { ModelField( name, - ModelField.InfoOrSubfields.Subfields( - ModelField.ComplexField( + shape, + ModelField.TypeOrSubfields.Subfields( + ModelField.Subfield( subFields ) ) @@ -43,7 +24,7 @@ object ContractBuilders { dataType: DataType, shape: Option[TensorShapeProto] ): ModelField = { - ModelField(name, ModelField.InfoOrSubfields.Info(TensorInfo(dataType, shape))) + ModelField(name, shape, ModelField.TypeOrSubfields.Dtype(dataType)) } def simpleTensorModelField( @@ -54,7 +35,8 @@ object ContractBuilders { ): ModelField = { ModelField( name, - ModelField.InfoOrSubfields.Info(createTensorInfo(dataType, shape, unknownRank)) + TensorShape(shape, unknownRank).toProto, + ModelField.TypeOrSubfields.Dtype(dataType) ) } -} +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/DataGenerator.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/DataGenerator.scala index 09c2001..5d3ef78 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/DataGenerator.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/DataGenerator.scala @@ -2,24 +2,19 @@ package io.hydrosphere.serving.contract.utils import io.hydrosphere.serving.contract.model_contract.ModelContract import io.hydrosphere.serving.contract.model_field.ModelField -import io.hydrosphere.serving.contract.model_field.ModelField.InfoOrSubfields.{ - Empty, - Info, - Subfields -} +import io.hydrosphere.serving.contract.model_field.ModelField.TypeOrSubfields.{Dtype, Empty, Subfields} import io.hydrosphere.serving.contract.model_signature.ModelSignature -import io.hydrosphere.serving.tensorflow.tensor.TensorProto -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo -import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.tensor._ import io.hydrosphere.serving.tensorflow.types.DataType import io.hydrosphere.serving.tensorflow.types.DataType._ class DataGenerator(val modelApi: ModelSignature) { - def generateInputs: Map[String, TensorProto] = { + def generateInputs: Map[String, TypedTensor[_]] = { modelApi.inputs.flatMap(DataGenerator.generateField).toMap } - def generateOutputs: Map[String, TensorProto] = { + def generateOutputs: Map[String, TypedTensor[_]] = { modelApi.outputs.flatMap(DataGenerator.generateField).toMap } } @@ -31,7 +26,7 @@ object DataGenerator { modelContract.signatures.find(_.signatureName == signature).map(DataGenerator.apply) } - def generateData(dataType: DataType): Any = { + def generateScalarData[T <: DataType](dataType: T): Any = { dataType match { case DT_FLOAT | DT_COMPLEX64 => 1.0F case DT_DOUBLE | DT_COMPLEX128 => 1.0D @@ -44,46 +39,46 @@ object DataGenerator { case DT_INVALID => throw new IllegalArgumentException( - s"Can't convert data to DT_INVALID has an invalid dtype" + s"Can't convert data to DT_INVALID" ) - case DT_VARIANT => - throw new IllegalArgumentException(s"Cannot process DT_VARIANT Tensor. Not supported yet.") - case Unrecognized(value) => - throw new IllegalArgumentException(s"Cannot process Tensor with Unrecognized($value) dtype") case x => throw new IllegalArgumentException(s"Cannot process Tensor with $x dtype") // refs } } - def generateData(dataType: DataType, shape: Option[TensorShapeProto]): List[Any] = { - shape match { + def createFlatTensor[T](shape: TensorShape, generator: => T): Seq[T] = { + shape.dims match { case Some(sh) => - sh.dim.map(_.size.max(1)).reverse.foldLeft(List.empty[Any]) { - case (Nil, y) => - 1L.to(y).map(_ => generateData(dataType)).toList - case (x, y) => - 1L.to(y).map(_ => x).toList - } - case None => List(generateData(dataType)) + val flatLen = sh.map(_.max(1)).product + (1L to flatLen).map(_ => generator) + case None => List(generator) } } - def generateTensor(tensorInfo: TensorInfo): TensorProto = { - val tensor = TensorProto(dtype = tensorInfo.dtype, tensorShape = tensorInfo.tensorShape) - val data = generateData(tensorInfo.dtype, tensorInfo.tensorShape) - val typedTensor = TypedTensor(tensor) - typedTensor.putAny(data).right.get + def generateTensor(shape: TensorShape, dtype: DataType): Option[TypedTensor[_]] = { + val factory = TypedTensorFactory(dtype) + val data = createFlatTensor(shape, generateScalarData(dtype)) + val s = factory.createFromAny(data, shape) + println(s.map(_.toProto)) + s } - def generateField(field: ModelField): Map[String, TensorProto] = { - val tensor = field.infoOrSubfields match { - case Empty => TensorProto() - case Info(value) => generateTensor(value) - case Subfields(value) => - val tensor = TensorProto(dtype = DataType.DT_MAP) - tensor.withMapVal( - value.data.flatMap(generateField).toMap - ) + def generateField(field: ModelField): Map[String, TypedTensor[_]] = { + val shape = TensorShape.fromProto(field.shape) + val fieldValue = field.typeOrSubfields match { + case Empty => None + case Dtype(value) => generateTensor(shape, value) + case Subfields(value) => generateNestedTensor(shape, value) } - Map(field.fieldName -> tensor) + fieldValue.map(x => field.name -> x).toMap + } + + private def generateNestedTensor(shape: TensorShape, value: ModelField.Subfield): Option[MapTensor] = { + val map = generateMap(value) + val tensorData = createFlatTensor(shape, map) + Some(MapTensor(shape, tensorData)) + } + + private def generateMap(value: ModelField.Subfield): Map[String, TypedTensor[_]] = { + value.data.flatMap(generateField).toMap } } diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/SignatureChecker.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/SignatureChecker.scala deleted file mode 100644 index d81bf6b..0000000 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/SignatureChecker.scala +++ /dev/null @@ -1,89 +0,0 @@ -package io.hydrosphere.serving.contract.utils - -import io.hydrosphere.serving.contract.model_field.ModelField -import io.hydrosphere.serving.contract.model_field.ModelField.InfoOrSubfields.{ - Empty, - Info, - Subfields -} -import io.hydrosphere.serving.contract.model_signature.ModelSignature -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo -import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto - -object SignatureChecker { - - def areCompatible( - first: Seq[TensorShapeProto.Dim], - second: Seq[TensorShapeProto.Dim] - ): Boolean = { - if (first.lengthCompare(second.length) != 0) { - false - } else { - first.zip(second).forall { - case (em, re) => - if (re.size == -1) { - true - } else { - em.size == re.size - } - } - } - } - - def areSequentiallyCompatible(emitter: TensorInfo, receiver: TensorInfo): Boolean = { - if (emitter.dtype != receiver.dtype) { - false - } else { - emitter.tensorShape -> receiver.tensorShape match { - case (em, re) if em == re => true // two identical tensors - case (em, re) if em.isDefined != re.isDefined => - false // comparing scalar and dimensional tensor - case (Some(_), Some(re)) if re.unknownRank => - true // receiver has unknown rank - runtime check - case (Some(em), Some(_)) if em.unknownRank => false - case (Some(em), Some(re)) => areCompatible(em.dim, re.dim) - } - } - } - - def areSequentiallyCompatible( - emitter: ModelField.ComplexField, - receiver: ModelField.ComplexField - ): Boolean = { - receiver.data.forall { field => - val emitterField = emitter.data.find(_.fieldName == field.fieldName) - emitterField.exists(areSequentiallyCompatible(_, field)) - } - } - - def areSequentiallyCompatible(emitter: ModelField, receiver: ModelField): Boolean = { - if (emitter == receiver) { - true - } else if (emitter.fieldName == receiver.fieldName) { - emitter.infoOrSubfields match { - case Empty => receiver.infoOrSubfields.isEmpty - case Subfields(fields) => - receiver.infoOrSubfields.subfields - .exists(areSequentiallyCompatible(fields, _)) - case Info(tensor) => - receiver.infoOrSubfields.info - .exists(areSequentiallyCompatible(tensor, _)) - } - } else { - false - } - } - - def areSequentiallyCompatible(emitter: ModelSignature, receiver: ModelSignature): Boolean = { - if (receiver.inputs.isEmpty) { - false - } else { - val outputMap = emitter.outputs.map(i => i.fieldName -> i).toMap - receiver.inputs.forall { input => - outputMap - .get(input.fieldName) - .exists(in => areSequentiallyCompatible(in, input)) - } - } - } -} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/TypedTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/TypedTensor.scala deleted file mode 100644 index 3194095..0000000 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/TypedTensor.scala +++ /dev/null @@ -1,181 +0,0 @@ -package io.hydrosphere.serving.contract.utils - -import io.hydrosphere.serving.contract.utils.validation.{ - InvalidFieldData, - UnsupportedFieldTypeError, - ValidationError -} -import com.google.protobuf.ByteString -import io.hydrosphere.serving.tensorflow.tensor.TensorProto -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo -import io.hydrosphere.serving.tensorflow.types.DataType -import io.hydrosphere.serving.tensorflow.types.DataType._ - -import scala.reflect.ClassTag - -/** - * At one time, `TensorProto` can only have one type of data. - * `TypedTensor` is a wrapper that ensures it, and provides simple access. - * @tparam T type of data field - */ -trait TypedTensor[T] { - def tensorProto: TensorProto - - /** - * Returns tensor contents from a field as flat `Seq` - * @return flat data `Seq` - */ - def get: Seq[T] - - /** - * Puts data to a field - * @param data data - * @return tensor with new data - */ - def put(data: Seq[T]): TensorProto - - /** - * Tries to convert data to tensor-specific type and puts it to a field - * @param data data - * @param ct class tag to retrieve class info - * @return tensor with new data or error - */ - def putAny(data: Seq[Any])(implicit ct: ClassTag[T]): Either[ValidationError, TensorProto] = { - castData(data).right.map { converted => - put(converted) - } - } - - /** - * Tries to convert `data` to tensor-specific type. - * @param data target data - * @param ct class tag to retrieve class info - * @return converted data or error - */ - def castData(data: Seq[Any])(implicit ct: ClassTag[T]): Either[ValidationError, Seq[T]] = { - try { - Right(data.map(_.asInstanceOf[T])) - } catch { - case _: ClassCastException => Left(new InvalidFieldData(ct.runtimeClass)) - } - } -} - -object TypedTensor { - def apply(dataType: DataType): TypedTensor[_] = { - TypedTensor(TensorProto.defaultInstance.withDtype(dataType)) - } - - def apply(tensorProto: TensorProto): TypedTensor[_] = { - tensorProto.dtype match { - case DT_FLOAT => FloatTensor(tensorProto) - - case DT_DOUBLE => DoubleTensor(tensorProto) - - case DT_INT8 | DT_INT16 | DT_INT32 | DT_QINT8 | DT_QINT16 | DT_QINT32 => - IntTensor(tensorProto) - - case DT_UINT8 | DT_UINT16 | DT_UINT32 | DT_QUINT8 | DT_QUINT16 => - UintTensor(tensorProto) - - case DT_INT64 => Int64Tensor(tensorProto) - - case DT_UINT64 => Uint64Tensor(tensorProto) - - case DT_COMPLEX64 => SComplexTensor(tensorProto) - case DT_COMPLEX128 => DComplexTensor(tensorProto) - - case DT_STRING => StringTensor(tensorProto) - case DT_BOOL => BoolTensor(tensorProto) - - case x => throw new UnsupportedFieldTypeError(x) - } - } - - /** - * Creates tensor with `data` and `tensorInfo` - * - * @param data contents to be put in tensor - * @param tensorInfo tensor info - * @return tensor with data or error - */ - def constructTensor( - data: Seq[Any], - tensorInfo: TensorInfo - ): Either[ValidationError, TensorProto] = { - val tensor = TensorProto(dtype = tensorInfo.dtype, tensorShape = tensorInfo.tensorShape) - val typedTensor = TypedTensor(tensor) - typedTensor.putAny(data) - } - - case class FloatTensor(tensorProto: TensorProto) extends TypedTensor[Float] { - override def get: Seq[Float] = tensorProto.floatVal - - override def put(data: Seq[Float]): TensorProto = - tensorProto.addAllFloatVal(data) - } - - case class SComplexTensor(tensorProto: TensorProto) extends TypedTensor[Float] { - override def get: Seq[Float] = tensorProto.scomplexVal - - override def put(data: Seq[Float]): TensorProto = - tensorProto.addAllScomplexVal(data) - } - - case class DoubleTensor(tensorProto: TensorProto) extends TypedTensor[Double] { - override def get: Seq[Double] = tensorProto.doubleVal - - override def put(data: Seq[Double]): TensorProto = - tensorProto.addAllDoubleVal(data) - } - - case class DComplexTensor(tensorProto: TensorProto) extends TypedTensor[Double] { - override def get: Seq[Double] = tensorProto.dcomplexVal - - override def put(data: Seq[Double]): TensorProto = - tensorProto.addAllDcomplexVal(data) - } - - case class Uint64Tensor(tensorProto: TensorProto) extends TypedTensor[Long] { - override def get: Seq[Long] = tensorProto.uint64Val - - override def put(data: Seq[Long]): TensorProto = - tensorProto.addAllUint64Val(data) - } - - case class Int64Tensor(tensorProto: TensorProto) extends TypedTensor[Long] { - override def get: Seq[Long] = tensorProto.int64Val - - override def put(data: Seq[Long]): TensorProto = - tensorProto.addAllInt64Val(data) - } - - case class IntTensor(tensorProto: TensorProto) extends TypedTensor[Int] { - override def get: Seq[Int] = tensorProto.intVal - - override def put(data: Seq[Int]): TensorProto = - tensorProto.addAllIntVal(data) - } - - case class UintTensor(tensorProto: TensorProto) extends TypedTensor[Int] { - override def get: Seq[Int] = tensorProto.uint32Val - - override def put(data: Seq[Int]): TensorProto = - tensorProto.addAllUint32Val(data) - } - - case class StringTensor(tensorProto: TensorProto) extends TypedTensor[String] { - override def get: Seq[String] = tensorProto.stringVal.map(_.toStringUtf8) - - override def put(data: Seq[String]): TensorProto = - tensorProto.addAllStringVal(data.map(ByteString.copyFromUtf8)) - } - - case class BoolTensor(tensorProto: TensorProto) extends TypedTensor[Boolean] { - override def get: Seq[Boolean] = tensorProto.boolVal - - override def put(data: Seq[Boolean]): TensorProto = - tensorProto.addAllBoolVal(data) - } - -} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/ContractDescription.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/ContractDescription.scala index 98b9e6e..857cf90 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/ContractDescription.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/ContractDescription.scala @@ -3,7 +3,7 @@ package io.hydrosphere.serving.contract.utils.description import io.hydrosphere.serving.contract.model_contract.ModelContract case class ContractDescription( - signatures: List[SignatureDescription] + signatures: Seq[SignatureDescription] ) { def toContract: ModelContract = ContractDescription.toContract(this) } @@ -14,4 +14,4 @@ object ContractDescription { signatures = contractDescription.signatures.map(SignatureDescription.toSignature) ) } -} +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/FieldDescription.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/FieldDescription.scala index fb9e255..2bb8b32 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/FieldDescription.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/FieldDescription.scala @@ -5,5 +5,5 @@ import io.hydrosphere.serving.tensorflow.types.DataType case class FieldDescription( fieldName: String, dataType: DataType, - shape: Option[List[Long]] + shape: Option[Seq[Long]] ) diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/SignatureDescription.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/SignatureDescription.scala index 8d3d01c..901ff09 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/SignatureDescription.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/description/SignatureDescription.scala @@ -2,35 +2,42 @@ package io.hydrosphere.serving.contract.utils.description import io.hydrosphere.serving.contract.model_field.ModelField import io.hydrosphere.serving.contract.model_signature.ModelSignature -import io.hydrosphere.serving.contract.utils.ContractBuilders -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto +import io.hydrosphere.serving.tensorflow.types.DataType +import io.hydrosphere.serving.tensorflow.utils.ops.TensorShapeProtoOps import scala.collection.mutable case class SignatureDescription( signatureName: String, - inputs: List[FieldDescription], - outputs: List[FieldDescription] + inputs: Seq[FieldDescription], + outputs: Seq[FieldDescription] ) { def toSignature: ModelSignature = SignatureDescription.toSignature(this) } object SignatureDescription { + class Converter() { sealed trait ANode { def name: String - def toInfoOrDict: ModelField.InfoOrSubfields + def shape: Option[Seq[Long]] + + def toTypeOrSubfields: ModelField.TypeOrSubfields + + def shapeProto: Option[TensorShapeProto] = TensorShape.fromSeq(shape).toProto } - case class FTensor(name: String, tensorInfo: TensorInfo) extends ANode { - override def toInfoOrDict: ModelField.InfoOrSubfields = { - ModelField.InfoOrSubfields.Info(tensorInfo) + case class FTensor(name: String, shape: Option[Seq[Long]], dtype: DataType) extends ANode { + override def toTypeOrSubfields = { + ModelField.TypeOrSubfields.Dtype(dtype) } } - case class FMap(name: String = "", data: mutable.ListBuffer[ANode] = mutable.ListBuffer.empty) + case class FMap(name: String = "", shape: Option[Seq[Long]] = None, data: mutable.ListBuffer[ANode] = mutable.ListBuffer.empty) extends ANode { def getOrUpdate(segment: String, map: FMap): ANode = { data.find(_.name == segment) match { @@ -44,11 +51,11 @@ object SignatureDescription { def +=(node: ANode): data.type = data += node - override def toInfoOrDict: ModelField.InfoOrSubfields = { - ModelField.InfoOrSubfields.Subfields( - ModelField.ComplexField( + override def toTypeOrSubfields = { + ModelField.TypeOrSubfields.Subfields( + ModelField.Subfield( data.map { node => - ModelField(node.name, node.toInfoOrDict) + ModelField(node.name, node.shapeProto, node.toTypeOrSubfields) } ) ) @@ -68,7 +75,8 @@ object SignatureDescription { case (tensorName :: Nil) => tree += FTensor( tensorName, - ContractBuilders.createTensorInfo(field.dataType, field.shape) + field.shape, + field.dataType ) case (root :: segments) => var last = tree.getOrUpdate(root, FMap(root)).asInstanceOf[FMap] @@ -82,7 +90,8 @@ object SignatureDescription { val lastName = segments.last last += FTensor( lastName, - ContractBuilders.createTensorInfo(field.dataType, field.shape) + field.shape, + field.dataType ) case Nil => throw new IllegalArgumentException( @@ -93,7 +102,7 @@ object SignatureDescription { def result: Seq[ModelField] = { tree.data.map { node => - ModelField(node.name, node.toInfoOrDict) + ModelField(node.name, node.shapeProto, node.toTypeOrSubfields) } } } @@ -101,8 +110,8 @@ object SignatureDescription { def toSignature(signatureDescription: SignatureDescription): ModelSignature = { ModelSignature( signatureName = signatureDescription.signatureName, - inputs = SignatureDescription.toFields(signatureDescription.inputs), - outputs = SignatureDescription.toFields(signatureDescription.outputs) + inputs = SignatureDescription.toFields(signatureDescription.inputs), + outputs = SignatureDescription.toFields(signatureDescription.outputs) ) } diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/Implicits.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/Implicits.scala index fa80554..39155c2 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/Implicits.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/Implicits.scala @@ -1,5 +1,7 @@ package io.hydrosphere.serving.contract.utils.ops +import io.hydrosphere.serving.tensorflow.utils.ops.TensorShapeProtoOps + trait Implicits extends ModelContractOps with ModelSignatureOps diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelContractOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelContractOps.scala index 112ee94..b70bbf4 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelContractOps.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelContractOps.scala @@ -11,10 +11,10 @@ trait ModelContractOps { ) } } -} -object ModelContractOps { def flatten(modelContract: ModelContract): List[SignatureDescription] = { modelContract.signatures.map(ModelSignatureOps.flatten).toList } } + +object ModelContractOps extends ModelContractOps \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelFieldOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelFieldOps.scala index 5a34d5d..e657795 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelFieldOps.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelFieldOps.scala @@ -2,12 +2,10 @@ package io.hydrosphere.serving.contract.utils.ops import io.hydrosphere.serving.contract.utils.description.FieldDescription import io.hydrosphere.serving.contract.model_field.ModelField -import io.hydrosphere.serving.contract.model_field.ModelField.InfoOrSubfields.{ - Empty, - Info, - Subfields -} +import io.hydrosphere.serving.contract.model_field.ModelField.TypeOrSubfields.{Dtype, Empty, Subfields} import io.hydrosphere.serving.contract.utils.ContractBuilders +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.utils.ops.{DataTypeOps, TensorShapeProtoOps} trait ModelFieldOps { @@ -18,31 +16,31 @@ trait ModelFieldOps { } def insert(name: String, fieldInfo: ModelField): Option[ModelField] = { - modelField.infoOrSubfields match { + modelField.typeOrSubfields match { case Subfields(fields) => - fields.data.find(_.fieldName == name) match { + fields.data.find(_.name == name) match { case Some(_) => None case None => val newData = fields.data :+ fieldInfo - Some(ContractBuilders.complexField(modelField.fieldName, newData)) + Some(ContractBuilders.complexField(modelField.name, fieldInfo.shape, newData)) } case _ => None } } def child(name: String): Option[ModelField] = { - modelField.infoOrSubfields match { + modelField.typeOrSubfields match { case Subfields(value) => - value.data.find(_.fieldName == name) + value.data.find(_.name == name) case _ => None } } def search(name: String): Option[ModelField] = { - modelField.infoOrSubfields match { + modelField.typeOrSubfields match { case Subfields(value) => - value.data.find(_.fieldName == name).orElse { + value.data.find(_.name == name).orElse { value.data.flatMap(_.search(name)).headOption } case _ => None @@ -51,13 +49,10 @@ trait ModelFieldOps { } -} - -object ModelFieldOps { - def merge(inputs: Seq[ModelField], inputs1: Seq[ModelField]): Seq[ModelField] = { + def mergeAll(inputs: Seq[ModelField], inputs1: Seq[ModelField]): Seq[ModelField] = { inputs.zip(inputs1).flatMap { case (in1, in2) => - if (in1.fieldName == in2.fieldName) { + if (in1.name == in2.name) { val merged = merge(in1, in2) .getOrElse(throw new IllegalArgumentException(s"$in1 and $in2 aren't mergeable")) List(merged) @@ -70,31 +65,33 @@ object ModelFieldOps { def merge(first: ModelField, second: ModelField): Option[ModelField] = { if (first == second) { Some(first) - } else if (first.fieldName == second.fieldName) { - val fieldContents = first.infoOrSubfields -> second.infoOrSubfields match { - case (Subfields(fDict), Subfields(sDict)) => - mergeComplexFields(fDict, sDict).map(ModelField.InfoOrSubfields.Subfields.apply) - case (Info(fInfo), Info(sInfo)) => - TensorInfoOps.merge(fInfo, sInfo).map(ModelField.InfoOrSubfields.Info.apply) - case _ => None + } else if (first.name == second.name) { + TensorShapeProtoOps.merge(first.shape, second.shape).flatMap { shape => + val fieldContents = first.typeOrSubfields -> second.typeOrSubfields match { + case (Subfields(fDict), Subfields(sDict)) => + mergeSubfields(fDict, sDict).map(ModelField.TypeOrSubfields.Subfields.apply) + case (Dtype(fInfo), Dtype(sInfo)) => + DataTypeOps.merge(fInfo, sInfo).map(ModelField.TypeOrSubfields.Dtype.apply) + case _ => None + } + fieldContents.map(ModelField(first.name, shape, _)) } - fieldContents.map(ModelField(first.fieldName, _)) } else { None } } - def mergeComplexFields( - first: ModelField.ComplexField, - second: ModelField.ComplexField - ): Option[ModelField.ComplexField] = { + def mergeSubfields( + first: ModelField.Subfield, + second: ModelField.Subfield + ): Option[ModelField.Subfield] = { val fields = second.data.map { field => - val emitterField = first.data.find(_.fieldName == field.fieldName) + val emitterField = first.data.find(_.name == field.name) emitterField.flatMap(merge(_, field)) } if (fields.forall(_.isDefined)) { val exactFields = fields.flatten - Some(ModelField.ComplexField(exactFields)) + Some(ModelField.Subfield(exactFields)) } else { None } @@ -105,15 +102,38 @@ object ModelFieldOps { } def flatten(rootName: String, field: ModelField): Seq[FieldDescription] = { - val name = s"$rootName/${field.fieldName}" - field.infoOrSubfields match { + val name = s"$rootName/${field.name}" + field.typeOrSubfields match { case Empty => List.empty case Subfields(value) => value.data.flatMap { subfield => flatten(name, subfield) } - case Info(value) => - List(TensorInfoOps.flatten(name, value)) + case Dtype(value) => + List( + FieldDescription( + name, + value, + TensorShape.fromProto(field.shape).dims + ) + ) } } + + def appendAll(outputs: Seq[ModelField], inputs: Seq[ModelField]): Option[Seq[ModelField]] = { + val fields = inputs.map { input => + outputs.find(_.name == input.name).flatMap { output => + merge(output, input) + } + } + + if (fields.exists(_.isEmpty)) { + None + } else { + Some(fields.flatten) + } + } + } + +object ModelFieldOps extends ModelFieldOps \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelSignatureOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelSignatureOps.scala index eb3e2ec..aad4f59 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelSignatureOps.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/ModelSignatureOps.scala @@ -1,22 +1,13 @@ package io.hydrosphere.serving.contract.utils.ops +import io.hydrosphere.serving.contract.model_field.ModelField import io.hydrosphere.serving.contract.utils.description.SignatureDescription import io.hydrosphere.serving.contract.model_signature.ModelSignature trait ModelSignatureOps { - - implicit class ModelSignaturePumped(modelSignature: ModelSignature) { - def +++(other: ModelSignature): ModelSignature = { - ModelSignatureOps.merge(modelSignature, other) - } - } - -} - -object ModelSignatureOps { def merge(signature1: ModelSignature, signature2: ModelSignature): ModelSignature = { - val mergedIns = ModelFieldOps.merge(signature1.inputs, signature2.inputs) - val mergedOuts = ModelFieldOps.merge(signature1.outputs, signature2.outputs) + val mergedIns = ModelFieldOps.mergeAll(signature1.inputs, signature2.inputs) + val mergedOuts = ModelFieldOps.mergeAll(signature1.outputs, signature2.outputs) ModelSignature( s"${signature1.signatureName}&${signature2.signatureName}", mergedIns, @@ -25,8 +16,25 @@ object ModelSignatureOps { } def flatten(modelSignature: ModelSignature): SignatureDescription = { - val inputs = ModelFieldOps.flatten(modelSignature.inputs) + val inputs = ModelFieldOps.flatten(modelSignature.inputs) val outputs = ModelFieldOps.flatten(modelSignature.outputs) SignatureDescription(modelSignature.signatureName, inputs, outputs) } + + def append(head: ModelSignature, tail: ModelSignature): Option[ModelSignature] = { + if (tail.inputs.isEmpty) { + None + } else { + val maybeFields: Option[Seq[ModelField]] = ModelFieldOps.appendAll(head.outputs, tail.inputs) + maybeFields.map { _ => + ModelSignature( + s"${head.signatureName}>${tail.signatureName}", + head.inputs, + tail.outputs + ) + } + } + } } + +object ModelSignatureOps extends ModelSignatureOps \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorInfoOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorInfoOps.scala deleted file mode 100644 index 89f4736..0000000 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorInfoOps.scala +++ /dev/null @@ -1,38 +0,0 @@ -package io.hydrosphere.serving.contract.utils.ops - -import io.hydrosphere.serving.contract.utils.description.FieldDescription -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo - -trait TensorInfoOps { - implicit class TensorInfoPumped(tensorInfo: TensorInfo) { - def flatten(rootName: String = ""): FieldDescription = { - TensorInfoOps.flatten(rootName, tensorInfo) - } - } -} - -object TensorInfoOps { - def merge(first: TensorInfo, second: TensorInfo): Option[TensorInfo] = { - if (first.dtype != second.dtype) { - None - } else { - first.tensorShape -> second.tensorShape match { - case (em, re) if em == re => Some(first) - case (Some(em), Some(re)) if re.unknownRank == em.unknownRank && re.unknownRank => - Some(first) - case (Some(em), Some(re)) => - val shape = TensorShapeProtoOps.merge(em, re) - Some(TensorInfo(first.dtype, shape)) - case _ => None - } - } - } - - def flatten(rootName: String, tensor: TensorInfo): FieldDescription = { - FieldDescription( - rootName, - tensor.dtype, - TensorShapeProtoOps.shapeToList(tensor.tensorShape) - ) - } -} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/validation/package.scala b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/validation/package.scala index 2fda4e6..ebcab32 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/validation/package.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/validation/package.scala @@ -11,7 +11,7 @@ package object validation { class SignatureMissingError(val expectedSignature: String, val modelContract: ModelContract) extends ValidationError( - s"Couldn't find '$expectedSignature' signature in '${modelContract.modelName} model contract'" + s"Couldn't find '$expectedSignature' signature among [${modelContract.signatures.map(_.signatureName)}] signatures" ) {} class SignatureValidationError( @@ -26,7 +26,7 @@ package object validation { class ComplexFieldValidationError(val suberrors: Seq[ValidationError], val field: ModelField) extends ValidationError( - s"Errors while validating subfields for '${field.fieldName}' field: ${suberrors.mkString("\n")}" + s"Errors while validating subfields for '${field.name}' field: ${suberrors.mkString("\n")}" ) class IncompatibleFieldTypeError(val field: String, val expectedType: DataType) diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/TensorShape.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/TensorShape.scala new file mode 100644 index 0000000..ee370d8 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/TensorShape.scala @@ -0,0 +1,37 @@ +package io.hydrosphere.serving.tensorflow + +import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto + +case class TensorShape(dims: Option[Seq[Long]], unknownRank: Boolean = false) { + def toProto: Option[TensorShapeProto] = { + dims.map { shapeDims => + TensorShapeProto( + dim = shapeDims.map(TensorShapeProto.Dim.apply(_)), + unknownRank = unknownRank + ) + } + } +} + +object TensorShape { + def scalar: TensorShape = TensorShape(None) + + def vector(size: Long) = TensorShape(Some(Seq(size))) + + def mat(dims: Long*) = TensorShape(Some(dims)) + + def fromProto(protoShape: Option[TensorShapeProto]): TensorShape = { + TensorShape( + dims = protoShape.map { shape => + shape.dim.map(_.size) + }, + unknownRank = protoShape.exists(_.unknownRank) + ) + } + + def fromSeq(dims: Option[Seq[Long]]): TensorShape = { + TensorShape( + dims = dims + ) + } +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/BoolTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/BoolTensor.scala new file mode 100644 index 0000000..db37c81 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/BoolTensor.scala @@ -0,0 +1,25 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class BoolTensor(shape: TensorShape, data: Seq[Boolean]) extends TypedTensor[DataType.DT_BOOL.type] { + override type Self = BoolTensor + + override type DataT = Boolean + + override def dtype = DataType.DT_BOOL + + override def factory = BoolTensor +} + +object BoolTensor extends TypedTensorFactory[BoolTensor] { + + override implicit def lens: TensorProtoLens[BoolTensor] = new TensorProtoLens[BoolTensor] { + override def getter: TensorProto => Seq[Boolean] = _.boolVal + + override def setter: (TensorProto, Seq[Boolean]) => TensorProto = _.withBoolVal(_) + } + + override def constructor = BoolTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DComplexTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DComplexTensor.scala new file mode 100644 index 0000000..6d6250e --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DComplexTensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class DComplexTensor(shape: TensorShape, data: Seq[Double]) extends TypedTensor[DataType.DT_COMPLEX128.type] { + override type Self = DComplexTensor + + override type DataT = Double + + override def dtype = DataType.DT_COMPLEX128 + + override def factory = DComplexTensor +} + +object DComplexTensor extends TypedTensorFactory[DComplexTensor] { + override implicit def lens: TensorProtoLens[DComplexTensor] = new TensorProtoLens[DComplexTensor] { + override def getter: TensorProto => Seq[Double] = _.dcomplexVal + + override def setter: (TensorProto, Seq[Double]) => TensorProto = _.withDcomplexVal(_) + } + + override def constructor = DComplexTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DoubleTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DoubleTensor.scala new file mode 100644 index 0000000..702fe44 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/DoubleTensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class DoubleTensor(shape: TensorShape, data: Seq[Double]) extends TypedTensor[DataType.DT_DOUBLE.type] { + override type Self = DoubleTensor + + override type DataT = Double + + override def dtype = DataType.DT_DOUBLE + + override def factory = DoubleTensor +} + +object DoubleTensor extends TypedTensorFactory[DoubleTensor] { + override implicit def lens: TensorProtoLens[DoubleTensor] = new TensorProtoLens[DoubleTensor] { + override def getter: TensorProto => Seq[Double] = _.doubleVal + + override def setter: (TensorProto, Seq[Double]) => TensorProto = _.withDoubleVal(_) + } + + override def constructor = DoubleTensor.apply +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/FloatTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/FloatTensor.scala new file mode 100644 index 0000000..564884c --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/FloatTensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class FloatTensor(shape: TensorShape, data: Seq[Float]) extends TypedTensor[DataType.DT_FLOAT.type] { + override type Self = FloatTensor + + override type DataT = Float + + override def dtype = DataType.DT_FLOAT + + override def factory = FloatTensor +} + +object FloatTensor extends TypedTensorFactory[FloatTensor] { + override implicit def lens: TensorProtoLens[FloatTensor] = new TensorProtoLens[FloatTensor] { + override def getter: TensorProto => Seq[Float] = _.floatVal + + override def setter: (TensorProto, Seq[Float]) => TensorProto = _.withFloatVal(_) + } + + override def constructor = FloatTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int16Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int16Tensor.scala new file mode 100644 index 0000000..6deebd4 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int16Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Int16Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_INT16.type] { + override type Self = Int16Tensor + + override def dtype = DataType.DT_INT16 + + override def factory = Int16Tensor +} + +object Int16Tensor extends TypedTensorFactory[Int16Tensor] { + override implicit def lens: TensorProtoLens[Int16Tensor] = IntTensor.protoLens[Int16Tensor] + + override def constructor = Int16Tensor.apply +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int32Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int32Tensor.scala new file mode 100644 index 0000000..d97e376 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int32Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Int32Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_INT32.type] { + override type Self = Int32Tensor + + override def dtype = DataType.DT_INT32 + + override def factory = Int32Tensor +} + +object Int32Tensor extends TypedTensorFactory[Int32Tensor] { + override implicit def lens: TensorProtoLens[Int32Tensor] = IntTensor.protoLens[Int32Tensor] + + override def constructor = Int32Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int64Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int64Tensor.scala new file mode 100644 index 0000000..b09a57d --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int64Tensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Int64Tensor(shape: TensorShape, data: Seq[Long]) extends TypedTensor[DataType.DT_INT64.type] { + override type Self = Int64Tensor + + override type DataT = Long + + override def dtype = DataType.DT_INT64 + + override def factory = Int64Tensor +} + +object Int64Tensor extends TypedTensorFactory[Int64Tensor] { + override implicit def lens: TensorProtoLens[Int64Tensor] = new TensorProtoLens[Int64Tensor] { + override def getter: TensorProto => Seq[Long] = _.int64Val + + override def setter: (TensorProto, Seq[Long]) => TensorProto = _.withInt64Val(_) + } + + override def constructor = Int64Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int8Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int8Tensor.scala new file mode 100644 index 0000000..0c4cfc0 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Int8Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Int8Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_INT8.type] { + override type Self = Int8Tensor + + override def dtype = DataType.DT_INT8 + + override def factory = Int8Tensor +} + +object Int8Tensor extends TypedTensorFactory[Int8Tensor] { + override implicit def lens = IntTensor.protoLens[Int8Tensor] + + override def constructor = Int8Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/IntTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/IntTensor.scala new file mode 100644 index 0000000..74c578d --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/IntTensor.scala @@ -0,0 +1,15 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.types.DataType + +trait IntTensor[T <: DataType] extends TypedTensor[T] { + final override type DataT = Int +} + +object IntTensor { + def protoLens[T <: IntTensor[_]] = new TensorProtoLens[T] { + override def getter: TensorProto => Seq[Int] = _.intVal + + override def setter: (TensorProto, Seq[Int]) => TensorProto = _.withIntVal(_) + } +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/MapTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/MapTensor.scala new file mode 100644 index 0000000..b7bbbc5 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/MapTensor.scala @@ -0,0 +1,33 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class MapTensor(shape: TensorShape, data: Seq[Map[String, TypedTensor[_]]]) extends TypedTensor[DataType.DT_MAP.type] { + override type Self = MapTensor + + override type DataT = Map[String, TypedTensor[_]] + + override def dtype = DataType.DT_MAP + + override def factory = MapTensor +} + +object MapTensor extends TypedTensorFactory[MapTensor] { + override implicit def lens: TensorProtoLens[MapTensor] = new TensorProtoLens[MapTensor] { + override def getter: TensorProto => Seq[Map[String, TypedTensor[_]]] = { tensor => + tensor.mapVal.map { + _.subtensors.mapValues(TypedTensorFactory.create) + } + } + + override def setter: (TensorProto, Seq[Map[String, TypedTensor[_]]]) => TensorProto = { (tensor, maps) => + val protoMaps = maps.map { tensorMap => + MapTensorData(tensorMap.mapValues(_.toProto)) + } + tensor.withMapVal(protoMaps) + } + } + + override def constructor = MapTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/SComplexTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/SComplexTensor.scala new file mode 100644 index 0000000..5496f3f --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/SComplexTensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class SComplexTensor(shape: TensorShape, data: Seq[Float]) extends TypedTensor[DataType.DT_COMPLEX64.type] { + override type Self = SComplexTensor + + override type DataT = Float + + override def dtype = DataType.DT_COMPLEX64 + + override def factory = SComplexTensor +} + +object SComplexTensor extends TypedTensorFactory[SComplexTensor] { + override implicit def lens: TensorProtoLens[SComplexTensor] = new TensorProtoLens[SComplexTensor] { + override def getter: TensorProto => Seq[Float] = _.scomplexVal + + override def setter: (TensorProto, Seq[Float]) => TensorProto = _.withScomplexVal(_) + } + + override def constructor = SComplexTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/StringTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/StringTensor.scala new file mode 100644 index 0000000..e18e05e --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/StringTensor.scala @@ -0,0 +1,25 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import com.google.protobuf.ByteString +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class StringTensor(shape: TensorShape, data: Seq[String]) extends TypedTensor[DataType.DT_STRING.type] { + override type Self = StringTensor + + override type DataT = String + + override def dtype = DataType.DT_STRING + + override def factory = StringTensor +} + +object StringTensor extends TypedTensorFactory[StringTensor] { + override implicit def lens: TensorProtoLens[StringTensor] = new TensorProtoLens[StringTensor] { + override def getter: TensorProto => Seq[String] = _.stringVal.map(_.toStringUtf8) + + override def setter: (TensorProto, Seq[String]) => TensorProto = (t, d) => t.withStringVal(d.map(ByteString.copyFromUtf8)) + } + + override def constructor = StringTensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TensorProtoLens.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TensorProtoLens.scala new file mode 100644 index 0000000..8ea1955 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TensorProtoLens.scala @@ -0,0 +1,13 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import com.trueaccord.lenses.Lens + +trait TensorProtoLens[T <: TypedTensor[_]] { + def getter: TensorProto => Seq[T#DataT] + + def setter: (TensorProto, Seq[T#DataT]) => TensorProto + + final def lens: Lens[TensorProto, Seq[T#DataT]] = { + Lens[TensorProto, Seq[T#DataT]](getter)(setter) + } +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensor.scala new file mode 100644 index 0000000..84b02a5 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensor.scala @@ -0,0 +1,24 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +trait TypedTensor[DTypeT] { + type Self <: TypedTensor[DTypeT] + type DataT + + def data: Seq[Self#DataT] + + def shape: TensorShape + + def dtype: DataType + + def factory: TypedTensorFactory[Self] + + final def toProto: TensorProto = { + val pretensor = TensorProto(dtype, shape.toProto) + pretensor.update { _ => + factory.lens.lens := data + } + } +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensorFactory.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensorFactory.scala new file mode 100644 index 0000000..641d615 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/TypedTensorFactory.scala @@ -0,0 +1,95 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.contract.utils.validation.{InvalidFieldData, UnsupportedFieldTypeError, ValidationError} +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType +import io.hydrosphere.serving.tensorflow.types.DataType._ + +trait TypedTensorFactory[TensorT <: TypedTensor[_]] { + + implicit def lens: TensorProtoLens[TensorT] + + def constructor: (TensorShape, Seq[TensorT#DataT]) => TensorT + + /** + * Tries to convert `data` to tensor-specific type. + * + * @param data target data + * @return converted data or error + */ + final def castData(data: Seq[Any]): Either[ValidationError, Seq[TensorT#DataT]] = { + try { + Right(data.map(_.asInstanceOf[TensorT#DataT])) + } catch { + case ex: ClassCastException => Left(new InvalidFieldData(ex.getClass)) + } + } + + final def empty: TensorT = { + constructor(TensorShape.scalar, Seq.empty) + } + + final def create( + data: Seq[TensorT#DataT], + shape: TensorShape + ): TensorT = { + constructor(shape, data) + } + + final def fromProto(proto: TensorProto): TensorT = { + constructor(TensorShape.fromProto(proto.tensorShape), lens.lens.get(proto)) + } + + /** + * Creates tensor with `data` and `tensorInfo` + * + * @param data contents to be put in tensor + * @param shape data shape + * @return tensor with data or error + */ + final def createFromAny( + data: Seq[Any], + shape: TensorShape + ): Option[TensorT] = { + val tensorProto = TensorProto(dtype = empty.dtype, tensorShape = shape.toProto) + castData(data).right.toOption.map { converted => + val newTensorProto = tensorProto.update(_ => lens.lens := converted) + fromProto(newTensorProto) + } + } +} + +object TypedTensorFactory { + + def apply(dataType: DataType): TypedTensorFactory[_ <: TypedTensor[_]] = { + dataType match { + case DT_FLOAT => FloatTensor + case DT_DOUBLE => DoubleTensor + + case DT_INT8 | DT_QINT8 => Int8Tensor + case DT_INT16 | DT_QINT16 => Int16Tensor + case DT_INT32 => Int32Tensor + + case DT_UINT8 | DT_QUINT8 => Uint8Tensor + case DT_UINT16 | DT_QUINT16 => Uint16Tensor + case DT_UINT32 => Uint32Tensor + + case DT_INT64 => Int64Tensor + case DT_UINT64 => Uint64Tensor + + case DT_COMPLEX64 => SComplexTensor + case DT_COMPLEX128 => DComplexTensor + + case DT_STRING => StringTensor + case DT_BOOL => BoolTensor + + case DT_MAP => MapTensor + + case x => throw new UnsupportedFieldTypeError(x) + } + } + + def create(tensorProto: TensorProto): TypedTensor[_] = { + TypedTensorFactory(tensorProto.dtype).fromProto(tensorProto) + } +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint16Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint16Tensor.scala new file mode 100644 index 0000000..554583e --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint16Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Uint16Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_UINT16.type] { + override type Self = Uint16Tensor + + override def dtype = DataType.DT_UINT16 + + override def factory = Uint16Tensor +} + +object Uint16Tensor extends TypedTensorFactory[Uint16Tensor] { + override implicit def lens: TensorProtoLens[Uint16Tensor] = UintTensor.protoLens[Uint16Tensor] + + override def constructor = Uint16Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint32Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint32Tensor.scala new file mode 100644 index 0000000..5938ecd --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint32Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Uint32Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_UINT32.type] { + override type Self = Uint32Tensor + + override def dtype = DataType.DT_UINT32 + + override def factory = Uint32Tensor +} + +object Uint32Tensor extends TypedTensorFactory[Uint32Tensor] { + override implicit def lens: TensorProtoLens[Uint32Tensor] = UintTensor.protoLens[Uint32Tensor] + + override def constructor = Uint32Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint64Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint64Tensor.scala new file mode 100644 index 0000000..b204390 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint64Tensor.scala @@ -0,0 +1,23 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Uint64Tensor(shape: TensorShape, data: Seq[Long]) extends TypedTensor[DataType.DT_UINT64.type] { + override type Self = Uint64Tensor + override type DataT = Long + + override def dtype = DataType.DT_UINT64 + + override def factory = Uint64Tensor +} + +object Uint64Tensor extends TypedTensorFactory[Uint64Tensor] { + override implicit def lens: TensorProtoLens[Uint64Tensor] = new TensorProtoLens[Uint64Tensor] { + override def getter: TensorProto => Seq[Long] = _.uint64Val + + override def setter: (TensorProto, Seq[Long]) => TensorProto = _.withUint64Val(_) + } + + override def constructor = Uint64Tensor.apply +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint8Tensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint8Tensor.scala new file mode 100644 index 0000000..f6cf0e4 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/Uint8Tensor.scala @@ -0,0 +1,18 @@ +package io.hydrosphere.serving.tensorflow.tensor + +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.types.DataType + +case class Uint8Tensor(shape: TensorShape, data: Seq[Int]) extends IntTensor[DataType.DT_UINT8.type] { + override type Self = Uint8Tensor + + override def dtype = DataType.DT_UINT8 + + override def factory = Uint8Tensor +} + +object Uint8Tensor extends TypedTensorFactory[Uint8Tensor] { + override implicit def lens: TensorProtoLens[Uint8Tensor] = UintTensor.protoLens[Uint8Tensor] + + override def constructor = Uint8Tensor.apply +} \ No newline at end of file diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/UintTensor.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/UintTensor.scala new file mode 100644 index 0000000..55b5703 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/tensor/UintTensor.scala @@ -0,0 +1,9 @@ +package io.hydrosphere.serving.tensorflow.tensor + +object UintTensor { + def protoLens[T <: IntTensor[_]] = new TensorProtoLens[T] { + override def getter: TensorProto => Seq[Int] = _.uint32Val + + override def setter: (TensorProto, Seq[Int]) => TensorProto = _.withUint32Val(_) + } +} diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/DataTypeOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/DataTypeOps.scala new file mode 100644 index 0000000..4616ca0 --- /dev/null +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/DataTypeOps.scala @@ -0,0 +1,15 @@ +package io.hydrosphere.serving.tensorflow.utils.ops + +import io.hydrosphere.serving.tensorflow.types.DataType + +trait DataTypeOps { + def merge(d1: DataType, d2: DataType): Option[DataType] = { + if (d1 == d2) { + Some(d1) + } else { + None + } + } +} + +object DataTypeOps extends DataTypeOps diff --git a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorShapeProtoOps.scala b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/TensorShapeProtoOps.scala similarity index 51% rename from scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorShapeProtoOps.scala rename to scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/TensorShapeProtoOps.scala index dbaca2f..7cbf9c9 100644 --- a/scala-package/src/main/scala/io/hydrosphere/serving/contract/utils/ops/TensorShapeProtoOps.scala +++ b/scala-package/src/main/scala/io/hydrosphere/serving/tensorflow/utils/ops/TensorShapeProtoOps.scala @@ -1,8 +1,9 @@ -package io.hydrosphere.serving.contract.utils.ops +package io.hydrosphere.serving.tensorflow.utils.ops import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto trait TensorShapeProtoOps { + type Shape = Option[TensorShapeProto] implicit class TensorShapeProtoPumped(tensorShapeProto: TensorShapeProto) { def toDimList: List[Long] = { @@ -10,18 +11,26 @@ trait TensorShapeProtoOps { } } -} + def merge(firstShape: Shape, secondShape: Shape): Option[Shape] = { + firstShape -> secondShape match { + case (em, re) if em == re => Some(firstShape) + case (Some(em), Some(re)) if re.unknownRank == em.unknownRank && re.unknownRank => + Some(firstShape) + case (Some(em), Some(re)) => + TensorShapeProtoOps.merge(em, re).map(Some.apply) + case _ => None + } + } -object TensorShapeProtoOps { def merge(first: TensorShapeProto, second: TensorShapeProto): Option[TensorShapeProto] = { if (first.dim.lengthCompare(second.dim.length) != 0) { None } else { val dims = first.dim.zip(second.dim).map { case (fDim, sDim) if fDim.size == sDim.size => Some(fDim) - case (fDim, sDim) if fDim.size == -1 => Some(sDim) - case (fDim, sDim) if sDim.size == -1 => Some(fDim) - case _ => None + case (fDim, sDim) if fDim.size == -1 => Some(sDim) + case (fDim, sDim) if sDim.size == -1 => Some(fDim) + case _ => None } if (dims.forall(_.isDefined)) { Some(TensorShapeProto(dims.map(_.get))) @@ -30,12 +39,6 @@ object TensorShapeProtoOps { } } } - - def shapeToList(tensorShapeProto: Option[TensorShapeProto]): Option[List[Long]] = { - tensorShapeProto.map { shape => - shape.dim.map { dim => - dim.size - }.toList - } - } } + +object TensorShapeProtoOps extends TensorShapeProtoOps \ No newline at end of file diff --git a/scala-package/src/test/scala/ContractOpsSpecs.scala b/scala-package/src/test/scala/ContractOpsSpecs.scala index e4424b1..3ae4fc8 100644 --- a/scala-package/src/test/scala/ContractOpsSpecs.scala +++ b/scala-package/src/test/scala/ContractOpsSpecs.scala @@ -4,8 +4,10 @@ import io.hydrosphere.serving.contract.model_signature.ModelSignature import io.hydrosphere.serving.contract.utils.ContractBuilders import io.hydrosphere.serving.contract.utils.description.{ContractDescription, FieldDescription, SignatureDescription} import io.hydrosphere.serving.contract.utils.description._ +import io.hydrosphere.serving.contract.utils.ops.ModelSignatureOps import io.hydrosphere.serving.tensorflow.tensor.TensorProto import io.hydrosphere.serving.tensorflow.types.DataType +import io.hydrosphere.serving.tensorflow.utils.ops.TensorShapeProtoOps import org.scalatest.WordSpec @@ -48,7 +50,7 @@ class ContractOpsSpecs extends WordSpec { ) ) - assert(sig1 +++ sig2 === expectedSig) + assert(ModelSignatureOps.merge(sig1, sig2) === expectedSig) } "inputs are overlapping and compatible" in { @@ -82,7 +84,7 @@ class ContractOpsSpecs extends WordSpec { ) ) - assert(sig1 +++ sig2 === expectedSig) + assert(ModelSignatureOps.merge(sig1, sig2) === expectedSig) } } @@ -107,7 +109,7 @@ class ContractOpsSpecs extends WordSpec { ) ) - assertThrows[IllegalArgumentException](sig1 +++ sig2) + assertThrows[IllegalArgumentException](ModelSignatureOps.merge(sig1, sig2)) } "outputs overlap is conflicting" in { @@ -130,7 +132,7 @@ class ContractOpsSpecs extends WordSpec { ) ) - assertThrows[IllegalArgumentException](sig1 +++ sig2) + assertThrows[IllegalArgumentException](ModelSignatureOps.merge(sig1, sig2)) } } } @@ -155,7 +157,7 @@ class ContractOpsSpecs extends WordSpec { ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, Some(List(3))) ) ) - val contract = ModelContract("test", List(sig1, sig2)) + val contract = ModelContract("test", Seq(sig1, sig2)) val expected = ContractDescription( List( @@ -189,6 +191,7 @@ class ContractOpsSpecs extends WordSpec { List( ContractBuilders.complexField( "in", + None, Seq( ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, None), ContractBuilders.simpleTensorModelField("in2", DataType.DT_INT32, None) @@ -198,6 +201,7 @@ class ContractOpsSpecs extends WordSpec { List( ContractBuilders.complexField( "out", + None, Seq( ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))), ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, None) @@ -214,7 +218,7 @@ class ContractOpsSpecs extends WordSpec { ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, Some(List(3))) ) ) - val contract = ModelContract("test", List(sig1, sig2)) + val contract = ModelContract("test", Seq(sig1, sig2)) val expected = ContractDescription( List( @@ -275,6 +279,7 @@ class ContractOpsSpecs extends WordSpec { List( ContractBuilders.complexField( "in", + None, Seq( ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, None), ContractBuilders.simpleTensorModelField("in2", DataType.DT_INT32, None) @@ -284,10 +289,10 @@ class ContractOpsSpecs extends WordSpec { List( ContractBuilders.complexField( "out", + None, Seq( ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))), ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, None) - ) ) ) diff --git a/scala-package/src/test/scala/DataGeneratorSpecs.scala b/scala-package/src/test/scala/DataGeneratorSpecs.scala new file mode 100644 index 0000000..7849167 --- /dev/null +++ b/scala-package/src/test/scala/DataGeneratorSpecs.scala @@ -0,0 +1,164 @@ +import com.google.protobuf.ByteString +import io.hydrosphere.serving.contract.model_signature.ModelSignature +import io.hydrosphere.serving.contract.utils.{ContractBuilders, DataGenerator} +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.tensor.{MapTensorData, TensorProto, TypedTensorFactory} +import io.hydrosphere.serving.tensorflow.types.DataType +import org.scalatest.WordSpec + +class DataGeneratorSpecs extends WordSpec { + val fooString = ByteString.copyFromUtf8("foo") + + "DataGenerator" should { + "generate correct example" when { + "scalar flat signature" in { + val sig1 = ModelSignature( + "sig1", + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, None) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))) + ) + ) + + val expected = Map( + "in1" -> TypedTensorFactory.create( + TensorProto( + dtype = DataType.DT_STRING, + tensorShape = None, + stringVal = List(fooString) + ) + ) + ) + + val generated = DataGenerator(sig1).generateInputs + assert(generated === expected) + } + + "vector flat signature" in { + val sig1 = ModelSignature( + "sig1", + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, Some(List(-1))), + ContractBuilders.simpleTensorModelField("in2", DataType.DT_INT32, Some(List(3))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))) + ) + ) + + val expected = Map( + "in1" -> TypedTensorFactory.create( + TensorProto( + dtype = DataType.DT_STRING, + tensorShape = TensorShape.vector(-1).toProto, + stringVal = List(fooString) + ) + ), + "in2" -> TypedTensorFactory.create( + TensorProto( + dtype = DataType.DT_INT32, + tensorShape = TensorShape.vector(3).toProto, + intVal = List(1, 1, 1) + ) + ) + ) + + val generated = DataGenerator(sig1).generateInputs + assert(generated === expected) + } + + "nested singular signature" in { + val sig1 = ModelSignature( + "sig1", + List( + ContractBuilders.complexField( + "in1", + None, + Seq( + ContractBuilders.simpleTensorModelField("a", DataType.DT_STRING, None), + ContractBuilders.simpleTensorModelField("b", DataType.DT_STRING, None) + ) + ) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))) + ) + ) + + val expected = Map( + "in1" -> + TypedTensorFactory.create( + TensorProto( + dtype = DataType.DT_MAP, + tensorShape = None, + mapVal = Seq( + MapTensorData( + Map( + "a" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)), + "b" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)) + ) + ) + ) + ) + ) + ) + + val generated = DataGenerator(sig1).generateInputs + assert(generated === expected) + } + + "nested vector signature" in { + val sig1 = ModelSignature( + "sig1", + List( + ContractBuilders.complexField( + "in1", + TensorShape.vector(3).toProto, + Seq( + ContractBuilders.simpleTensorModelField("a", DataType.DT_STRING, None), + ContractBuilders.simpleTensorModelField("b", DataType.DT_STRING, None) + ) + ) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(List(-1))) + ) + ) + + val expected = Map( + "in1" -> TypedTensorFactory.create( + TensorProto( + dtype = DataType.DT_MAP, + tensorShape = TensorShape.vector(3).toProto, + mapVal = Seq( + MapTensorData( + Map( + "a" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)), + "b" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)) + ) + ), + MapTensorData( + Map( + "a" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)), + "b" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)) + ) + ), + MapTensorData( + Map( + "a" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)), + "b" -> TensorProto(DataType.DT_STRING, None, stringVal = List(fooString)) + ) + ) + ) + ) + ) + ) + + val generated = DataGenerator(sig1).generateInputs + assert(generated === expected) + } + } + } +} \ No newline at end of file diff --git a/scala-package/src/test/scala/SignatureCheckerSpecs.scala b/scala-package/src/test/scala/SignatureCheckerSpecs.scala index 5013243..4807527 100644 --- a/scala-package/src/test/scala/SignatureCheckerSpecs.scala +++ b/scala-package/src/test/scala/SignatureCheckerSpecs.scala @@ -1,8 +1,6 @@ -import io.hydrosphere.serving.contract.model_field.ModelField import io.hydrosphere.serving.contract.model_signature.ModelSignature -import io.hydrosphere.serving.contract.utils.SignatureChecker -import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo -import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto +import io.hydrosphere.serving.contract.utils.ContractBuilders +import io.hydrosphere.serving.contract.utils.ops.ModelSignatureOps import io.hydrosphere.serving.tensorflow.types.DataType import org.scalatest.WordSpec @@ -13,202 +11,114 @@ class SignatureCheckerSpecs extends WordSpec { val sig1 = ModelSignature( "sig1", List( - ModelField("in1", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_STRING, None))) + ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, None) ), List( - ModelField("out1", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_STRING, None))) + ContractBuilders.simpleTensorModelField("out1", DataType.DT_STRING, None) ) ) val sig2 = ModelSignature( "sig2", List( - ModelField("out1", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_STRING, None))) + ContractBuilders.simpleTensorModelField("out1", DataType.DT_STRING, None) ), List( - ModelField("out2", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_STRING, None))) + ContractBuilders.simpleTensorModelField("out2", DataType.DT_STRING, None) ) ) - assert(SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isDefined) } "two identical signatures (Double[5],Double[5] -> Double[5],Double[5])" in { val sig1 = ModelSignature( "sig1", - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: Nil)) - ) - ) - ) :: Nil, - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(5))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5))) + ) ) val sig2 = ModelSignature( "sig2", - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: Nil)) - ) - ) - ) :: Nil, - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(5))) + + ) ) - assert(SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isDefined) } "two compatible signatures (Int32[3] -> Int32[-1])" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_INT32, Some(Seq(3))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_INT32, Some(Seq(3))) + ) ) val sig2 = ModelSignature( "sig2", + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_INT32, Some(Seq(-1))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, Some(Seq(-1))) - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(-1) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(-1) :: Nil)) - ) - ) - ) :: Nil + ) ) - assert(SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isDefined) } "two identical signatures (Double[5, 2] -> Double[5, 2])" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(5, 2))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5, 2))) ) - ) :: Nil ) val sig2 = ModelSignature( "sig2", + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5, 2))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(5, 2))) - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) - ) - ) :: Nil + ) ) - assert(SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isDefined) } - "two identical signatures (Double[5, 2] -> Double[5, -1])" in { + "two compatible signatures (Double[5, 2] -> Double[5, -1])" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(2) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(5, 2))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5, 2))) + ) ) val sig2 = ModelSignature( "sig2", + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(5, -1))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(5, -1))) - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(-1) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_INT32, - Some(TensorShapeProto(TensorShapeProto.Dim(5) :: TensorShapeProto.Dim(-1) :: Nil)) - ) - ) - ) :: Nil + ) ) - assert(SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isDefined) } } @@ -218,160 +128,106 @@ class SignatureCheckerSpecs extends WordSpec { val sig1 = ModelSignature( "sig1", List( - ModelField("in1", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_STRING, None))) + ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, None) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_STRING, None) ) ) val sig2 = ModelSignature( "sig2", List( - ModelField("in2", ModelField.InfoOrSubfields.Info(TensorInfo(DataType.DT_INT32, None))) + ContractBuilders.simpleTensorModelField("out1", DataType.DT_INT32, None) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, None) ) ) - assert(! SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isEmpty) } "two completely different signatures (String[3] -> String[4])" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_STRING, Some(Seq(3))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_STRING, Some(Seq(3))) + ) ) val sig2 = ModelSignature( "sig2", - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(4) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(4) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_INT32, Some(Seq(4))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_INT32, Some(Seq(4))) + ) ) - assert(! SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isEmpty) } "two completely different signatures (Double[4] -> Double[3])" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(4) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(4) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(4))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(4))) + ) ) val sig2 = ModelSignature( "sig2", - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out2", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_DOUBLE, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(3))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(3))) + ) ) - assert(! SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isEmpty) } "two signatures when receiver has empty input signature" in { val sig1 = ModelSignature( "sig1", - - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil, - - ModelField("out1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(4))) + ), + List( + ContractBuilders.simpleTensorModelField("out1", DataType.DT_DOUBLE, Some(Seq(4))) + ) ) - val sig2 = ModelSignature() - assert(! SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + val sig2 = ModelSignature( + "sig2", + List(), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(4))) + ) + ) + assert(ModelSignatureOps.append(sig1, sig2).isEmpty) } "two signatures when emitter has empty output signature" in { val sig1 = ModelSignature( "sig1", - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(4))) + ), + List() ) val sig2 = ModelSignature( "sig2", - ModelField("in1", - ModelField.InfoOrSubfields.Info( - TensorInfo( - DataType.DT_STRING, - Some(TensorShapeProto(TensorShapeProto.Dim(3) :: Nil)) - ) - ) - ) :: Nil + List( + ContractBuilders.simpleTensorModelField("in1", DataType.DT_DOUBLE, Some(Seq(4))) + ), + List( + ContractBuilders.simpleTensorModelField("out2", DataType.DT_DOUBLE, Some(Seq(4))) + ) ) - assert(! SignatureChecker.areSequentiallyCompatible(sig1, sig2)) + assert(ModelSignatureOps.append(sig1, sig2).isEmpty) } } } diff --git a/scala-package/src/test/scala/TypedTensorSpecs.scala b/scala-package/src/test/scala/TypedTensorSpecs.scala new file mode 100644 index 0000000..f4173b4 --- /dev/null +++ b/scala-package/src/test/scala/TypedTensorSpecs.scala @@ -0,0 +1,69 @@ +import com.google.protobuf.ByteString +import io.hydrosphere.serving.contract.utils.DataGenerator +import io.hydrosphere.serving.tensorflow.TensorShape +import io.hydrosphere.serving.tensorflow.tensor._ +import io.hydrosphere.serving.tensorflow.types.DataType +import org.scalatest.WordSpec + +class TypedTensorSpecs extends WordSpec { + val shape = TensorShape.mat(2, 2) + + def test[TFactory <: TypedTensorFactory[_ <: TypedTensor[_ <: DataType]]] + (factory: TFactory) = { + factory.getClass.getSimpleName in { + val tensor = DataGenerator.generateTensor(shape, factory.empty.dtype) + assert(tensor.get === factory.fromProto(tensor.get.toProto)) + } + } + + "Scalar tensors" should { + "convert to TensorProto" when { + test(BoolTensor) + + test(DoubleTensor) + test(FloatTensor) + + test(StringTensor) + + test(SComplexTensor) + test(DComplexTensor) + + test(Int8Tensor) + test(Int16Tensor) + test(Int32Tensor) + test(Int64Tensor) + + test(Uint8Tensor) + test(Uint16Tensor) + test(Uint32Tensor) + test(Uint64Tensor) + } + } + + "Map tensor" should { + "convert to TensorProto" in { + val expected = TensorProto( + dtype = DataType.DT_MAP, + tensorShape = None, + mapVal = Seq( + MapTensorData( + Map( + "name" -> TensorProto( + dtype = DataType.DT_STRING, + tensorShape = None, + stringVal = Seq(ByteString.copyFromUtf8("John")) + ), + "surname" -> TensorProto( + dtype = DataType.DT_STRING, + tensorShape = None, + stringVal = Seq(ByteString.copyFromUtf8("Doe")) + ) + ) + ) + ) + ) + + assert(TypedTensorFactory.create(expected).toProto === expected) + } + } +} diff --git a/src/hydro_serving_grpc/contract/model_field.proto b/src/hydro_serving_grpc/contract/model_field.proto index 66ebfda..a190894 100644 --- a/src/hydro_serving_grpc/contract/model_field.proto +++ b/src/hydro_serving_grpc/contract/model_field.proto @@ -3,17 +3,20 @@ syntax = "proto3"; package hydrosphere.contract; option java_package = "io.hydrosphere.serving.contract"; -import "hydro_serving_grpc/tf/tensor_info.proto"; +import "hydro_serving_grpc/tf/types.proto"; +import "hydro_serving_grpc/tf/tensor_shape.proto"; message ModelField { - message ComplexField { + message Subfield { repeated ModelField data = 1; } - string field_name = 1; + string name = 1; - oneof info_or_subfields { - ComplexField subfields = 2; - hydrosphere.tensorflow.TensorInfo info = 3; + tensorflow.TensorShapeProto shape = 2; + + oneof type_or_subfields { + Subfield subfields = 3; + tensorflow.DataType dtype = 4; } } \ No newline at end of file diff --git a/src/hydro_serving_grpc/tf/api/predict.proto b/src/hydro_serving_grpc/tf/api/predict.proto index 7cabb00..9c1037e 100644 --- a/src/hydro_serving_grpc/tf/api/predict.proto +++ b/src/hydro_serving_grpc/tf/api/predict.proto @@ -27,7 +27,7 @@ message PredictRequest { // Only tensors specified here will be run/fetched and returned, with the // exception that when none is specified, all tensors specified in the // named signature will be run/fetched and returned. - repeated string output_filter = 3; + reserved 3; } // Response for PredictRequest on successful run. diff --git a/src/hydro_serving_grpc/tf/tensor.proto b/src/hydro_serving_grpc/tf/tensor.proto index 6cc0a93..b4920bc 100644 --- a/src/hydro_serving_grpc/tf/tensor.proto +++ b/src/hydro_serving_grpc/tf/tensor.proto @@ -80,7 +80,7 @@ message TensorProto { // DT_UINT64 repeated uint64 uint64_val = 17 [packed = true]; - map map_val = 27; // Hydroserving + repeated MapTensorData map_val = 27; // Hydroserving }; // Protocol buffer representing the serialization format of DT_VARIANT tensors. @@ -92,3 +92,7 @@ message VariantTensorDataProto { // Tensors contained within objects being serialized. repeated TensorProto tensors = 3; } + +message MapTensorData { + map subtensors = 1; +} \ No newline at end of file diff --git a/src/hydro_serving_grpc/tf/tensor_info.proto b/src/hydro_serving_grpc/tf/tensor_info.proto deleted file mode 100644 index 32905fc..0000000 --- a/src/hydro_serving_grpc/tf/tensor_info.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax = "proto3"; - -package hydrosphere.tensorflow; -option cc_enable_arenas = true; -option java_multiple_files = true; -option java_package = "io.hydrosphere.serving.tensorflow"; - -import "hydro_serving_grpc/tf/types.proto"; -import "hydro_serving_grpc/tf/tensor_shape.proto"; - -message TensorInfo { - reserved 1; // hydroserving -- hide tensor name, move it to ModelField - DataType dtype = 2; - TensorShapeProto tensor_shape = 3; -} \ No newline at end of file diff --git a/src/hydro_serving_grpc/tf/tensor_slice.proto b/src/hydro_serving_grpc/tf/tensor_slice.proto deleted file mode 100644 index cfd1202..0000000 --- a/src/hydro_serving_grpc/tf/tensor_slice.proto +++ /dev/null @@ -1,37 +0,0 @@ -// Protocol buffer representing slices of a tensor - -syntax = "proto3"; - -package hydrosphere.tensorflow; - -option cc_enable_arenas = true; -option java_multiple_files = true; -option java_package = "io.hydrosphere.serving.tensorflow"; - -// Can only be interpreted if you know the corresponding TensorShape. -message TensorSliceProto { - // Extent of the slice in one dimension. - message Extent { - // Either both or no attributes must be set. When no attribute is set - // means: All data in that dimension. - - // Start index of the slice, starting at 0. - int64 start = 1; - - // Length of the slice: if the length is missing or -1 we will - // interpret this as "everything in this dimension". We use - // "oneof" to preserve information about whether the length is - // present without changing the serialization format from the - // prior proto2 version of this proto. - oneof has_length { - int64 length = 2; - } - }; - - // Extent of the slice in all tensor dimensions. - // - // Must have one entry for each of the dimension of the tensor that this - // slice belongs to. The order of sizes is the same as the order of - // dimensions in the TensorShape. - repeated Extent extent = 1; -}; diff --git a/src/hydro_serving_grpc/tf/types.proto b/src/hydro_serving_grpc/tf/types.proto index 4effa44..21a4319 100644 --- a/src/hydro_serving_grpc/tf/types.proto +++ b/src/hydro_serving_grpc/tf/types.proto @@ -38,33 +38,4 @@ enum DataType { DT_UINT64 = 23; DT_MAP = 27; // Hydroserving map structure inside the tensor - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; - - DT_MAP_REF = 127; // Hydroserving - } diff --git a/version b/version index 17d7380..4ecb664 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.0.16-SNAPSHOT \ No newline at end of file +0.1.0-SNAPSHOT \ No newline at end of file