Skip to content

Commit

Permalink
use key schema and partition column
Browse files Browse the repository at this point in the history
  • Loading branch information
pengyu-hou committed Oct 3, 2023
1 parent 7080d35 commit 87548f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
13 changes: 9 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.yahoo.sketches.ArrayOfStringsSerDe
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch}
import org.apache.spark.sql.{DataFrame, Row, types}
import org.apache.spark.sql.functions.{col, from_unixtime, lit}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{StringType, StructType}
import ai.chronon.aggregator.row.StatsGenerator
import ai.chronon.api.DataModel.{DataModel, Entities, Events}

Expand Down Expand Up @@ -183,12 +183,17 @@ class Analyzer(tableUtils: TableUtils,
val schema = if (groupByConf.isSetBackfillStartDate && groupByConf.hasDerivations) {
// handle group by backfill mode for derivations
// todo: add the similar logic to join derivations
val sparkSchema = SparkConversions.fromChrononSchema(groupBy.outputSchema)
val sparkSchema = {
StructType(
SparkConversions.fromChrononSchema(groupBy.outputSchema).fields ++ groupBy.keySchema.fields ++ Seq(
org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)))
}
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns = groupByConf.derivations.toScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val finalOutputColumns = groupByConf.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(derivedDummyOutputDf.schema)
val columns = SparkConversions.toChrononSchema(
StructType(derivedDummyOutputDf.schema.filterNot(groupBy.keySchema.fields.contains)))
api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2)))
} else {
groupBy.outputSchema
Expand Down
11 changes: 8 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ai.chronon.online.SparkConversions
import ai.chronon.spark.Extensions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}

import scala.collection.{Seq, immutable, mutable}
import scala.util.ScalaJavaConversions.ListOps
Expand Down Expand Up @@ -71,12 +71,17 @@ object BootstrapInfo {
.map(field => StructField(part.rightToLeft(field._1), field._2))

val outputSchema = if (part.groupBy.hasDerivations) {
val sparkSchema = SparkConversions.fromChrononSchema(gb.outputSchema)
val sparkSchema = {
StructType(
SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ gb.keySchema.fields ++ Seq(
org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)))
}
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(derivedDummyOutputDf.schema)
val columns = SparkConversions.toChrononSchema(
StructType(derivedDummyOutputDf.schema.filterNot(gb.keySchema.fields.contains)))
api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2)))
} else {
gb.outputSchema
Expand Down

0 comments on commit 87548f7

Please sign in to comment.