diff --git a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala index b02183989..b1847d46c 100644 --- a/spark/src/main/scala/ai/chronon/spark/Analyzer.scala +++ b/spark/src/main/scala/ai/chronon/spark/Analyzer.scala @@ -183,17 +183,17 @@ class Analyzer(tableUtils: TableUtils, val schema = if (groupByConf.isSetBackfillStartDate && groupByConf.hasDerivations) { // handle group by backfill mode for derivations // todo: add the similar logic to join derivations + val keyAndPartitionFields = + groupBy.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) val sparkSchema = { - StructType( - SparkConversions.fromChrononSchema(groupBy.outputSchema).fields ++ groupBy.keySchema.fields ++ Seq( - org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType))) + StructType(SparkConversions.fromChrononSchema(groupBy.outputSchema).fields ++ keyAndPartitionFields) } val dummyOutputDf = tableUtils.sparkSession .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) val finalOutputColumns = groupByConf.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) val columns = SparkConversions.toChrononSchema( - StructType(derivedDummyOutputDf.schema.filterNot(groupBy.keySchema.fields.contains))) + StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) } else { groupBy.outputSchema diff --git a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala index 83a60a798..02d71ee80 100644 --- a/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala +++ b/spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala @@ -70,18 +70,18 @@ object BootstrapInfo { .toChrononSchema(gb.keySchema) .map(field => StructField(part.rightToLeft(field._1), field._2)) + val keyAndPartitionFields = + gb.keySchema.fields ++ Seq(org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType)) val outputSchema = if (part.groupBy.hasDerivations) { val sparkSchema = { - StructType( - SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ gb.keySchema.fields ++ Seq( - org.apache.spark.sql.types.StructField(tableUtils.partitionColumn, StringType))) + StructType(SparkConversions.fromChrononSchema(gb.outputSchema).fields ++ keyAndPartitionFields) } val dummyOutputDf = tableUtils.sparkSession .createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema) val finalOutputColumns = part.groupBy.derivationsScala.finalOutputColumn(dummyOutputDf.columns).toSeq val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*) val columns = SparkConversions.toChrononSchema( - StructType(derivedDummyOutputDf.schema.filterNot(gb.keySchema.fields.contains))) + StructType(derivedDummyOutputDf.schema.filterNot(keyAndPartitionFields.contains))) api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2))) } else { gb.outputSchema diff --git a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala index 3d4f4afe0..b8230dcd3 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala @@ -29,20 +29,20 @@ class DerivationTest { val namespace = "test_derivations" spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace") val groupBy = BootstrapUtils.buildGroupBy(namespace, spark) - .setDerivations(Seq( - Builders.Derivation(name = "user_amount_30d_avg", - expression = "amount_dollars_sum_30d / 30"), - Builders.Derivation( - name = "*" - ), - ).toJava) + val groupByWithDerivation = groupBy.setDerivations(Seq( + Builders.Derivation(name = "user_amount_30d_avg", + expression = "amount_dollars_sum_30d / 30"), + Builders.Derivation( + name = "*" + ), + ).toJava) val queryTable = BootstrapUtils.buildQuery(namespace, spark) val baseJoin = Builders.Join( left = Builders.Source.events( table = queryTable, query = Builders.Query() ), - joinParts = Seq(Builders.JoinPart(groupBy = groupBy)), + joinParts = Seq(Builders.JoinPart(groupBy = groupByWithDerivation)), rowIds = Seq("request_id"), externalParts = Seq( Builders.ExternalPart( @@ -117,7 +117,7 @@ class DerivationTest { val outputDf = runner.computeJoin() assertTrue( - outputDf.columns sameElements Array( + outputDf.columns.toSet == Set( "user", "request_id", "ts",