Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
pengyu-hou committed Oct 3, 2023
1 parent 87548f7 commit 3c7fd33
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,17 @@ class Analyzer(tableUtils: TableUtils,
val schema = if (groupByConf.isSetBackfillStartDate && groupByConf.hasDerivations) {
// handle group by backfill mode for derivations
// todo: add the similar logic to join derivations
val keyAndPartitionFields =
groupBy.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType))
val sparkSchema = {
StructType(
SparkConversions.fromChrononSchema(groupBy.outputSchema).fields ++ groupBy.keySchema.fields ++ Seq(
org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)))
StructType(SparkConversions.fromChrononSchema(groupBy.outputSchema).fields ++ keyAndPartitionFields)
}
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns = groupByConf.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(
StructType(derivedDummyOutputDf.schema.filterNot(groupBy.keySchema.fields.contains)))
StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains)))
api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2)))
} else {
groupBy.outputSchema
Expand Down
8 changes: 4 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ object BootstrapInfo {
.toChrononSchema(gb.keySchema)
.map(field => StructField(part.rightToLeft(field._1), field._2))

val keyAndPartitionFields =
gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType))
val outputSchema = if (part.groupBy.hasDerivations) {
val sparkSchema = {
StructType(
SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ gb.keySchema.fields ++ Seq(
org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)))
StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields)
}
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(
StructType(derivedDummyOutputDf.schema.filterNot(gb.keySchema.fields.contains)))
StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains)))
api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2)))
} else {
gb.outputSchema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ class DerivationTest {
val namespace = "test_derivations"
spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace")
val groupBy = BootstrapUtils.buildGroupBy(namespace, spark)
.setDerivations(Seq(
Builders.Derivation(name = "user_amount_30d_avg",
expression = "amount_dollars_sum_30d / 30"),
Builders.Derivation(
name = "*"
),
).toJava)
val groupByWithDerivation = groupBy.setDerivations(Seq(
Builders.Derivation(name = "user_amount_30d_avg",
expression = "amount_dollars_sum_30d / 30"),
Builders.Derivation(
name = "*"
),
).toJava)
val queryTable = BootstrapUtils.buildQuery(namespace, spark)
val baseJoin = Builders.Join(
left = Builders.Source.events(
table = queryTable,
query = Builders.Query()
),
joinParts = Seq(Builders.JoinPart(groupBy = groupBy)),
joinParts = Seq(Builders.JoinPart(groupBy = groupByWithDerivation)),
rowIds = Seq("request_id"),
externalParts = Seq(
Builders.ExternalPart(
Expand Down Expand Up @@ -117,7 +117,7 @@ class DerivationTest {
val outputDf = runner.computeJoin()

assertTrue(
outputDf.columns sameElements Array(
outputDf.columns.toSet == Set(
"user",
"request_id",
"ts",
Expand Down

0 comments on commit 3c7fd33

Please sign in to comment.