Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Num Shuffle Files : Gets the count of shuffle files pulled into memory for a filter condition #76

Merged
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,30 @@ DeltaHelpers.deltaNumRecordDistribution(path)
+------------------------------------------------+--------------------+------------------+--------------------+--------------------+------------------+---------------------------------------------------------+
```

## Number of Shuffle Files in Merge & Other Filter Conditions
The function `getNumShuffleFiles` gets the number of shuffle files (part files for parquet) that will be pulled into memory for a given filter condition. This is particularly useful in a Delta Merge operation where the number of shuffle files can be a bottleneck.
Running the merge condition through this method can give an idea about the amount of memory resources required to run the merge.
For example, if the condition is "country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' " and country is the partition column,
```scala
DeltaHelpers.getNumShuffleFiles(path, "country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' ")
```

then the output might look like
```scala
(18, 100, 300, 600, 800, 800, List())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this output be a map that indicates what each value represents?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the output to a MAP and updated the method comments. Please check if it looks good, then I will update README.md @MrPowers

// 18 - number of files that will be pulled into memory for the entire provided condition
// 100 - number of files signifying the greater than/less than part => "age >= 30 and age <= 40"
// 300 - number of files signifying the equals part => "country = 'GBR'
// 600 - number of files signifying the like (or any other) part => "firstname like '%Jo%' "
// 800 - number of files signifying any other part. This is mostly a failsafe
// 1. to capture any other condition that might have been missed
// 2. If wrong attribute names or conditions are provided like snapshot.id = source.id (usually found in merge conditions)
// 800 - Total no. of files in the Delta Table
// List() - List of unresolved columns/attributes in the provided condition
```

This function works only on the Delta Log and does not scan any data in the Delta Table.

## Change Data Feed Helpers

### CASE I - When Delta aka Transaction Log gets purged
Expand Down
111 changes: 109 additions & 2 deletions src/main/scala/mrpowers/jodie/DeltaHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package mrpowers.jodie

import io.delta.tables._
import mrpowers.jodie.delta.DeltaConstants._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.expressions.Window.partitionBy
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}

import scala.collection.mutable

Expand Down Expand Up @@ -71,6 +73,111 @@ object DeltaHelpers {
def deltaNumRecordDistribution(path: String, condition: Option[String] = None): DataFrame =
getAllPartitionStats(deltaFileStats(path, condition), statsPartitionColumn, numRecordsColumn).toDF(numRecordsDFColumns: _*)

/**
* Gets the number of shuffle files (part files for parquet) that will be pulled into memory for a given filter condition.
* This is particularly useful in a Delta Merge operation where the number of shuffle files can be a bottleneck. Running
* the merge condition through this method can give an idea about the amount of memory resources required to run the merge.
*
* For example, if the condition is "country = 'GBR' and age >= 30 and age <= 40 and firstname like '%Jo%' " and country
* is the partition column, then the output might look like => (18, 100, 300, 600, 800, 800, List())
* 18 - number of files that will be pulled into memory for the entire provided condition
* 100 - number of files signifying the greater than/less than part => "age >= 30 and age <= 40"
* 300 - number of files signifying the equals part => "country = 'GBR'
* 600 - number of files signifying the like (or any other) part => "firstname like '%Jo%' "
* 800 - number of files signifying any other part. This is mostly a failsafe
* 1. to capture any other condition that might have been missed
* 2. If wrong attribute names or conditions are provided like snapshot.id = source.id (usually found in merge conditions)
* 800 - Total no. of files in the Delta Table
* List() - List of unresolved columns/attributes in the provided condition
*
* This function works only on the Delta Log and does not scan any data in the Delta Table.
*
* @param path
* @param condition
* @return
*/
def getNumShuffleFiles(path: String, condition: String) = {
val (deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions,
otherExpressions, removedPredicates) = getResolvedExpressions(path, condition)
deltaLog.withNewTransaction { deltaTxn =>
(deltaTxn.filterFiles(targetOnlyPredicates).count(a => true),
deltaTxn.filterFiles(minMaxOnlyExpressions).count(a => true),
deltaTxn.filterFiles(equalOnlyExpressions).count(a => true),
deltaTxn.filterFiles(otherExpressions).count(a => true),
deltaTxn.filterFiles(removedPredicates).count(a => true),
deltaLog.snapshot.filesWithStatsForScan(Nil).count(),
unresolvedColumns)
}
}

def getShuffleFileMetadata(path: String, condition: String)
= {
val (deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions,otherExpressions, removedPredicates) = getResolvedExpressions(path, condition)
deltaLog.withNewTransaction { deltaTxn =>
(deltaTxn.filterFiles(targetOnlyPredicates),
deltaTxn.filterFiles(minMaxOnlyExpressions),
deltaTxn.filterFiles(equalOnlyExpressions),
deltaTxn.filterFiles(otherExpressions),
deltaTxn.filterFiles(removedPredicates),
deltaLog.snapshot.filesWithStatsForScan(Nil),
unresolvedColumns)
}
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add any test cases for this method as the file metadata received is directly from Delta Table and I am not making any mutations to it. Additionally, please review if this method would be useful. I can remove it as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any use case in mind but I think it may be of help for someone investigating the behavior of a query. My only advice here would be to explicitly add the return type in the method signature so it can serve as reference for someone that potentially would like to use this method.

private def getResolvedExpressions(path: String, condition: String) = {
val spark = SparkSession.active
val deltaTable = DeltaTable.forPath(path)
val deltaLog = DeltaLog.forTable(spark, path)

val expression = functions.expr(condition).expr
val targetPlan = deltaTable.toDF.queryExecution.analyzed
val resolvedExpression: Expression = spark.sessionState.analyzer.resolveExpressionByPlanOutput(expression, targetPlan, true)
val unresolvedColumns = if (!resolvedExpression.childrenResolved) {
resolvedExpression.references.filter(a => a match {
case b: UnresolvedAttribute => true
case _ => false
}).map(a => a.asInstanceOf[UnresolvedAttribute].sql).toSeq
} else Seq()

def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
brayanjuls marked this conversation as resolved.
Show resolved Hide resolved
condition match {
case And(cond1, cond2) =>
splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
case other => other :: Nil
}
}

val splitExpressions = splitConjunctivePredicates(resolvedExpression)
val targetOnlyPredicates = splitExpressions.filter(_.references.subsetOf(targetPlan.outputSet))

val minMaxOnlyExpressions = targetOnlyPredicates.filter(e => e match {
case GreaterThanOrEqual(_, _) => true
case LessThanOrEqual(_, _) => true
case LessThan(_, _) => true
case GreaterThan(_, _) => true
case _ => false
})

val equalOnlyExpressions = targetOnlyPredicates.filter(e => e match {
case EqualTo(_, _) => true
case EqualNullSafe(_, _) => true
case _ => false
})

val otherExpressions = targetOnlyPredicates.filter(e => e match {
case EqualTo(_, _) => false
case EqualNullSafe(_, _) => false
case GreaterThanOrEqual(_, _) => false
case LessThanOrEqual(_, _) => false
case LessThan(_, _) => false
case GreaterThan(_, _) => false
case _ => true
})

val removedPredicates = splitExpressions.filterNot(_.references.subsetOf(targetPlan.outputSet))

(deltaLog, unresolvedColumns, targetOnlyPredicates, minMaxOnlyExpressions, equalOnlyExpressions, otherExpressions, removedPredicates)
}

brayanjuls marked this conversation as resolved.
Show resolved Hide resolved

private def getAllPartitionStats(filteredDF: DataFrame, groupByCol: String, aggCol: String) = filteredDF
.groupBy(map_entries(col(groupByCol)))
Expand All @@ -92,7 +199,7 @@ object DeltaHelpers {
val snapshot = tableLog.snapshot
condition match {
case None => snapshot.filesWithStatsForScan(Nil)
case Some(value) => snapshot.filesWithStatsForScan(Seq(expr(value).expr))
case Some(value) => snapshot.filesWithStatsForScan(Seq(functions.expr(value).expr))
}
}

Expand Down
Loading