Skip to content

Commit

Permalink
Fix/map tensor (#10)
Browse files Browse the repository at this point in the history
* Map reimplementation and cleanup

* Typed tensor refactor and DataGenerator tests

* Shape class and generation fixes
  • Loading branch information
KineticCookie authored Mar 20, 2018
1 parent ab805f7 commit f3fe4a6
Show file tree
Hide file tree
Showing 47 changed files with 1,097 additions and 824 deletions.
Original file line number Diff line number Diff line change
@@ -1,37 +1,18 @@
package io.hydrosphere.serving.contract.utils

import io.hydrosphere.serving.contract.model_field.ModelField
import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo
import io.hydrosphere.serving.tensorflow.TensorShape
import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto
import io.hydrosphere.serving.tensorflow.types.DataType
import io.hydrosphere.serving.tensorflow.utils.ops.TensorShapeProtoOps

object ContractBuilders {
def createUnknownTensorShape(): TensorShapeProto = {
TensorShapeProto(unknownRank = true)
}

def createTensorShape(dims: Seq[Long], unknownRank: Boolean = false): TensorShapeProto = {
TensorShapeProto(
dim = dims.map { d =>
TensorShapeProto.Dim(d)
},
unknownRank = unknownRank
)
}

def createTensorInfo(
dataType: DataType,
shape: Option[Seq[Long]],
unknownRank: Boolean = false
): TensorInfo = {
TensorInfo(dataType, shape.map(s => createTensorShape(s, unknownRank)))
}

def complexField(name: String, subFields: Seq[ModelField]): ModelField = {
def complexField(name: String, shape: Option[TensorShapeProto], subFields: Seq[ModelField]): ModelField = {
ModelField(
name,
ModelField.InfoOrSubfields.Subfields(
ModelField.ComplexField(
shape,
ModelField.TypeOrSubfields.Subfields(
ModelField.Subfield(
subFields
)
)
Expand All @@ -43,7 +24,7 @@ object ContractBuilders {
dataType: DataType,
shape: Option[TensorShapeProto]
): ModelField = {
ModelField(name, ModelField.InfoOrSubfields.Info(TensorInfo(dataType, shape)))
ModelField(name, shape, ModelField.TypeOrSubfields.Dtype(dataType))
}

def simpleTensorModelField(
Expand All @@ -54,7 +35,8 @@ object ContractBuilders {
): ModelField = {
ModelField(
name,
ModelField.InfoOrSubfields.Info(createTensorInfo(dataType, shape, unknownRank))
TensorShape(shape, unknownRank).toProto,
ModelField.TypeOrSubfields.Dtype(dataType)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,19 @@ package io.hydrosphere.serving.contract.utils

import io.hydrosphere.serving.contract.model_contract.ModelContract
import io.hydrosphere.serving.contract.model_field.ModelField
import io.hydrosphere.serving.contract.model_field.ModelField.InfoOrSubfields.{
Empty,
Info,
Subfields
}
import io.hydrosphere.serving.contract.model_field.ModelField.TypeOrSubfields.{Dtype, Empty, Subfields}
import io.hydrosphere.serving.contract.model_signature.ModelSignature
import io.hydrosphere.serving.tensorflow.tensor.TensorProto
import io.hydrosphere.serving.tensorflow.tensor_info.TensorInfo
import io.hydrosphere.serving.tensorflow.tensor_shape.TensorShapeProto
import io.hydrosphere.serving.tensorflow.TensorShape
import io.hydrosphere.serving.tensorflow.tensor._
import io.hydrosphere.serving.tensorflow.types.DataType
import io.hydrosphere.serving.tensorflow.types.DataType._

class DataGenerator(val modelApi: ModelSignature) {
def generateInputs: Map[String, TensorProto] = {
def generateInputs: Map[String, TypedTensor[_]] = {
modelApi.inputs.flatMap(DataGenerator.generateField).toMap
}

def generateOutputs: Map[String, TensorProto] = {
def generateOutputs: Map[String, TypedTensor[_]] = {
modelApi.outputs.flatMap(DataGenerator.generateField).toMap
}
}
Expand All @@ -31,7 +26,7 @@ object DataGenerator {
modelContract.signatures.find(_.signatureName == signature).map(DataGenerator.apply)
}

def generateData(dataType: DataType): Any = {
def generateScalarData[T <: DataType](dataType: T): Any = {
dataType match {
case DT_FLOAT | DT_COMPLEX64 => 1.0F
case DT_DOUBLE | DT_COMPLEX128 => 1.0D
Expand All @@ -44,46 +39,46 @@ object DataGenerator {

case DT_INVALID =>
throw new IllegalArgumentException(
s"Can't convert data to DT_INVALID has an invalid dtype"
s"Can't convert data to DT_INVALID"
)
case DT_VARIANT =>
throw new IllegalArgumentException(s"Cannot process DT_VARIANT Tensor. Not supported yet.")
case Unrecognized(value) =>
throw new IllegalArgumentException(s"Cannot process Tensor with Unrecognized($value) dtype")
case x => throw new IllegalArgumentException(s"Cannot process Tensor with $x dtype") // refs
}
}

def generateData(dataType: DataType, shape: Option[TensorShapeProto]): List[Any] = {
shape match {
def createFlatTensor[T](shape: TensorShape, generator: => T): Seq[T] = {
shape.dims match {
case Some(sh) =>
sh.dim.map(_.size.max(1)).reverse.foldLeft(List.empty[Any]) {
case (Nil, y) =>
1L.to(y).map(_ => generateData(dataType)).toList
case (x, y) =>
1L.to(y).map(_ => x).toList
}
case None => List(generateData(dataType))
val flatLen = sh.map(_.max(1)).product
(1L to flatLen).map(_ => generator)
case None => List(generator)
}
}

def generateTensor(tensorInfo: TensorInfo): TensorProto = {
val tensor = TensorProto(dtype = tensorInfo.dtype, tensorShape = tensorInfo.tensorShape)
val data = generateData(tensorInfo.dtype, tensorInfo.tensorShape)
val typedTensor = TypedTensor(tensor)
typedTensor.putAny(data).right.get
def generateTensor(shape: TensorShape, dtype: DataType): Option[TypedTensor[_]] = {
val factory = TypedTensorFactory(dtype)
val data = createFlatTensor(shape, generateScalarData(dtype))
val s = factory.createFromAny(data, shape)
println(s.map(_.toProto))
s
}

def generateField(field: ModelField): Map[String, TensorProto] = {
val tensor = field.infoOrSubfields match {
case Empty => TensorProto()
case Info(value) => generateTensor(value)
case Subfields(value) =>
val tensor = TensorProto(dtype = DataType.DT_MAP)
tensor.withMapVal(
value.data.flatMap(generateField).toMap
)
def generateField(field: ModelField): Map[String, TypedTensor[_]] = {
val shape = TensorShape.fromProto(field.shape)
val fieldValue = field.typeOrSubfields match {
case Empty => None
case Dtype(value) => generateTensor(shape, value)
case Subfields(value) => generateNestedTensor(shape, value)
}
Map(field.fieldName -> tensor)
fieldValue.map(x => field.name -> x).toMap
}

private def generateNestedTensor(shape: TensorShape, value: ModelField.Subfield): Option[MapTensor] = {
val map = generateMap(value)
val tensorData = createFlatTensor(shape, map)
Some(MapTensor(shape, tensorData))
}

private def generateMap(value: ModelField.Subfield): Map[String, TypedTensor[_]] = {
value.data.flatMap(generateField).toMap
}
}

This file was deleted.

Loading

0 comments on commit f3fe4a6

Please sign in to comment.