-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add codec support to ser / deser pre-aggregated IR tiles (#523)
* Add codec support to ser / deser pre-aggregated IR tiles * Remove trailing comma for Scala 2.11 * Use unpack instead of UnpackedAggregations
- Loading branch information
1 parent
45b9cae
commit 8c464a8
Showing
2 changed files
with
146 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package ai.chronon.online | ||
|
||
import ai.chronon.aggregator.row.RowAggregator | ||
import ai.chronon.api.{BooleanType, DataType, GroupBy, StructType} | ||
import org.apache.avro.generic.GenericData | ||
import ai.chronon.api.Extensions.{AggregationOps, MetadataOps} | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
object TileCodec { | ||
def buildRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = { | ||
// a set of Chronon groupBy aggregations needs to be flatted out to get the | ||
// full cross-product of all the feature column aggregations to be computed | ||
val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unpack) | ||
new RowAggregator(inputSchema, unpackedAggs) | ||
} | ||
} | ||
|
||
/** | ||
* TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values. | ||
* These pre-aggregated tiles can be used in the serving layer to compute the final feature values along | ||
* with batch pre-aggregates produced by GroupByUploads. | ||
* The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates) | ||
*/ | ||
class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { | ||
val windowedIrSchema: StructType = StructType.from("WindowedIr", rowAggregator.irSchema) | ||
val fields: Array[(String, DataType)] = Array( | ||
"collapsedIr" -> windowedIrSchema, | ||
"isComplete" -> BooleanType | ||
) | ||
|
||
val tileChrononSchema: StructType = | ||
StructType.from(s"${groupBy.metaData.cleanName}_TILE_IR", fields) | ||
val tileAvroSchema: String = AvroConversions.fromChrononSchema(tileChrononSchema).toString() | ||
val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema) | ||
private val irToBytesFn = AvroConversions.encodeBytes(tileChrononSchema, null) | ||
|
||
def makeTileIr(ir: Array[Any], isComplete: Boolean): Array[Byte] = { | ||
val normalizedIR = rowAggregator.normalize(ir) | ||
val tileIr: Array[Any] = Array(normalizedIR, Boolean.box(isComplete)) | ||
irToBytesFn(tileIr) | ||
} | ||
|
||
def decodeTileIr(tileIr: Array[Byte]): (Array[Any], Boolean) = { | ||
val decodedTileIr = tileAvroCodec.decode(tileIr) | ||
val collapsedIr = decodedTileIr | ||
.get("collapsedIr") | ||
.asInstanceOf[GenericData.Record] | ||
|
||
val ir = AvroConversions | ||
.toChrononRow(collapsedIr, windowedIrSchema) | ||
.asInstanceOf[Array[Any]] | ||
val denormalizedIr = rowAggregator.denormalize(ir) | ||
val isComplete = decodedTileIr.get("isComplete").asInstanceOf[Boolean] | ||
(denormalizedIr, isComplete) | ||
} | ||
} |
89 changes: 89 additions & 0 deletions
89
online/src/test/scala/ai/chronon/online/TileCodecTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package ai.chronon.online | ||
|
||
import ai.chronon.api.{Aggregation, Builders, FloatType, IntType, ListType, LongType, Operation, Row, StringType, TimeUnit, Window} | ||
import org.junit.Assert.assertEquals | ||
import org.junit.Test | ||
import scala.collection.JavaConverters._ | ||
|
||
class TileCodecTest { | ||
private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava | ||
|
||
private val aggregationsAndExpected: Array[(Aggregation, Any)] = Array( | ||
Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> 16.0, | ||
Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 4.0, | ||
|
||
Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 12.0f, | ||
Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(7, TimeUnit.DAYS))) -> 12.0f, | ||
|
||
Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, | ||
Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, | ||
|
||
Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "C", | ||
Builders.Aggregation(Operation.LAST, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "C", | ||
|
||
Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, | ||
Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, | ||
|
||
Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, | ||
Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, | ||
|
||
Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "A", | ||
Builders.Aggregation(Operation.MIN, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "A", | ||
|
||
Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, | ||
Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, | ||
|
||
Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram, | ||
Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram | ||
) | ||
|
||
private val schema = List( | ||
"created" -> LongType, | ||
"views" -> IntType, | ||
"rating" -> FloatType, | ||
"title" -> StringType, | ||
"hist_input" -> ListType(StringType) | ||
) | ||
|
||
@Test | ||
def testTileCodecIrSerRoundTrip(): Unit = { | ||
val groupByMetadata = Builders.MetaData(name = "my_group_by") | ||
val (aggregations, expectedVals) = aggregationsAndExpected.unzip | ||
val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) | ||
val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema) | ||
val rowIR = rowAggregator.init | ||
val tileCodec = new TileCodec(rowAggregator, groupBy) | ||
|
||
val originalIsComplete = true | ||
val rows = Seq( | ||
createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")), | ||
createRow(1519862399984L, 40, 5.0f, "B", Seq()), | ||
createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C")) | ||
) | ||
rows.foreach(row => rowAggregator.update(rowIR, row)) | ||
val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete) | ||
assert(bytes.length > 0) | ||
|
||
val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes) | ||
assert(isComplete == originalIsComplete) | ||
|
||
// lets finalize the payload intermediate results and verify things | ||
val finalResults = rowAggregator.finalize(deserPayload) | ||
expectedVals.zip(finalResults).zip(rowAggregator.outputSchema.map(_._1)).foreach { | ||
case ((expected, actual), name) => | ||
println(s"Checking: $name") | ||
assertEquals(expected, actual) | ||
} | ||
} | ||
|
||
def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = { | ||
val values: Array[(String, Any)] = Array( | ||
"created" -> ts, | ||
"views" -> views, | ||
"rating" -> rating, | ||
"title" -> title, | ||
"hist_input" -> histInput | ||
) | ||
new ArrayRow(values.map(_._2), ts) | ||
} | ||
} |