Skip to content

Commit

Permalink
Add codec support to ser / deser pre-aggregated IR tiles (#523)
Browse files Browse the repository at this point in the history
* Add codec support to ser / deser pre-aggregated IR tiles

* Remove trailing comma for Scala 2.11

* Use unpack instead of UnpackedAggregations
  • Loading branch information
piyushn-stripe authored Jul 25, 2023
1 parent 45b9cae commit 8c464a8
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
57 changes: 57 additions & 0 deletions online/src/main/scala/ai/chronon/online/TileCodec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package ai.chronon.online

import ai.chronon.aggregator.row.RowAggregator
import ai.chronon.api.{BooleanType, DataType, GroupBy, StructType}
import org.apache.avro.generic.GenericData
import ai.chronon.api.Extensions.{AggregationOps, MetadataOps}

import scala.collection.JavaConverters._

object TileCodec {
def buildRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = {
// a set of Chronon groupBy aggregations needs to be flatted out to get the
// full cross-product of all the feature column aggregations to be computed
val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unpack)
new RowAggregator(inputSchema, unpackedAggs)
}
}

/**
* TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values.
* These pre-aggregated tiles can be used in the serving layer to compute the final feature values along
* with batch pre-aggregates produced by GroupByUploads.
* The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates)
*/
class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) {
val windowedIrSchema: StructType = StructType.from("WindowedIr", rowAggregator.irSchema)
val fields: Array[(String, DataType)] = Array(
"collapsedIr" -> windowedIrSchema,
"isComplete" -> BooleanType
)

val tileChrononSchema: StructType =
StructType.from(s"${groupBy.metaData.cleanName}_TILE_IR", fields)
val tileAvroSchema: String = AvroConversions.fromChrononSchema(tileChrononSchema).toString()
val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema)
private val irToBytesFn = AvroConversions.encodeBytes(tileChrononSchema, null)

def makeTileIr(ir: Array[Any], isComplete: Boolean): Array[Byte] = {
val normalizedIR = rowAggregator.normalize(ir)
val tileIr: Array[Any] = Array(normalizedIR, Boolean.box(isComplete))
irToBytesFn(tileIr)
}

def decodeTileIr(tileIr: Array[Byte]): (Array[Any], Boolean) = {
val decodedTileIr = tileAvroCodec.decode(tileIr)
val collapsedIr = decodedTileIr
.get("collapsedIr")
.asInstanceOf[GenericData.Record]

val ir = AvroConversions
.toChrononRow(collapsedIr, windowedIrSchema)
.asInstanceOf[Array[Any]]
val denormalizedIr = rowAggregator.denormalize(ir)
val isComplete = decodedTileIr.get("isComplete").asInstanceOf[Boolean]
(denormalizedIr, isComplete)
}
}
89 changes: 89 additions & 0 deletions online/src/test/scala/ai/chronon/online/TileCodecTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package ai.chronon.online

import ai.chronon.api.{Aggregation, Builders, FloatType, IntType, ListType, LongType, Operation, Row, StringType, TimeUnit, Window}
import org.junit.Assert.assertEquals
import org.junit.Test
import scala.collection.JavaConverters._

class TileCodecTest {
private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava

private val aggregationsAndExpected: Array[(Aggregation, Any)] = Array(
Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> 16.0,
Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 4.0,

Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 12.0f,
Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(7, TimeUnit.DAYS))) -> 12.0f,

Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L,

Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "C",
Builders.Aggregation(Operation.LAST, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "C",

Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava,
Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava,

Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava,
Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava,

Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "A",
Builders.Aggregation(Operation.MIN, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "A",

Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L,
Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L,

Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram,
Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram
)

private val schema = List(
"created" -> LongType,
"views" -> IntType,
"rating" -> FloatType,
"title" -> StringType,
"hist_input" -> ListType(StringType)
)

@Test
def testTileCodecIrSerRoundTrip(): Unit = {
val groupByMetadata = Builders.MetaData(name = "my_group_by")
val (aggregations, expectedVals) = aggregationsAndExpected.unzip
val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations)
val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema)
val rowIR = rowAggregator.init
val tileCodec = new TileCodec(rowAggregator, groupBy)

val originalIsComplete = true
val rows = Seq(
createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")),
createRow(1519862399984L, 40, 5.0f, "B", Seq()),
createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C"))
)
rows.foreach(row => rowAggregator.update(rowIR, row))
val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete)
assert(bytes.length > 0)

val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes)
assert(isComplete == originalIsComplete)

// lets finalize the payload intermediate results and verify things
val finalResults = rowAggregator.finalize(deserPayload)
expectedVals.zip(finalResults).zip(rowAggregator.outputSchema.map(_._1)).foreach {
case ((expected, actual), name) =>
println(s"Checking: $name")
assertEquals(expected, actual)
}
}

def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = {
val values: Array[(String, Any)] = Array(
"created" -> ts,
"views" -> views,
"rating" -> rating,
"title" -> title,
"hist_input" -> histInput
)
new ArrayRow(values.map(_._2), ts)
}
}

0 comments on commit 8c464a8

Please sign in to comment.