Skip to content

Commit

Permalink
Add option to save statistics only
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Sep 13, 2023
1 parent 511b74d commit 7e96d99
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
1 change: 1 addition & 0 deletions DIFF.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ Input and output
--ignore <name> ignore column name
--save-mode <save-mode> save mode for writing output (Append, Overwrite, ErrorIfExists, Ignore, default ErrorIfExists)
--filter <filter> Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)
--statistics Only output statistics on how many rows exist per diff action (see 'Diffing options' section)
Diffing options
--diff-column <name> column name for diff column (default 'diff')
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/uk/co/gresearch/spark/diff/App.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ object App {
ignore: Seq[String] = Seq.empty,
saveMode: SaveMode = SaveMode.ErrorIfExists,
filter: Set[String] = Set.empty,
statistics: Boolean = false,
diffOptions: DiffOptions = DiffOptions.default)

// read options from args
Expand Down Expand Up @@ -179,6 +180,10 @@ object App {
.valueName("<filter>")
.action((x, c) => c.copy(filter = c.filter + x))
.text(s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)")
opt[Unit]("statistics")
.optional()
.action((_, c) => c.copy(statistics = true))
.text(s"Only output statistics on how many rows exist per diff action (see 'Diffing options' section)")

note("")
note("Diffing options")
Expand Down Expand Up @@ -244,8 +249,9 @@ object App {
.when(schema.isDefined).call(_.schema(schema.get))
.when(format.isDefined).either(_.load(path)).or(_.table(path))

def write(df: DataFrame, format: Option[String], path: String, options: Map[String, String], saveMode: SaveMode, filter: Set[String], diffOptions: DiffOptions): Unit =
def write(df: DataFrame, format: Option[String], path: String, options: Map[String, String], saveMode: SaveMode, filter: Set[String], saveStats: Boolean, diffOptions: DiffOptions): Unit =
df.when(filter.nonEmpty).call(_.where(col(diffOptions.diffColumn).isInCollection(filter)))
.when(saveStats).call(_.groupBy(diffOptions.diffColumn).count)
.write
.when(format.isDefined).call(_.format(format.get))
.options(options)
Expand All @@ -270,6 +276,6 @@ object App {
val left = read(spark, options.leftFormat, options.leftPath.get, options.leftSchema, options.leftOptions)
val right = read(spark, options.rightFormat, options.rightPath.get, options.rightSchema, options.rightOptions)
val diff = left.diff(right, options.diffOptions, options.ids, options.ignore)
write(diff, options.outputFormat, options.outputPath.get, options.outputOptions, options.saveMode, options.filter, options.diffOptions)
write(diff, options.outputFormat, options.outputPath.get, options.outputOptions, options.saveMode, options.filter, options.statistics, options.diffOptions)
}
}
28 changes: 28 additions & 0 deletions src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,32 @@ class AppSuite extends AnyFunSuite with SparkTestSession {
}
}
}

test("run app writing stats") {
withTempPath { path =>
// write left dataframe as parquet
val leftPath = new File(path, "left.parquet").getAbsolutePath
DiffSuite.left(spark).write.parquet(leftPath)

// write right dataframe as csv
val rightPath = new File(path, "right.parquet").getAbsolutePath
DiffSuite.right(spark).write.parquet(rightPath)

// launch app
val outputPath = new File(path, "diff.parquet").getAbsolutePath
App.main(Array(
"--format", "parquet",
"--statistics",
"--id", "id",
leftPath,
rightPath,
outputPath
))

// assert written diff
val actual = spark.read.parquet(outputPath).as[(String, Long)].collect().toMap
val expected = DiffSuite.expectedDiff.groupBy(row => row.getString(0)).view.mapValues(_.length).toMap
assert(actual === expected)
}
}
}

0 comments on commit 7e96d99

Please sign in to comment.