From 1c5f066d0355638c92a21326e556f4f44ed55bb4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 20 May 2024 23:38:20 +0200 Subject: [PATCH] Adding more functions - min/max/variance/histogram (#10) ### What changes were proposed in this pull request? Adds few more functions that are simple forwards to the RDD implementation. ### How was this patch tested? Added new UT --- README.md | 14 +++++++------- congruity/rdd_adapter.py | 18 ++++++++++++++++++ tests/test_rdd_adapter.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 073d2f8..173e87e 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ that open a pull request and we will review it. | RDD | API | Comment | |-----------------------------------|--------------------|-------------------------------------------------------------------| -| aggregate | :x: | | +| aggregate | :white_check_mark: | | | aggregateByKey | :x: | | | barrier | :x: | | | cache | :x: | | @@ -84,7 +84,7 @@ that open a pull request and we will review it. | countByKey | :x: | | | countByValue | :x: | | | distinct | :x: | | -| filter | :x: | | +| filter | :white_check_mark: | | | first | :white_check_mark: | | | flatMap | :x: | | | fold | :white_check_mark: | First version | @@ -99,7 +99,7 @@ that open a pull request and we will review it. | groupBy | :x: | | | groupByKey | :x: | | | groupWith | :x: | | -| histogram | :x: | | +| histogram | :white_check_mark: | | | id | :x: | | | intersection | :x: | | | isCheckpointed | :x: | | @@ -116,10 +116,10 @@ that open a pull request and we will review it. | mapPartitionsWithIndex | :x: | | | mapPartitionsWithSplit | :x: | | | mapValues | :x: | | -| max | :x: | | -| mean | :x: | | +| max | :white_check_mark: | | +| mean | :white_check_mark: | | | meanApprox | :x: | | -| min | :x: | | +| min | :white_check_mark: | | | name | :x: | | | partitionBy | :x: | | | persist | :x: | | @@ -152,7 +152,7 @@ that open a pull request and we will review it. | take | :white_check_mark: | Ordering might not be guaranteed in the same way as it is in RDD. | | takeOrdered | :x: | | | takeSample | :x: | | -| toDF | :x: | | +| toDF | :white_check_mark: | | | toDebugString | :x: | | | toLocalIterator | :x: | | | top | :x: | | diff --git a/congruity/rdd_adapter.py b/congruity/rdd_adapter.py index 3f1f80c..ec0fc1a 100644 --- a/congruity/rdd_adapter.py +++ b/congruity/rdd_adapter.py @@ -290,6 +290,24 @@ def func(iterator: Iterable[T]) -> Iterable[U]: aggregate = RDD.aggregate aggregate.__doc__ = RDD.aggregate.__doc__ + max = RDD.max + max.__doc__ = RDD.max.__doc__ + + min = RDD.min + min.__doc__ = RDD.min.__doc__ + + filter = RDD.filter + filter.__doc__ = RDD.filter.__doc__ + + histogram = RDD.histogram + histogram.__doc__ = RDD.histogram.__doc__ + + mean = RDD.mean + mean.__doc__ = RDD.mean.__doc__ + + variance = RDD.variance + variance.__doc__ = RDD.variance.__doc__ + class WrappedIterator(Iterable): """This is a helper class that wraps the iterator of RecordBatches as returned by mapInArrow and converts it into an iterator of the underlaying values.""" diff --git a/tests/test_rdd_adapter.py b/tests/test_rdd_adapter.py index b03fa35..56333c1 100644 --- a/tests/test_rdd_adapter.py +++ b/tests/test_rdd_adapter.py @@ -246,3 +246,39 @@ def test_rdd_aggregate(spark_session: "SparkSession"): # TODO empty # res = spark_session.sparkContext.parallelize([]).aggregate((0, 0), seqOp, combOp) # assert res == (0, 0) + + +def test_rdd_min(spark_session: "SparkSession"): + monkey_patch_spark() + df = spark_session.range(10).repartition(1) + assert df.rdd.map(lambda x: x[0]).min() == 0 + + +def test_rdd_max(spark_session: "SparkSession"): + monkey_patch_spark() + df = spark_session.range(10).repartition(1) + assert df.rdd.map(lambda x: x[0]).max() == 9 + + +def test_rdd_histogram(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.histogram(3) == ([0, 3, 6, 9], [3, 3, 4]) + + +def test_rdd_filter(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.filter(lambda x: x % 2 == 0).collect() == [0, 2, 4, 6, 8] + + +def test_rdd_mean(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.mean() == 4.5 + + +def test_rdd_variance(spark_session: "SparkSession"): + monkey_patch_spark() + rdd = spark_session.sparkContext.parallelize(range(10)) + assert rdd.variance() == 8.25