diff --git a/README.md b/README.md index e88fa62..cce26db 100644 --- a/README.md +++ b/README.md @@ -438,6 +438,45 @@ DeltaHelpers.deltaNumRecordDistribution(path) +------------------------------------------------+--------------------+------------------+--------------------+--------------------+------------------+---------------------------------------------------------+ ``` +## Number of Shuffle Files in Merge & Other Filter Conditions + +The function `getNumShuffleFiles` gets the number of shuffle files (think of part files in parquet) that will be pulled into memory for a given filter condition. This is particularly useful to estimate memory requirements in a Delta Merge operation where the number of shuffle files can be a bottleneck. +To better tune your jobs, you can use this function to get the number of shuffle files for different kinds of filter condition and then perform operations like merge, zorder, compaction etc. to see if you reach the desired no. of shuffle files. + + +For example, if the condition is "country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' " and country is the partition column, +```scala +DeltaHelpers.getNumShuffleFiles(path, "country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' ") +``` + +then the output might look like following (explaining different parts of the condition as a key in the `Map` and the value contains the file count) +```scala +Map( + // number of files that will be pulled into memory for the entire provided condition + "OVERALL RESOLVED CONDITION => [ (country = 'GBR') and (age >= 30) and" + + " (age = 40) and firstname LIKE '%Joh%' ]" -> 18, + // number of files signifying the greater than/less than part => "age >= 30 and age <= 40" + "GREATER THAN / LESS THAN PART => [ (age >= 30) and (age = 40) ]" -> 100, + // number of files signifying the equals part => "country = 'GBR' + "EQUALS/EQUALS NULL SAFE PART => [ (country = 'GBR') ]" -> 300, + // number of files signifying the like (or any other) part => "firstname like '%Jo%' " + "LEFT OVER PART => [ firstname LIKE '%Joh%' ]" -> 600, + // number of files signifying any other part. This is mostly a failsafe + // 1. to capture any other condition that might have been missed + // 2. If wrong attribute names or conditions are provided like snapshot.id = source.id (usually found in merge conditions) + "UNRESOLVED PART => [ (snapshot.id = update.id) ]" -> 800, + // Total no. of files in the Delta Table + "TOTAL_NUM_FILES_IN_DELTA_TABLE =>" -> 800, + // List of unresolved columns/attributes in the provided condition. + // Will be empty if all columns are resolved. + "UNRESOLVED_COLUMNS =>" -> List()) +``` + +Another important use case this method can help with is to see the min-max range overlap. Adding a min max on a high cardinality column like id say `id >= 900 and id <= 5000` can actually help in reducing the no. of shuffle files delta lake pulls into memory. However, such a operation is not always guaranteed to work and the effect can be viewed when you run this method. + +This function works only on the Delta Log and does not scan any data in the Delta Table. + +If you want more information about these individual files and their metadata, consider using the `getShuffleFileMetadata` function. ## Change Data Feed Helpers ### CASE I - When Delta aka Transaction Log gets purged diff --git a/src/main/scala/mrpowers/jodie/DeltaHelpers.scala b/src/main/scala/mrpowers/jodie/DeltaHelpers.scala index 2afc5a6..a26d362 100644 --- a/src/main/scala/mrpowers/jodie/DeltaHelpers.scala +++ b/src/main/scala/mrpowers/jodie/DeltaHelpers.scala @@ -2,10 +2,13 @@ package mrpowers.jodie import io.delta.tables._ import mrpowers.jodie.delta.DeltaConstants._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.delta.DeltaLog +import org.apache.spark.sql.delta.actions.AddFile import org.apache.spark.sql.expressions.Window.partitionBy import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import scala.collection.mutable @@ -71,6 +74,125 @@ object DeltaHelpers { def deltaNumRecordDistribution(path: String, condition: Option[String] = None): DataFrame = getAllPartitionStats(deltaFileStats(path, condition), statsPartitionColumn, numRecordsColumn).toDF(numRecordsDFColumns: _*) + /** + * Gets the number of shuffle files (part files for parquet) that will be pulled into memory for a given filter condition. + * This is particularly useful in a Delta Merge operation where the number of shuffle files can be a bottleneck. Running + * the merge condition through this method can give an idea about the amount of memory resources required to run the merge. + * + * For example, if the condition is "snapshot.id = update.id and country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' " + * and country is the partition column, then the output might look like => + * Map( + * OVERALL RESOLVED CONDITION => [ (country = 'GBR') and (age >= 30) and (age <= 40) and firstname LIKE '%Joh%' ] -> 18, + * GREATER THAN / LESS THAN PART => [ (age >= 30) and (age <= 40) ] -> 100, + * EQUALS/EQUALS NULL SAFE PART => [ (country = 'GBR') ] -> 300, + * LEFT OVER PART => [ firstname LIKE '%Joh%' ] -> 600, + * UNRESOLVED PART => [ (snapshot.id = update.id) ] -> 800, + * TOTAL_NUM_FILES_IN_DELTA_TABLE => -> 800, + * UNRESOLVED_COLUMNS => -> List(snapshot.id, update.id)) + * + * 18 - number of files that will be pulled into memory for the entire provided condition + * 100 - number of files signifying the greater than/less than part => "age >= 30 and age <= 40" + * 300 - number of files signifying the equals part => "country = 'GBR' + * 600 - number of files signifying the like (or any other) part => "firstname like '%Jo%' " + * 800 - number of files signifying any other part. This is mostly a failsafe + * 1. to capture any other condition that might have been missed + * 2. If wrong attribute names or conditions are provided like snapshot.id = source.id (usually found in merge conditions) + * 800 - Total no. of files in the Delta Table without any filter condition or partitions + * List() - List of unresolved columns/attributes in the provided condition + * Note: Whenever a resolved condition comes back as Empty, the output will contain number of files in the entire Delta Table and can be ignored + * This function works only on the Delta Log and does not scan any data in the Delta Table. + * + * @param path + * @param condition + * @return + */ + def getNumShuffleFiles(path: String, condition: String) = { + val (deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions, + otherExpressions, removedPredicates) = getResolvedExpressions(path, condition) + deltaLog.withNewTransaction { deltaTxn => + Map(s"$OVERALL [ ${formatSQL(targetOnlyPredicates).getOrElse("Empty")} ]" -> + deltaTxn.filterFiles(targetOnlyPredicates).count(a => true), + s"$MIN_MAX [ ${formatSQL(minMaxOnlyExpressions).getOrElse("Empty")} ]" -> + deltaTxn.filterFiles(minMaxOnlyExpressions).count(a => true), + s"$EQUALS [ ${formatSQL(equalOnlyExpressions).getOrElse("Empty")} ]" -> + deltaTxn.filterFiles(equalOnlyExpressions).count(a => true), + s"$LEFT_OVER [ ${formatSQL(otherExpressions).getOrElse("Empty")} ]" -> + deltaTxn.filterFiles(otherExpressions).count(a => true), + s"$UNRESOLVED [ ${formatSQL(removedPredicates).getOrElse("Empty")} ]" -> + deltaTxn.filterFiles(removedPredicates).count(a => true), + TOTAL_NUM_FILES -> deltaLog.snapshot.filesWithStatsForScan(Nil).count(), + UNRESOLVED_COLS -> unresolvedColumns) + } + } + + def getShuffleFileMetadata(path: String, condition: String): + (Seq[AddFile], Seq[AddFile], Seq[AddFile], Seq[AddFile], Seq[AddFile], DataFrame, Seq[String]) = { + val (deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions, otherExpressions, removedPredicates) = getResolvedExpressions(path, condition) + deltaLog.withNewTransaction { deltaTxn => + (deltaTxn.filterFiles(targetOnlyPredicates), + deltaTxn.filterFiles(minMaxOnlyExpressions), + deltaTxn.filterFiles(equalOnlyExpressions), + deltaTxn.filterFiles(otherExpressions), + deltaTxn.filterFiles(removedPredicates), + deltaLog.snapshot.filesWithStatsForScan(Nil), + unresolvedColumns) + } + } + private def getResolvedExpressions(path: String, condition: String) = { + val spark = SparkSession.active + val deltaTable = DeltaTable.forPath(path) + val deltaLog = DeltaLog.forTable(spark, path) + + val expression = functions.expr(condition).expr + val targetPlan = deltaTable.toDF.queryExecution.analyzed + val resolvedExpression: Expression = spark.sessionState.analyzer.resolveExpressionByPlanOutput(expression, targetPlan, true) + val unresolvedColumns = if (!resolvedExpression.childrenResolved) { + resolvedExpression.references.filter(a => a match { + case b: UnresolvedAttribute => true + case _ => false + }).map(a => a.asInstanceOf[UnresolvedAttribute].sql).toSeq + } else Seq() + + def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case And(cond1, cond2) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } + } + + val splitExpressions = splitConjunctivePredicates(resolvedExpression) + val targetOnlyPredicates = splitExpressions.filter(_.references.subsetOf(targetPlan.outputSet)) + + val minMaxOnlyExpressions = targetOnlyPredicates.filter(e => e match { + case GreaterThanOrEqual(_, _) => true + case LessThanOrEqual(_, _) => true + case LessThan(_, _) => true + case GreaterThan(_, _) => true + case _ => false + }) + + val equalOnlyExpressions = targetOnlyPredicates.filter(e => e match { + case EqualTo(_, _) => true + case EqualNullSafe(_, _) => true + case _ => false + }) + + val otherExpressions = targetOnlyPredicates.filter(e => e match { + case EqualTo(_, _) => false + case EqualNullSafe(_, _) => false + case GreaterThanOrEqual(_, _) => false + case LessThanOrEqual(_, _) => false + case LessThan(_, _) => false + case GreaterThan(_, _) => false + case _ => true + }) + + val removedPredicates = splitExpressions.filterNot(_.references.subsetOf(targetPlan.outputSet)) + + (deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions, otherExpressions, removedPredicates) + } + private def getAllPartitionStats(filteredDF: DataFrame, groupByCol: String, aggCol: String) = filteredDF .groupBy(map_entries(col(groupByCol))) @@ -92,7 +214,7 @@ object DeltaHelpers { val snapshot = tableLog.snapshot condition match { case None => snapshot.filesWithStatsForScan(Nil) - case Some(value) => snapshot.filesWithStatsForScan(Seq(expr(value).expr)) + case Some(value) => snapshot.filesWithStatsForScan(Seq(functions.expr(value).expr)) } } diff --git a/src/main/scala/mrpowers/jodie/delta/DeltaConstants.scala b/src/main/scala/mrpowers/jodie/delta/DeltaConstants.scala index 110e27d..5265cf5 100644 --- a/src/main/scala/mrpowers/jodie/delta/DeltaConstants.scala +++ b/src/main/scala/mrpowers/jodie/delta/DeltaConstants.scala @@ -1,5 +1,7 @@ package mrpowers.jodie.delta +import org.apache.spark.sql.catalyst.expressions.Expression + object DeltaConstants { val sizeColumn = "size" val numRecordsColumn = "stats.numRecords" @@ -27,4 +29,16 @@ object DeltaConstants { "max_file_size", percentileCol ) + val OVERALL = "OVERALL RESOLVED CONDITION =>" + val MIN_MAX = "GREATER THAN / LESS THAN PART =>" + val EQUALS = "EQUALS/EQUALS NULL SAFE PART =>" + val LEFT_OVER = "LEFT OVER PART =>" + val UNRESOLVED = "UNRESOLVED PART =>" + val TOTAL_NUM_FILES = "TOTAL_NUM_FILES_IN_DELTA_TABLE =>" + val UNRESOLVED_COLS = "UNRESOLVED_COLUMNS =>" + + def formatSQL(expressions: Seq[Expression]) = expressions.isEmpty match { + case true => None + case false => Some(expressions.map(a => a.sql).reduce(_ + " and " + _)) + } } diff --git a/src/test/scala/mrpowers/jodie/DeltaHelperSpec.scala b/src/test/scala/mrpowers/jodie/DeltaHelperSpec.scala index 7085235..3b95674 100644 --- a/src/test/scala/mrpowers/jodie/DeltaHelperSpec.scala +++ b/src/test/scala/mrpowers/jodie/DeltaHelperSpec.scala @@ -3,6 +3,7 @@ package mrpowers.jodie import com.github.mrpowers.spark.daria.sql.SparkSessionExt.SparkSessionMethods import com.github.mrpowers.spark.fast.tests.DataFrameComparer import io.delta.tables.DeltaTable +import mrpowers.jodie.delta.DeltaConstants._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.{IntegerType, StringType} import org.scalatest.BeforeAndAfterEach @@ -27,10 +28,10 @@ class DeltaHelperSpec describe("When Delta table is queried for file sizes") { it("should provide delta file sizes successfully") { val path = (os.pwd / "tmp" / "delta-table").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) - val actual = DeltaHelpers.deltaFileSizes(deltaTable) + val actual = DeltaHelpers.deltaFileSizes(deltaTable) actual("size_in_bytes") should equal(1088L) actual("number_of_files") should equal(1L) @@ -38,7 +39,8 @@ class DeltaHelperSpec } it("should not fail if the table is empty") { - val emptyDeltaTable = DeltaTable.create(spark) + val emptyDeltaTable = DeltaTable + .create(spark) .tableName("delta_empty_table") .addColumn("id", dataType = "INT") .addColumn("firstname", dataType = "STRING") @@ -53,7 +55,7 @@ class DeltaHelperSpec describe("remove duplicate records from delta table") { it("should remove duplicates successful") { val path = (os.pwd / "tmp" / "delta-duplicate").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) val duplicateColumns = Seq("firstname", "lastname") @@ -150,7 +152,7 @@ class DeltaHelperSpec describe("remove duplicate records from delta table using primary key") { it("should remove duplicates given a primary key and duplicate columns") { val path = (os.pwd / "tmp" / "delta-duplicate-pk").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) val duplicateColumns = Seq("lastname") @@ -173,7 +175,7 @@ class DeltaHelperSpec it("should fail to remove duplicates when not duplicate columns is provided") { val path = (os.pwd / "tmp" / "delta-pk-not-duplicate-columns").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) val primaryKey = "id" @@ -220,7 +222,7 @@ class DeltaHelperSpec it("should fail to remove duplicate when not primary key is provided") { val path = (os.pwd / "tmp" / "delta-duplicate-no-pk").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) val primaryKey = "" @@ -233,7 +235,7 @@ class DeltaHelperSpec it("should fail to remove duplicate when duplicateColumns does not exist in table") { val path = (os.pwd / "tmp" / "delta-duplicate-cols-no-exists").toString() - createBaseDeltaTable(path) + createBaseDeltaTable(path, rows) val deltaTable = DeltaTable.forPath(path) val primaryKey = "id" @@ -379,7 +381,7 @@ class DeltaHelperSpec it("should create a new delta table from an existing one using path") { val path = (os.pwd / "tmp" / "delta-copy-from-existing-path").toString() - val df = createBaseDeltaTableWithPartitions(path, Seq("lastname", "firstname")) + val df = createBaseDeltaTableWithPartitions(path, Seq("lastname", "firstname"), partRows) val deltaTable = DeltaTable.forPath(path) val targetPath = (os.pwd / "tmp" / "delta-copy-from-existing-target-path").toString() DeltaHelpers.copyTable(deltaTable, targetPath = Some(targetPath)) @@ -395,9 +397,9 @@ class DeltaHelperSpec it("should copy table from existing one using table name") { val path = (os.pwd / "tmp" / "delta-copy-from-existing-tb-name").toString() - val df: DataFrame = createBaseDeltaTableWithPartitions(path,Seq("lastname")) - val deltaTable = DeltaTable.forPath(path) - val tableName = "students" + val df: DataFrame = createBaseDeltaTableWithPartitions(path, Seq("lastname"), partRows) + val deltaTable = DeltaTable.forPath(path) + val tableName = "students" DeltaHelpers.copyTable(deltaTable, targetTableName = Some(tableName)) assertSmallDataFrameEquality( DeltaTable.forName(spark, tableName).toDF, @@ -410,8 +412,8 @@ class DeltaHelperSpec it("should fail to copy when no table name or target path is set") { val path = (os.pwd / "tmp" / "delta-copy-non-destination").toString() - val df: DataFrame = createBaseDeltaTableWithPartitions(path,Seq("lastname")) - val deltaTable = DeltaTable.forPath(path) + val df: DataFrame = createBaseDeltaTableWithPartitions(path, Seq("lastname"), partRows) + val deltaTable = DeltaTable.forPath(path) val exceptionMessage = intercept[JodieValidationError] { DeltaHelpers.copyTable(deltaTable) }.getMessage @@ -420,11 +422,11 @@ class DeltaHelperSpec } it("should fail to copy when both table name and target path are set") { - val path = (os.pwd / "tmp" / "delta-copy-two-destination").toString() - val df: DataFrame = createBaseDeltaTableWithPartitions(path,Seq("lastname")) - val deltaTable = DeltaTable.forPath(path) - val tableName = "students" - val tablePath = (os.pwd / "tmp" / "delta-copy-from-existing-target-path").toString() + val path = (os.pwd / "tmp" / "delta-copy-two-destination").toString() + val df: DataFrame = createBaseDeltaTableWithPartitions(path, Seq("lastname"), partRows) + val deltaTable = DeltaTable.forPath(path) + val tableName = "students" + val tablePath = (os.pwd / "tmp" / "delta-copy-from-existing-target-path").toString() val exceptionMessage = intercept[JodieValidationError] { DeltaHelpers.copyTable(deltaTable, Some(tablePath), Some(tableName)) }.getMessage @@ -552,7 +554,7 @@ class DeltaHelperSpec .save(path) val deltaTable = DeltaTable.forPath(path) - val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable, Seq("id")) + val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable, Seq("id")) assertResult(Nil)(result) } @@ -571,8 +573,8 @@ class DeltaHelperSpec .mode("overwrite") .save(path) val deltaTable = DeltaTable.forPath(path) - val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable, Seq("id")) - val expected = Seq("firstname", "lastname") + val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable, Seq("id")) + val expected = Seq("firstname", "lastname") assertResult(expected)(result) } @@ -591,8 +593,8 @@ class DeltaHelperSpec .option("delta.logRetentionDuration", "interval 30 days") .save(path) val deltaTable = DeltaTable.forPath(path) - val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable) - val expected = Seq("id") + val result = DeltaHelpers.findCompositeKeyCandidate(deltaTable) + val expected = Seq("id") assertResult(expected)(result) } } @@ -644,67 +646,297 @@ class DeltaHelperSpec (3, "Jose", "Travolta", "1f1ac7f74f43eff911a92f7e28069271") ).toDF("id", "firstname", "lastname", "unique_id") - assertSmallDataFrameEquality( actualDF = resultDF, expectedDF = expectedDF, ignoreNullable = true, - orderedComparison = false) + orderedComparison = false + ) } } describe("Generate metrics for optimize functions on Delta Table") { it("should return valid file sizes and num records for non partitioned tables") { val path = (os.pwd / "tmp" / "delta-table-non-partitioned").toString() - createBaseDeltaTable(path) - val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution(path) + createBaseDeltaTable(path, rows) + val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution(path) val numRecordsDF = DeltaHelpers.deltaNumRecordDistribution(path) - fileSizeDF.count() should equal(1l) - assertDistributionCount(fileSizeDF, (0, 1l, 1088.0, null, 1088l, 1088l, Array(1088, 1088, 1088, 1088, 1088, 1088))) - numRecordsDF.count() should equal(1l) - assertDistributionCount(numRecordsDF, (0, 1l, 7.0, null, 7l, 7l, Array(7, 7, 7, 7, 7, 7))) + fileSizeDF.count() should equal(1L) + assertDistributionCount( + fileSizeDF, + (0, 1L, 1088.0, null, 1088L, 1088L, Array(1088, 1088, 1088, 1088, 1088, 1088)) + ) + numRecordsDF.count() should equal(1L) + assertDistributionCount(numRecordsDF, (0, 1L, 7.0, null, 7L, 7L, Array(7, 7, 7, 7, 7, 7))) } it("should return valid file sizes and num records for single partitioned tables") { val path = (os.pwd / "tmp" / "delta-table-single-partition").toString() - createBaseDeltaTableWithPartitions(path, Seq("lastname")) - val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution(path, Some("lastname='Travolta'")) + createBaseDeltaTableWithPartitions(path, Seq("lastname"), partRows) + val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution(path, Some("lastname='Travolta'")) val numRecordsDF = DeltaHelpers.deltaNumRecordDistribution(path, Some("lastname='Travolta'")) - fileSizeDF.count() should equal(1l) - assertDistributionCount(fileSizeDF, (1, 1l, 756.0, null, 756, 756, Array(756, 756, 756, 756, 756, 756))) - numRecordsDF.count() should equal(1l) - assertDistributionCount(numRecordsDF, (1, 1l, 3.0, null, 3, 3, Array(3, 3, 3, 3, 3, 3))) + fileSizeDF.count() should equal(1L) + assertDistributionCount( + fileSizeDF, + (1, 1L, 756.0, null, 756, 756, Array(756, 756, 756, 756, 756, 756)) + ) + numRecordsDF.count() should equal(1L) + assertDistributionCount(numRecordsDF, (1, 1L, 3.0, null, 3, 3, Array(3, 3, 3, 3, 3, 3))) } it("should return valid file sizes and num records for multiple partitioned tables") { val path = (os.pwd / "tmp" / "delta-table-multi-partition").toString() - createBaseDeltaTableWithPartitions(path, Seq("lastname", "firstname")) - val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution(path, Some("lastname='Travolta' and firstname='Jose'")) - val numRecordsDF = DeltaHelpers.deltaNumRecordDistribution(path, Some("lastname='Travolta' and firstname='Jose'")) - fileSizeDF.count() should equal(1l) - assertDistributionCount(fileSizeDF, (2, 1l, 456.0, null, 456, 456, Array(456, 456, 456, 456, 456, 456))) - numRecordsDF.count() should equal(1l) - assertDistributionCount(numRecordsDF, (2, 1l, 2.0, null, 2, 2, Array(2, 2, 2, 2, 2, 2))) + createBaseDeltaTableWithPartitions(path, Seq("lastname", "firstname"), partRows) + val fileSizeDF = DeltaHelpers.deltaFileSizeDistribution( + path, + Some("lastname='Travolta' and firstname='Jose'") + ) + val numRecordsDF = DeltaHelpers.deltaNumRecordDistribution( + path, + Some("lastname='Travolta' and firstname='Jose'") + ) + fileSizeDF.count() should equal(1L) + assertDistributionCount( + fileSizeDF, + (2, 1L, 456.0, null, 456, 456, Array(456, 456, 456, 456, 456, 456)) + ) + numRecordsDF.count() should equal(1L) + assertDistributionCount(numRecordsDF, (2, 1L, 2.0, null, 2, 2, Array(2, 2, 2, 2, 2, 2))) } - it("should return valid file sizes in megabytes"){ + it("should return valid file sizes in megabytes") { val path = (os.pwd / "tmp" / "delta-table-multi-files").toString() - def getDF(partition:String) = { - (1 to 10000).toDF("id") + + def getDF(partition: String) = { + (1 to 10000) + .toDF("id") .collect() .map(_.getInt(0)) .map(id => (id, partition, id + 10)) .toSeq } + (getDF("dog") ++ getDF("cat") ++ getDF("bird")) - .toDF("id", "animal", "age").write.mode("overwrite") - .format("delta").partitionBy("animal").save(path) + .toDF("id", "animal", "age") + .write + .mode("overwrite") + .format("delta") + .partitionBy("animal") + .save(path) val fileSizeDF = DeltaHelpers.deltaFileSizeDistributionInMB(path) - val size = 0.07698249816894531 + val size = 0.07698249816894531 fileSizeDF.count() should equal(3) - assertDistributionCount(fileSizeDF, (1, 1l, size, null, size, size, Array(size, size, size, size, size, size))) + assertDistributionCount( + fileSizeDF, + (1, 1L, size, null, size, size, Array(size, size, size, size, size, size)) + ) } + + describe("Generate file overlap metrics on running filter queries") { + it("should return valid metrics for partitioned tables") { + val path = (os.pwd / "tmp" / "delta-table-min-max-part").toString() + spark.conf.set("spark.sql.files.maxRecordsPerFile", "4") + createBaseDeltaTableWithPartitions(path, Seq("lastname"), minMaxRows) + + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "lastname = 'Travolta'"), + (3, 7, 3, 7, 7, 7, List()) + ) + // Min Max Query + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "lastname = 'Travolta' " + + "and id >= 10 and id <= 12" + ), + (2, 5, 3, 7, 7, 7, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "id <= 10 "), + (4, 4, 7, 7, 7, 7, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "id >= 12"), + (5, 5, 7, 7, 7, 7, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "snapshot.id = update.id and " + + "lastname = 'Travolta' and id >= 10 and id <= 12 and firstname like '%Joh%'" + ), + (2, 5, 3, 7, 7, 7, List("snapshot.id", "update.id")) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "lastname = 'Travolta' and id <= 10 and id >= 12"), + (0, 2, 3, 7, 7, 7, List()) + ) + } + it("should return valid metrics for non partitioned tables") { + val path = (os.pwd / "tmp" / "delta-table-min-max-no-part").toString() + spark.conf.set("spark.sql.files.maxRecordsPerFile", "4") + createBaseDeltaTable(path, minMaxRows) + + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "id <= 10 "), + (3, 3, 6, 6, 6, 6, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "id >= 12"), + (4, 4, 6, 6, 6, 6, List()) + ) + // Min Max Query + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles(path, "id >= 10 and id <= 12"), + (1, 1, 6, 6, 6, 6, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "id >= 10 and id <= 12 " + + "and firstname like '%Joh%'" + ), + (1, 1, 6, 6, 6, 6, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "lastname = 'Travolta' " + + "and id <= 10 and id >= 12" + ), + (1, 1, 6, 6, 6, 6, List()) + ) + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "id >= 10 and id <= 12 " + + "and (firstname = 'John' or firstname = 'Maria')" + ), + (1, 1, 6, 6, 6, 6, List()) + ) + } + it("should return valid metrics even when unresolved column names are present in the query") { + val path = (os.pwd / "tmp" / "delta-table-min-max-with-part").toString() + spark.conf.set("spark.sql.files.maxRecordsPerFile", "4") + createBaseDeltaTableWithPartitions(path, Seq("lastname"), minMaxRows) + + assertNumFilesCount( + DeltaHelpers.getNumShuffleFiles( + path, + "snapshot.id = update.id and " + + "lastname = 'Travolta' and id >= 10 and id <= 12" + ), + (2, 5, 3, 7, 7, 7, List("snapshot.id", "update.id")) + ) + } + spark.conf.set("spark.sql.files.maxRecordsPerFile", 0) + } + it("when rows are not in any order"){ + val path = (os.pwd / "tmp" / "delta-table-min-max-without-order").toString() + spark.conf.set("spark.sql.files.maxRecordsPerFile", "4") + createBaseDeltaTable(path, noOrderRows) + + val actualBefore = DeltaHelpers.getNumShuffleFiles( + path, + "snapshot.id = update.id and " + + " id >= 10 and id <= 12 " + ) + DeltaTable.forPath(spark, path).optimize().executeZOrderBy("id") + val actualAfter = DeltaHelpers.getNumShuffleFiles( + path, + "snapshot.id = update.id and " + + " id >= 10 and id <= 12 " + ) + } + } + + val rows = Seq( + (1, "Benito", "Jackson"), + (2, "Maria", "Willis"), + (3, "Jose", "Travolta"), + (4, "Benito", "Jackson"), + (5, "Jose", "Travolta"), + (6, "Jose", "Travolta"), + (7, "Maria", "Pitt") + ) + + val partRows = Seq( + (1, "Benito", "Jackson"), + (2, "Maria", "Willis"), + (3, "Jose", "Travolta"), + (4, "Patricia", "Jackson"), + (5, "Jose", "Travolta"), + (6, "Gabriela", "Travolta"), + (7, "Maria", "Pitt") + ) + val minMaxRows: Seq[(Int, String, String)] = rows ++ Seq( + (8, "Benito", "Jackson"), + (9, "Maria", "Willis"), + (10, "Jose", "Travolta"), + (11, "Benito", "Jackson"), + (12, "Jose", "Travolta"), + (13, "Jose", "Travolta"), + (14, "Maria", "Pitt"), + (15, "Jose", "Travolta"), + (16, "Jose", "Travolta"), + (17, "Maria", "Pitt"), + (18, "Benito", "Jackson"), + (19, "Maria", "Willis"), + (20, "Jose", "Travolta"), + (21, "Benito", "Jackson"), + (22, "Jose", "Travolta"), + (23, "Jose", "Travolta"), + (24, "Maria", "Pitt") + ) + + val noOrderRows = Seq( + (11, "Benito", "Jackson"), + (1, "Benito", "Jackson"), + (21, "Benito", "Jackson"), + (3, "Jose", "Travolta"), + (18, "Benito", "Jackson"), + (14, "Maria", "Pitt"), + (5, "Jose", "Travolta"), + (23, "Jose", "Travolta"), + (7, "Maria", "Pitt"), + (8, "Benito", "Jackson"), + (9, "Maria", "Willis"), + (10, "Jose", "Travolta"), + (17, "Maria", "Pitt"), + (12, "Jose", "Travolta"), + (16, "Jose", "Travolta"), + (4, "Patricia", "Jackson"), + (6, "Gabriela", "Travolta"), + (19, "Maria", "Willis"), + (20, "Jose", "Travolta"), + (15, "Jose", "Travolta"), + (22, "Jose", "Travolta"), + (13, "Jose", "Travolta"), + (24, "Maria", "Pitt"), + (2, "Maria", "Willis") + ) + + private def createBaseDeltaTable(path: String, data: Seq[(Int, String, String)]): Unit = { + data.toDF("id", "firstname", "lastname").write.format("delta").mode("overwrite").save(path) + } + + private def createBaseDeltaTableWithPartitions( + path: String, + partitionBy: Seq[String], + data: Seq[(Int, String, String)] + ) = { + val df = data.toDF("id", "firstname", "lastname") + df.write + .format("delta") + .mode("overwrite") + .partitionBy(partitionBy: _*) + .option("delta.logRetentionDuration", "interval 30 days") + .save(path) + df } - private def assertDistributionCount(df: DataFrame, expected: (Int, Long, Double, Any, Any, Any, Array[Double])) = { + private def assertDistributionCount( + df: DataFrame, + expected: (Int, Long, Double, Any, Any, Any, Array[Double]) + ) = { val actual = df.take(1)(0) actual.getAs[mutable.WrappedArray[(String, String)]](0).length should equal(expected._1) actual.getAs[Long](1) should equal(expected._2) @@ -715,34 +947,28 @@ class DeltaHelperSpec actual.getAs[Array[Double]](6) should equal(expected._7) } - private def createBaseDeltaTable(path: String): Unit = { - val df = Seq( - (1, "Benito", "Jackson"), - (2, "Maria", "Willis"), - (3, "Jose", "Travolta"), - (4, "Benito", "Jackson"), - (5, "Jose", "Travolta"), - (6, "Jose", "Travolta"), - (7, "Maria", "Pitt") - ).toDF("id", "firstname", "lastname") - df.write.format("delta").mode("overwrite").save(path) - } - private def createBaseDeltaTableWithPartitions(path: String, partitionBy: Seq[String]) = { - val df = Seq( - (1, "Benito", "Jackson"), - (2, "Maria", "Willis"), - (3, "Jose", "Travolta"), - (4, "Patricia", "Jackson"), - (5, "Jose", "Travolta"), - (6, "Gabriela", "Travolta"), - (7, "Maria", "Pitt") - ).toDF("id", "firstname", "lastname") - df.write - .format("delta") - .mode("overwrite") - .partitionBy(partitionBy: _*) - .option("delta.logRetentionDuration", "interval 30 days") - .save(path) - df + private def assertNumFilesCount( + actual: Map[String, Any], + expected: Tuple7[Int, Int, Int, Int, Int, Long, Seq[String]] + ) = { + actual.map { a => + if (a._1.contains(OVERALL)) { + a._2 should equal(expected._1) + } else if (a._1.contains(MIN_MAX)) { + a._2 should equal(expected._2) + } else if (a._1.contains(EQUALS)) { + a._2 should equal(expected._3) + } else if (a._1.contains(LEFT_OVER)) { + a._2 should equal(expected._4) + } else if (a._1.contains(UNRESOLVED)) { + a._2 should equal(expected._5) + } else if (a._1.contains(TOTAL_NUM_FILES)) { + a._2 should equal(expected._6) + } else if (a._1.contains(UNRESOLVED_COLS)) { + a._2 should equal(expected._7) + } else { + fail("Unexpected key") + } + } } }