Skip to content

Commit

Permalink
Adding more functions - min/max/variance/histogram (#10)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Adds few more functions that are simple forwards to the RDD
implementation.

### How was this patch tested?
Added new UT
  • Loading branch information
grundprinzip authored May 20, 2024
1 parent 8467838 commit 1c5f066
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ that open a pull request and we will review it.

| RDD | API | Comment |
|-----------------------------------|--------------------|-------------------------------------------------------------------|
| aggregate | :x: | |
| aggregate | :white_check_mark: | |
| aggregateByKey | :x: | |
| barrier | :x: | |
| cache | :x: | |
Expand All @@ -84,7 +84,7 @@ that open a pull request and we will review it.
| countByKey | :x: | |
| countByValue | :x: | |
| distinct | :x: | |
| filter | :x: | |
| filter | :white_check_mark: | |
| first | :white_check_mark: | |
| flatMap | :x: | |
| fold | :white_check_mark: | First version |
Expand All @@ -99,7 +99,7 @@ that open a pull request and we will review it.
| groupBy | :x: | |
| groupByKey | :x: | |
| groupWith | :x: | |
| histogram | :x: | |
| histogram | :white_check_mark: | |
| id | :x: | |
| intersection | :x: | |
| isCheckpointed | :x: | |
Expand All @@ -116,10 +116,10 @@ that open a pull request and we will review it.
| mapPartitionsWithIndex | :x: | |
| mapPartitionsWithSplit | :x: | |
| mapValues | :x: | |
| max | :x: | |
| mean | :x: | |
| max | :white_check_mark: | |
| mean | :white_check_mark: | |
| meanApprox | :x: | |
| min | :x: | |
| min | :white_check_mark: | |
| name | :x: | |
| partitionBy | :x: | |
| persist | :x: | |
Expand Down Expand Up @@ -152,7 +152,7 @@ that open a pull request and we will review it.
| take | :white_check_mark: | Ordering might not be guaranteed in the same way as it is in RDD. |
| takeOrdered | :x: | |
| takeSample | :x: | |
| toDF | :x: | |
| toDF | :white_check_mark: | |
| toDebugString | :x: | |
| toLocalIterator | :x: | |
| top | :x: | |
Expand Down
18 changes: 18 additions & 0 deletions congruity/rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,24 @@ def func(iterator: Iterable[T]) -> Iterable[U]:
aggregate = RDD.aggregate
aggregate.__doc__ = RDD.aggregate.__doc__

max = RDD.max
max.__doc__ = RDD.max.__doc__

min = RDD.min
min.__doc__ = RDD.min.__doc__

filter = RDD.filter
filter.__doc__ = RDD.filter.__doc__

histogram = RDD.histogram
histogram.__doc__ = RDD.histogram.__doc__

mean = RDD.mean
mean.__doc__ = RDD.mean.__doc__

variance = RDD.variance
variance.__doc__ = RDD.variance.__doc__

class WrappedIterator(Iterable):
"""This is a helper class that wraps the iterator of RecordBatches as returned by
mapInArrow and converts it into an iterator of the underlaying values."""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,39 @@ def test_rdd_aggregate(spark_session: "SparkSession"):
# TODO empty
# res = spark_session.sparkContext.parallelize([]).aggregate((0, 0), seqOp, combOp)
# assert res == (0, 0)


def test_rdd_min(spark_session: "SparkSession"):
monkey_patch_spark()
df = spark_session.range(10).repartition(1)
assert df.rdd.map(lambda x: x[0]).min() == 0


def test_rdd_max(spark_session: "SparkSession"):
monkey_patch_spark()
df = spark_session.range(10).repartition(1)
assert df.rdd.map(lambda x: x[0]).max() == 9


def test_rdd_histogram(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.histogram(3) == ([0, 3, 6, 9], [3, 3, 4])


def test_rdd_filter(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.filter(lambda x: x % 2 == 0).collect() == [0, 2, 4, 6, 8]


def test_rdd_mean(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.mean() == 4.5


def test_rdd_variance(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.variance() == 8.25

0 comments on commit 1c5f066

Please sign in to comment.