diff --git a/README.md b/README.md index 92e4706..4b2128c 100644 --- a/README.md +++ b/README.md @@ -481,6 +481,73 @@ ChangeDataFeedHelper(writePath, 9, 13).dryRun().readCDF ``` If no error found, it will return a similar Spark Dataframe with CDF between given versions. +## Operation Metric Helpers + +### Count Metrics on Delta Table between 2 versions +This function displays all count metric stored in the Delta Logs across versions for the entire Delta Table. It skips versions which do not record +these count metrics and presents a unified view. It shows the growth of a Delta Table by providing the record counts - +**deleted**, **updated** and **inserted** against a **version**. For a **merge** operation, we additionally have a source dataframe to tally +with as **source rows = (deleted + updated + inserted) rows**. Please note that you need to have enough Driver Memory +for processing the Delta Logs at driver level. +```scala +OperationMetricHelper(path,0,6).getCountMetricsAsDF() +``` +The result will be following: +```scala ++-------+-------+--------+-------+-----------+ +|version|deleted|inserted|updated|source_rows| ++-------+-------+--------+-------+-----------+ +|6 |0 |108 |0 |108 | +|5 |12 |0 |0 |0 | +|4 |0 |0 |300 |300 | +|3 |0 |100 |0 |100 | +|2 |0 |150 |190 |340 | +|1 |0 |0 |200 |200 | +|0 |0 |400 |0 |400 | ++-------+-------+--------+-------+-----------+ +``` +### Count Metrics at partition level of Delta Table +This function provides the same count metrics as the above function, but this time at a partition level. If operations +like **MERGE, DELETE** and **UPDATE** are executed **at a partition level**, then this function can help in visualizing count +metrics for such a partition. However, **it will not provide correct count metrics if these operations are performed +across partitions**. This is because Delta Log does not store this information at a log level and hence, need to be +implemented separately (we intend to take this up in future). Please note that you need to have enough Driver Memory +for processing the Delta Logs at driver level. +```scala +OperationMetricHelper(path).getCountMetricsAsDF( + Some(" country = 'USA' and gender = 'Female'")) + +// The same metric can be obtained generally without using spark dataframe +def getCountMetrics(partitionCondition: Option[String] = None) + : Seq[(Long, Long, Long, Long, Long)] +``` +The result will be following: +```scala ++-------+-------+--------+--------+-----------+ +|version|deleted|inserted| updated|source_rows| ++-------+-------+--------+--------+-----------+ +| 27| 0| 0|20635530| 20635524| +| 14| 0| 0| 1429460| 1429460| +| 13| 0| 0| 4670450| 4670450| +| 12| 0| 0|20635530| 20635524| +| 11| 0| 0| 5181821| 5181821| +| 10| 0| 0| 1562046| 1562046| +| 9| 0| 0| 1562046| 1562046| +| 6| 0| 0|20635518| 20635512| +| 3| 0| 0| 5181821| 5181821| +| 0| 0|56287990| 0| 56287990| ++-------+-------+--------+--------+-----------+ +``` +Supported Partition condition types +```scala +// Single Partition +Some(" country = 'USA'") +// Multiple Partition with AND condition. OR is not supported. +Some(" country = 'USA' and gender = 'Female'") +// Without Single Quotes +Some(" country = USA and gender = Female") +``` + ## How to contribute We welcome contributions to this project, to contribute checkout our [CONTRIBUTING.md](CONTRIBUTING.md) file. diff --git a/src/main/scala/mrpowers/jodie/OperationMetricHelper.scala b/src/main/scala/mrpowers/jodie/OperationMetricHelper.scala new file mode 100644 index 0000000..91c2cbb --- /dev/null +++ b/src/main/scala/mrpowers/jodie/OperationMetricHelper.scala @@ -0,0 +1,262 @@ +package mrpowers.jodie + +import mrpowers.jodie.delta._ +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.delta.util.FileNames +import org.apache.spark.sql.delta.{DeltaHistory, DeltaLog} +import org.apache.spark.sql.functions.{col, from_json} +import org.apache.spark.sql.types.{LongType, MapType, StringType, StructType} +import org.apache.spark.sql.{Encoders, SparkSession} + +case class OperationMetricHelper( + path: String, + startingVersion: Long = 0, + endingVersion: Option[Long] = None +)(implicit spark: SparkSession) { + private val deltaLog = DeltaLog.forTable(spark, path) + private val metricColumns = Seq("version", "deleted", "inserted", "updated", "source_rows") + + /** + * The function returns operation metrics - count metric for either a provided partition condition + * or without one, provides the count metric for the entire Delta Table as a Spark Dataframe. + * +-------+-------+--------+-------+-----------+ + * |version|deleted|inserted|updated|source_rows| + * +-------+-------+--------+-------+-----------+ + * |6 |0 |108 |0 |108 | + * |5 |12 |0 |0 |0 | + * |4 |0 |0 |300 |300 | + * |3 |0 |100 |0 |100 | + * |2 |0 |150 |190 |340 | + * |1 |0 |0 |200 |200 | + * |0 |0 |400 |0 |400 | + * +-------+-------+--------+-------+-----------+ + * + * @param partitionCondition + * @return + * [[org.apache.spark.sql.DataFrame]] + */ + def getCountMetricsAsDF(partitionCondition: Option[String] = None) = { + import spark.implicits._ + getCountMetrics(partitionCondition).toDF(metricColumns: _*) + } + + /** + * The function returns operation metrics - count metric for either a provided partition condition + * or without one, provides the count metric for the entire Delta Table as a Seq[(Long, Long, + * Long, Long, Long)]. + * + * @param partitionCondition + * @return + * [[Seq]] of [[Tuple5]], where element is a [[Long]] + */ + def getCountMetrics( + partitionCondition: Option[String] = None + ): Seq[(Long, Long, Long, Long, Long)] = { + val histories = partitionCondition match { + case None => deltaLog.history.getHistory(startingVersion, endingVersion) + case Some(condition) => + deltaLog.history + .getHistory(startingVersion, endingVersion) + .filter(x => filterHistoryByPartition(x, condition)) + } + transformMetric(generateMetric(histories, partitionCondition)) + } + + /** + * This function returns the records inserted count for WRITE/APPEND operation on a Delta Table + * partition for a version. It inspects the Delta Log aka Transaction Log to obtain this metric. + * @param partitionCondition + * @param version + * @return + * [[Long]] + */ + def getWriteMetricByPartition( + partitionCondition: String, + version: Long + ): Long = { + val conditions = splitConditionTo(partitionCondition).map(x => s"${x._1}=${x._2}") + val jsonSchema = new StructType() + .add("numRecords", LongType) + .add("minValues", MapType(StringType, StringType)) + .add("maxValues", MapType(StringType, StringType)) + .add("nullCount", MapType(StringType, StringType)) + spark.read + .json(FileNames.deltaFile(deltaLog.logPath, version).toString) + .withColumn("stats", from_json(col("add.stats"), jsonSchema)) + .select("add.path", "stats") + .map(x => { + val path = x.getAs[String]("path") + conditions.map(x => path != null && path.contains(x)).reduceOption(_ && _) match { + case None => 0L + case Some(bool) => + if (bool) + x.getAs[String]("stats").asInstanceOf[GenericRowWithSchema].getAs[Long]("numRecords") + else 0L + } + })(Encoders.scalaLong) + .reduce(_ + _) + } + + /** + * Filter and maps the relevant operation for providing count metric: MERGE, WRITE, DELETE and + * UPDATE + * + * @param metric + * @return + */ + private def transformMetric( + metric: Seq[(Long, OperationMetrics)] + ): Seq[(Long, Long, Long, Long, Long)] = metric.flatMap { case (version, opMetric) => + opMetric match { + case MergeMetric(_, deleted, _, _, inserted, _, updated, _, _, sourceRows, _, _) => + Seq((version, deleted, inserted, updated, sourceRows)) + case WriteMetric(_, inserted, _) => Seq((version, 0L, inserted, 0L, inserted)) + case DeleteMetric(deleted, _, _, _, _, _, _, _, _, _) => + Seq((version, deleted, 0L, 0L, 0L)) + case UpdateMetric(_, _, _, _, _, _, updated, _) => Seq((version, 0L, 0L, updated, 0L)) + case _ => Seq.empty + } + } + + /** + * Given a [[DeltaHistory]] and a partition condition string, this method returns whether the + * condition matches the partition condition applied to the operations like DELETE, UPDATE and + * MERGE. + * @param x + * @param y + * @return + */ + def parseDeltaLogToValidatePartitionCondition(x: DeltaHistory, y: String): Boolean = { + val inputConditions: Map[String, String] = splitConditionTo(y.toLowerCase) + val opParamInDeltaLog: Map[String, String] = splitConditionTo( + // targets a delta log delete string that looks like ["(((country = 'USA') AND (gender = 'Female')) AND (id = 2))"] + x.operationParameters("predicate") + .toLowerCase + .replaceAll("[()]", " ") + .replaceAll("[\\[\\]]", " ") + .replaceAll("\\\"", " ") + ) + inputConditions + .map(x => if (opParamInDeltaLog.contains(x._1)) opParamInDeltaLog(x._1) == x._2 else false) + .reduceOption(_ && _) match { + case None => false + case Some(b) => b + } + } + + /** + * Breaks down a string condition into [[Map]] of {[[String]],[[String]]} Handles the + * idiosyncrasies of Delta Log recorded predicate strings for operations like DELETE, UPDATE and + * MERGE + * @param partitionCondition + * @return + */ + def splitConditionTo(partitionCondition: String): Map[String, String] = { + val trimmed = partitionCondition.trim + val splitCondition = + if (trimmed.contains(" and ")) + trimmed.split(" and ").toSeq + else Seq(trimmed) + splitCondition + .map(x => { + val kv = x.split("=") + assert(kv.size == 2) + if (kv.head.contains("#")) { + // targets an update string that looks like (((country#590 = USA) AND (gender#588 = Female)) AND (id#587 = 4)) + kv.head.split("#")(0).trim -> kv.tail.head.trim.stripPrefix("\'").stripSuffix("\'") + } else if (kv.head.contains(".")) + // targets a merge string that looks like + // (((multi_partitioned_snapshot.id = source.id) AND (multi_partitioned_snapshot.country = 'IND')) AND + // (multi_partitioned_snapshot.gender = 'Male')) + kv.head.split("\\.")(1).trim -> kv.tail.head.trim.stripPrefix("\'").stripSuffix("\'") + else + kv.head.trim -> kv.tail.head.trim.stripPrefix("\'").stripSuffix("\'") + }) + .toMap + } + + private def filterHistoryByPartition(x: DeltaHistory, partitionCondition: String): Boolean = + x.operation match { + case "WRITE" => true + case "DELETE" | "MERGE" | "UPDATE" => + if ( + x.operationParameters + .contains("predicate") && x.operationParameters.get("predicate") != None + ) { + parseDeltaLogToValidatePartitionCondition(x, partitionCondition) + } else { + false + } + case _ => false + } + + private def generateMetric( + deltaHistories: Seq[DeltaHistory], + partitionCondition: Option[String] + ): Seq[(Long, OperationMetrics)] = + deltaHistories + .map(dh => { + ( + dh.version.get, { + val metrics = dh.operationMetrics.get + dh.operation match { + case "MERGE" => + MergeMetric( + numTargetRowsCopied = metrics("numTargetRowsCopied").toLong, + numTargetRowsDeleted = metrics("numTargetRowsDeleted").toLong, + numTargetFilesAdded = metrics("numTargetFilesAdded").toLong, + executionTimeMs = metrics("executionTimeMs").toLong, + numTargetRowsInserted = metrics("numTargetRowsInserted").toLong, + scanTimeMs = metrics("scanTimeMs").toLong, + numTargetRowsUpdated = metrics("numTargetRowsUpdated").toLong, + numOutputRows = metrics("numOutputRows").toLong, + numTargetChangeFilesAdded = metrics("numTargetChangeFilesAdded").toLong, + numSourceRows = metrics("numSourceRows").toLong, + numTargetFilesRemoved = metrics("numTargetFilesRemoved").toLong, + rewriteTimeMs = metrics("rewriteTimeMs").toLong + ) + case "WRITE" => + partitionCondition match { + case None => + WriteMetric( + numFiles = metrics("numFiles").toLong, + numOutputRows = metrics("numOutputRows").toLong, + numOutputBytes = metrics("numOutputBytes").toLong + ) + case Some(condition) => + WriteMetric(0L, getWriteMetricByPartition(condition, dh.version.get), 0L) + } + case "DELETE" => + DeleteMetric( + numDeletedRows = whenContains(metrics, "numDeletedRows"), + numAddedFiles = whenContains(metrics, "numAddedFiles"), + numCopiedRows = whenContains(metrics, "numCopiedRows"), + numRemovedFiles = whenContains(metrics, "numRemovedFiles"), + numAddedChangeFiles = whenContains(metrics, "numAddedChangeFiles"), + numRemovedBytes = whenContains(metrics, "numRemovedBytes"), + numAddedBytes = whenContains(metrics, "numAddedBytes"), + executionTimeMs = whenContains(metrics, "executionTimeMs"), + scanTimeMs = whenContains(metrics, "scanTimeMs"), + rewriteTimeMs = whenContains(metrics, "rewriteTimeMs") + ) + case "UPDATE" => + UpdateMetric( + numRemovedFiles = whenContains(metrics, "numRemovedFiles"), + numCopiedRows = whenContains(metrics, "numCopiedRows"), + numAddedChangeFiles = whenContains(metrics, "numAddedChangeFiles"), + executionTimeMs = whenContains(metrics, "executionTimeMs"), + scanTimeMs = whenContains(metrics, "scanTimeMs"), + numAddedFiles = whenContains(metrics, "numAddedFiles"), + numUpdatedRows = whenContains(metrics, "numUpdatedRows"), + rewriteTimeMs = whenContains(metrics, "rewriteTimeMs") + ) + case _ => null + } + } + ) + }) + .filter(x => x._2 != null) + + private def whenContains(map: Map[String, String], key: String) = + if (map.contains(key)) map(key).toLong else 0L +} diff --git a/src/main/scala/mrpowers/jodie/delta/OperationMetric.scala b/src/main/scala/mrpowers/jodie/delta/OperationMetric.scala new file mode 100644 index 0000000..1e0a16e --- /dev/null +++ b/src/main/scala/mrpowers/jodie/delta/OperationMetric.scala @@ -0,0 +1,43 @@ +package mrpowers.jodie.delta + +sealed trait OperationMetrics +case class DeleteMetric( + numDeletedRows: Long, + numAddedFiles: Long, + numCopiedRows: Long, + numRemovedFiles: Long, + numAddedChangeFiles: Long, + numRemovedBytes: Long, + numAddedBytes: Long, + executionTimeMs: Long, + scanTimeMs: Long, + rewriteTimeMs: Long +) extends OperationMetrics + +case class UpdateMetric( + numRemovedFiles: Long, + numCopiedRows: Long, + numAddedChangeFiles: Long, + executionTimeMs: Long, + scanTimeMs: Long, + numAddedFiles: Long, + numUpdatedRows: Long, + rewriteTimeMs: Long +) extends OperationMetrics +case class WriteMetric(numFiles: Long, numOutputRows: Long, numOutputBytes: Long) + extends OperationMetrics + +case class MergeMetric( + numTargetRowsCopied: Long, + numTargetRowsDeleted: Long, + numTargetFilesAdded: Long, + executionTimeMs: Long, + numTargetRowsInserted: Long, + scanTimeMs: Long, + numTargetRowsUpdated: Long, + numOutputRows: Long, + numTargetChangeFilesAdded: Long, + numSourceRows: Long, + numTargetFilesRemoved: Long, + rewriteTimeMs: Long +) extends OperationMetrics diff --git a/src/test/scala/mrpowers/jodie/ChangeDataFeedHelperSpec.scala b/src/test/scala/mrpowers/jodie/ChangeDataFeedHelperSpec.scala index 752e2ba..7882be4 100644 --- a/src/test/scala/mrpowers/jodie/ChangeDataFeedHelperSpec.scala +++ b/src/test/scala/mrpowers/jodie/ChangeDataFeedHelperSpec.scala @@ -2,6 +2,7 @@ package mrpowers.jodie import com.github.mrpowers.spark.fast.tests.DataFrameComparer import io.delta.tables.DeltaTable +import mrpowers.jodie.DeltaTestUtils.executeMergeFor import org.apache.hadoop.fs.Path import org.apache.spark.sql.DataFrame import org.apache.spark.sql.delta.util.FileNames @@ -245,18 +246,5 @@ class ChangeDataFeedHelperSpec extends AnyFunSpec table } - def executeMergeFor(tableName: String, deltaTable: DeltaTable, updates: List[(Int, String, Int)]) = { - import spark.implicits._ - updates.foreach(row => { - val dataFrame = Seq(row).toDF("id", "gender", "age") - deltaTable.as(tableName) - .merge(dataFrame.as("source"), s"${tableName}.id = source.id") - .whenMatched - .updateAll() - .whenNotMatched() - .insertAll() - .execute() - }) - deltaTable - } + } diff --git a/src/test/scala/mrpowers/jodie/DeltaTestUtils.scala b/src/test/scala/mrpowers/jodie/DeltaTestUtils.scala new file mode 100644 index 0000000..6c2ff31 --- /dev/null +++ b/src/test/scala/mrpowers/jodie/DeltaTestUtils.scala @@ -0,0 +1,35 @@ +package mrpowers.jodie + +import io.delta.tables.DeltaTable + +object DeltaTestUtils extends SparkSessionTestWrapper{ + def executeMergeFor(tableName: String, deltaTable: DeltaTable, updates: List[(Int, String, Int)]) = { + import spark.implicits._ + updates.foreach(row => { + val dataFrame = Seq(row).toDF("id", "gender", "age") + deltaTable.as(tableName) + .merge(dataFrame.as("source"), s"${tableName}.id = source.id") + .whenMatched + .updateAll() + .whenNotMatched() + .insertAll() + .execute() + }) + deltaTable + } + + def executeMergeWithReducedSearchSpace(tableName: String, deltaTable: DeltaTable, updates: List[(Int, String, Int,String)], condition:String)={ + import spark.implicits._ + updates.foreach(row => { + val dataFrame = Seq(row).toDF("id", "gender", "age","country") + deltaTable.as(tableName) + .merge(dataFrame.as("source"), s"${tableName}.id = source.id and $condition") + .whenMatched + .updateAll() + .whenNotMatched() + .insertAll() + .execute() + }) + deltaTable + } +} diff --git a/src/test/scala/mrpowers/jodie/OperationMetricHelperSpec.scala b/src/test/scala/mrpowers/jodie/OperationMetricHelperSpec.scala new file mode 100644 index 0000000..c7501d8 --- /dev/null +++ b/src/test/scala/mrpowers/jodie/OperationMetricHelperSpec.scala @@ -0,0 +1,292 @@ +package mrpowers.jodie + +import com.github.mrpowers.spark.fast.tests.{DataFrameComparer, DatasetContentMismatch} +import io.delta.tables.DeltaTable +import org.apache.spark.sql.functions.desc +import org.apache.spark.sql.{DataFrame, Row} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funspec.AnyFunSpec + +class OperationMetricHelperSpec + extends AnyFunSpec + with SparkSessionTestWrapper + with DataFrameComparer + with BeforeAndAfterEach { + var writePath = "" + override def afterEach(): Unit = { + val tmpDir = os.pwd / "tmp" / "delta-opm" + os.remove.all(tmpDir) + } + + describe("When Delta Table has relevant operation metric") { + val rows = Seq((1, "Male", 25), (2, "Female", 35), (3, "Female", 45), (4, "Male", 18)) + val updates = List((1, "Male", 35), (2, "Male", 100), (5, "Male", 101), (4, "Female", 18)) + val path = (os.pwd / "tmp" / "delta-opm").toString() + import spark.implicits._ + val snapshotDF = rows.toDF("id", "gender", "age") + implicit val sparkSession = spark + it("should return valid count metric") { + val name = "snapshot" + val deltaTable = + DeltaTestUtils.executeMergeFor( + name, + createDeltaTable(name, snapshotDF, path, None), + updates + ) + deltaTable.delete("id == 5") + Seq((10, "Female", 35)) + .toDF("id", "gender", "age") + .write + .format("delta") + .mode("append") + .save(writePath) + val actualDF = OperationMetricHelper(writePath).getCountMetricsAsDF() + val expected = toVersionDF( + Seq( + (6L, 0L, 1L, 0L, 1L), + (5L, 1L, 0L, 0L, 0L), + (4L, 0L, 0L, 1L, 1L), + (3L, 0L, 1L, 0L, 1L), + (2L, 0L, 0L, 1L, 1L), + (1L, 0L, 0L, 1L, 1L), + (0L, 0L, 4L, 0L, 4L) + ) + ) + assertSmallDataFrameEquality(actualDF, expected) + } + it("should return valid metric for single partition column") { + val deltaTable = partitionWithMerge("partitioned_snapshot", rows, updates, path) + val actual = OperationMetricHelper(writePath).getCountMetricsAsDF(Some(" country = 'USA'")) + val versions: Array[Row] = getCountryVersions(deltaTable) + val expected = + toVersionDF( + Seq( + (versions.head.getAs[Long]("version"), 0L, 0L, 1L, 1L), + (versions.tail.head.getAs[Long]("version"), 0L, 0L, 1L, 1L), + (0L, 0L, 2L, 0L, 2L) + ) + ) + assertSmallDataFrameEquality(actual, expected) + } + it("should return valid metric for single partition column containing deletes and appends") { + val deltaTable = partitionWithMerge("single_partitioned_snapshot", rows, updates, path) + deltaTable.delete("country == 'USA' and age == 100") + Seq((10, "Female", 35, "USA")) + .toDF("id", "gender", "age", "country") + .write + .format("delta") + .mode("append") + .partitionBy("country") + .save(writePath) + val condition = " country = 'USA'" + val actual = OperationMetricHelper(writePath).getCountMetricsAsDF(Some(" country = 'USA'")) + val versions: Array[Row] = getCountryVersions(deltaTable) + val expected = + toVersionDF( + Seq( + (6L, 0L, 1L, 0, 1L), + (5L, 1L, 0L, 0L, 0L), + (versions.head.getAs[Long]("version"), 0L, 0L, 1L, 1L), + (versions.tail.head.getAs[Long]("version"), 0L, 0L, 1L, 1L), + (0L, 0L, 2L, 0L, 2L) + ) + ) + assertSmallDataFrameEquality(actual, expected) + // Query works without single quotes + val actualWithoutSingleQuote = + OperationMetricHelper(writePath).getCountMetricsAsDF(Some(" country = USA")) + assertSmallDataFrameEquality(actualWithoutSingleQuote, expected) + intercept[DatasetContentMismatch] { + // This query does not work because partition has country=USA and query passed is country=usa + // Query types that would work but return wrong count metric, more precisely 0L(zero) as count for write metric + val actualDoesNotMatchPartitionCase = + OperationMetricHelper(writePath).getCountMetricsAsDF(Some(" country = usa")) + assertSmallDataFrameEquality(actualDoesNotMatchPartitionCase, expected) + } + } + it("should return valid metric for multiple partition columns") { + val deltaTable = multiplePartitionWithMerge("multi_partitioned_snapshot", rows, updates, path) + deltaTable.delete("country == 'USA' and gender = 'Female' and id == 2") + Seq((10, "Female", 35, "USA")) + .toDF("id", "gender", "age", "country") + .write + .format("delta") + .mode("append") + .partitionBy("country", "gender") + .save(writePath) + val actual = OperationMetricHelper(writePath).getCountMetricsAsDF( + Some(" country = 'USA' and gender = 'Female'") + ) + val version = getMergeVersionForPartition(deltaTable) + val expected = + toVersionDF( + Seq( + (6L, 0L, 1L, 0L, 1L), + (5L, 1L, 0L, 0L, 0L), + (version, 0L, 1L, 0L, 1L), + (0L, 0L, 1L, 0L, 1L) + ) + ) + assertSmallDataFrameEquality(actual, expected) + val actualWithoutSingleQuote = OperationMetricHelper(writePath).getCountMetricsAsDF( + Some(" country = USA and gender = Female") + ) + assertSmallDataFrameEquality(actualWithoutSingleQuote, expected) + intercept[DatasetContentMismatch] { + val actualDoesNotMatchPartitionCase = OperationMetricHelper(writePath).getCountMetricsAsDF( + Some(" country = usa and gender = female ") + ) + assertSmallDataFrameEquality(actualDoesNotMatchPartitionCase, expected) + } + } + it("should return valid metric for multiple partition columns containing updates") { + val deltaTable = + multiplePartitionWithMerge("multi_partitioned_snapshot_updated", rows, updates, path) + deltaTable.delete("country == 'USA' and gender = 'Female' and id == 2") + Seq((10, "Female", 35, "USA")) + .toDF("id", "gender", "age", "country") + .write + .format("delta") + .mode("append") + .partitionBy("country", "gender") + .save(writePath) + spark.sql(s"ALTER TABLE default.multi_partitioned_snapshot_updated SET TBLPROPERTIES (delta.enableChangeDataFeed = true)") + + deltaTable.updateExpr( + "country == 'USA' and gender = 'Female' and id == 4", + Map("age" -> "533") + ) + deltaTable.optimize().where("country='USA' and gender='Female'").executeCompaction() + val actual = OperationMetricHelper(writePath).getCountMetricsAsDF( + Some(" country = 'USA' and gender = 'Female'") + ) + val version = getMergeVersionForPartition(deltaTable) + val expected = + toVersionDF( + Seq( + (8L, 0L, 0L, 1L, 0L), + (6L, 0L, 1L, 0L, 1L), + (5L, 1L, 0L, 0L, 0L), + (version, 0L, 1L, 0L, 1L), + (0L, 0L, 1L, 0L, 1L) + ) + ) + assertSmallDataFrameEquality(actual, expected) + } + } + + private def getMergeVersionForPartition(deltaTable: DeltaTable): Long = { + val versionDF = deltaTable + .history() + .filter(" version > 0 and version < 5") + .select("version", "operationParameters.predicate") + .filter("predicate like '%USA%' and predicate like '%Female%'") + assert(versionDF.count() == 1) + versionDF.take(1).head.getAs[Long]("version") + } + + private def getCountryVersions(deltaTable: DeltaTable) = { + val versionDF = deltaTable + .history() + .filter("operation == 'MERGE'") + .select("version", "operationParameters.predicate") + .filter("predicate like '%USA%'") + .orderBy(desc("version")) + assert(versionDF.count() == 2) + val versions = versionDF.take(2) + versions + } + + private def partitionWithMerge( + tableName: String, + rows: Seq[(Int, String, Int)], + updates: List[(Int, String, Int)], + path: String + ): DeltaTable = { + import spark.implicits._ + val rowsWithCountry = rows.map(x => appendCountry(x)) + val deltaTable = createDeltaTable( + tableName, + rowsWithCountry.toDF("id", "gender", "age", "country"), + path, + Some(Seq("country")) + ) + val upsertCandidates = updates.map(x => appendCountry(x)) + upsertCandidates + .groupBy(x => x._4) + .foreach(y => { + DeltaTestUtils.executeMergeWithReducedSearchSpace( + tableName, + deltaTable, + y._2, + s" ${tableName}.country == '${y._1}'" + ) + () + }) + deltaTable + } + + private def multiplePartitionWithMerge( + tableName: String, + rows: Seq[(Int, String, Int)], + updates: List[(Int, String, Int)], + path: String + ): DeltaTable = { + import spark.implicits._ + val rowsWithCountry = rows.map(x => appendCountry(x)) + val deltaTable = createDeltaTable( + tableName, + rowsWithCountry.toDF("id", "gender", "age", "country"), + path, + Some(Seq("country", "gender")) + ) + val upsertCandidates = updates.map(x => appendCountry(x)) + upsertCandidates + .groupBy(x => (x._4, x._2)) + .foreach(y => { + DeltaTestUtils.executeMergeWithReducedSearchSpace( + tableName, + deltaTable, + y._2, + s" ${tableName}.country == '${y._1._1}' and ${tableName}.gender == '${y._1._2}'" + ) + () + }) + deltaTable + } + + private def toVersionDF(s: Seq[(Long, Long, Long, Long, Long)]): DataFrame = { + import spark.implicits._ + s.toDF( + "version", + "deleted", + "inserted", + "updated", + "source_rows" + ) + } + + private def appendCountry(x: (Int, String, Int)) = { + if (x._1 % 2 == 0) (x._1, x._2, x._3, "USA") else (x._1, x._2, x._3, "IND") + } + + def createDeltaTable( + tableName: String, + snapshotDF: DataFrame, + path: String, + partitionColumn: Option[Seq[String]] + ) = { + writePath = path + "/" + tableName + partitionColumn match { + case None => snapshotDF.write.format("delta").save(writePath) + case Some(pc) => + snapshotDF.write + .format("delta") + .partitionBy(pc: _*) + .save(writePath) + } + spark.sql(s"CREATE TABLE default.${tableName} USING DELTA LOCATION '${writePath}' ") + DeltaTable.forPath(writePath) + } + +}