Skip to content

Commit

Permalink
Adding support for groupBy / groupByKey (#11)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Adding support for `groupBy` and `groupByKey`. This patch adds a new RDD
mapPartitions implementation that will create tuples and is then in turn
used to run `groupBy` on the DataFrame efficiently.

On a high-level the execution flow is as follows:

```python
df.groupBy(bin_col1).applyInPandas(agg_all_values_for_a_key).mapInArrow(merge_k_v_to_tuple)
```

Once `applyInArrow` is available, we can switch the implementation to
that.

### How was this patch tested?

Added new UT.
  • Loading branch information
grundprinzip authored May 26, 2024
1 parent 1c5f066 commit b8ff488
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 7 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ that open a pull request and we will review it.
| getResourceProfile | :x: | |
| getStorageLevel | :x: | |
| glom | :white_check_mark: | |
| groupBy | :x: | |
| groupByKey | :x: | |
| groupBy | :white_check_mark: | |
| groupByKey | :white_check_mark: | |
| groupWith | :x: | |
| histogram | :white_check_mark: | |
| id | :x: | |
Expand All @@ -115,7 +115,7 @@ that open a pull request and we will review it.
| mapPartitions | :white_check_mark: | First version, based on mapInArrow. |
| mapPartitionsWithIndex | :x: | |
| mapPartitionsWithSplit | :x: | |
| mapValues | :x: | |
| mapValues | :white_check_mark: | |
| max | :white_check_mark: | |
| mean | :white_check_mark: | |
| meanApprox | :x: | |
Expand Down Expand Up @@ -173,3 +173,10 @@ that open a pull request and we will review it.
|-------------|--------------------|---------------------------------|
| parallelize | :white_check_mark: | Does not support numSlices yet. |

## Limitations

* Error handling and checking is kind of limited right now. We try
to emulate the existing behavior, but this is not always possible
because the invariants are not encode in Python but rather somewhere
in Scala.
* `numSlices` - we don't emulate this behavior for now.
139 changes: 135 additions & 4 deletions congruity/rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from collections import defaultdict
from functools import reduce
from typing import (
Any,
Expand All @@ -30,7 +31,8 @@
Generic,
)

from pyspark.serializers import CloudPickleSerializer
import pandas
from pyspark.serializers import CloudPickleSerializer, CPickleSerializer, AutoBatchedSerializer
from pyspark.statcounter import StatCounter

T = TypeVar("T")
Expand Down Expand Up @@ -112,12 +114,33 @@ class RDDAdapter(Generic[T_co]):
[sqltypes.StructField("__bin_field__", sqltypes.BinaryType(), True, {"serde": "true"})]
)

BIN_TUPLE_SCHEMA = sqltypes.StructType(
[
sqltypes.StructField("__bin_field_k__", sqltypes.BinaryType(), True, {"serde": "true"}),
sqltypes.StructField("__bin_field_v__", sqltypes.BinaryType(), True, {"serde": "true"}),
]
)

PA_SCHEMA = pa.schema([pa.field("__bin_field__", pa.binary(), True, {"serde": "true"})])

PA_TUPLE_SCHEMA = pa.schema(
[
pa.field("__bin_field_k__", pa.binary(), True, {"serde": "true"}),
pa.field("__bin_field_v__", pa.binary(), True, {"serde": "true"}),
]
)

def __init__(self, df: "DataFrame", first_field: bool = False):
self._df = df
self._first_field = first_field

# Compatibility attributes
self._jrdd_deserializer = AutoBatchedSerializer(CPickleSerializer())

def _memory_limit(self) -> int:
# Dummy value
return 1024 * 1024 * 512

def collect(self: "RDDAdapter[T]") -> List[T]:
data = self._df.collect()
if self._first_field:
Expand Down Expand Up @@ -237,14 +260,14 @@ def take(self: "RDDAdapter[T]", num: int) -> List[T]:
take.__doc__ = RDD.take.__doc__

def map(
self: "RDDApapter[T]", f: Callable[[T], U], preservePartitioning=None
self: "RDDApapter[T]", f: Callable[[T], U], preservesPartitioning=None
) -> "RDDAdapter[U]":
def func(iterator: Iterable[T]) -> Iterable[U]:
return map(fail_on_stopiteration(f), iterator)

# This is a diff to the regular map implementation because we don't have
# access to mapPartitionsWithIndex
return self.mapPartitions(func, preservePartitioning)
return self.mapPartitions(func, preservesPartitioning)

map.__doc__ = RDD.map.__doc__

Expand Down Expand Up @@ -308,6 +331,72 @@ def func(iterator: Iterable[T]) -> Iterable[U]:
variance = RDD.variance
variance.__doc__ = RDD.variance.__doc__

def groupBy(
self: "RDDAdapter[T]",
f: Callable[[T], K],
numPartitions: Optional[int] = None,
partitionFunc: Callable[[K], int] = lambda x: 1,
) -> "RDDAdapter[Tuple[K, Iterable[T]]]":
# First transform the date into a tuple, in contrast to the regular map calls, we're going to
# create the tuple as something that has two columns rather than one serialized value.
converted = self.map(lambda x: x)
transformer = lambda x: (f(x), x)

def func(iterator: Iterable[T]) -> Iterable[U]:
return map(fail_on_stopiteration(transformer), iterator)

tuple_list = TupleAdapter(converted, func)
return tuple_list.groupByKey(numPartitions, partitionFunc)

groupBy.__doc__ = RDD.groupBy.__doc__

def groupByKey(
self: "RDDAdapter[Tuple[K, V]]",
numPartitions: Optional[int] = None,
partitionFunc: Callable[[K], int] = lambda x: 1,
) -> "RDDAdapter[Tuple[K, Iterable[V]]]":

input_tuples = self
if not isinstance(self, TupleAdapter):

def extractor(x):
# Extracting the k, v using unpacking will automatically raise an exception if the
# number of values does not match.
a, b = x
return (a, b)

def func(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, V]]:
return map(fail_on_stopiteration(extractor), iterator)

input_tuples = TupleAdapter(self, func)

def transformer(df: pandas.DataFrame) -> pandas.DataFrame:
batch = pa.RecordBatch.from_pandas(df, schema=RDDAdapter.PA_TUPLE_SCHEMA)

# Generate the resulting aggregation
result = defaultdict(list)
for r in batch.to_pylist():
# r has two columns k and v
result[loads(r["__bin_field_k__"])].append(loads(r["__bin_field_v__"]))

# Serialize back the result
return pandas.DataFrame(
[dumps(x) for x in result.items()],
columns=[
"__bin_field__",
],
)

df = input_tuples._df.groupBy("__bin_field_k__").applyInPandas(
transformer, schema=RDDAdapter.BIN_SCHEMA
)
return RDDAdapter(df, first_field=True)

groupByKey.__doc__ = RDD.groupByKey.__doc__

mapValues = RDD.mapValues
mapValues.__doc__ = RDD.mapValues.__doc__

class WrappedIterator(Iterable):
"""This is a helper class that wraps the iterator of RecordBatches as returned by
mapInArrow and converts it into an iterator of the underlaying values."""
Expand All @@ -326,7 +415,7 @@ def __next__(self):
if self._first_field:
self._current_batch = [loads(x["__bin_field__"]) for x in v.to_pylist()]
else:
self._current_batch = [list(x.values()) for x in v.to_pylist()]
self._current_batch = [Row(**x) for x in v.to_pylist()]

result = self._current_batch[self._current_idx]
self._current_idx += 1
Expand All @@ -345,9 +434,51 @@ def mapPartitions(
mapPartitions.__doc__ = RDD.mapPartitions.__doc__


class TupleAdapter(RDDAdapter):
"""This is a helper class that takes an input RDD and converts it into a tupled RDD by
creating an output DF that has two columns, one for the key and one for the value. The
actual values are still encoded via cloudpickle."""

def __init__(self, input: "RDDAdapter", f):
super().__init__(input._df, input._first_field)
mapper = self._build_mapper(f, input._first_field)
self._df = input._df.mapInArrow(mapper, RDDAdapter.BIN_TUPLE_SCHEMA)
self._first_field = True

def _build_mapper(self, f, needs_conversion):
# Fixed constants for the mapPartitions implementation.
schema = RDDAdapter.PA_TUPLE_SCHEMA
max_rows_per_batch = 1000

def mapper(iter: Iterable[RecordBatch]):
# the function that is passed to mapPartitions works the same way as the mapper. But
# when next(iter) is called we have to send the converted batch instead of the raw
# data.
wrapped = RDDAdapter.WrappedIterator(iter, needs_conversion)
result = []
for kv in f(wrapped):
# kv is a tuple with a key value pair.
assert len(kv) == 2
result.append({"__bin_field_k__": dumps(kv[0]), "__bin_field_v__": dumps(kv[1])})
if len(result) > max_rows_per_batch:
yield RecordBatch.from_pylist(result, schema=schema)
result = []

if len(result) > 0:
yield RecordBatch.from_pylist(result, schema=schema)

return mapper


class Pipeline(RDDAdapter):
"""The pipeline is an extension of the RDDAdapter that allows to pipeline multiple
mapPartitions operations. This is useful to avoid the overhead of creating multiple
plan execution nodes. Instead, we can create a single plan node that executes all the
operations in a single pass."""

def __init__(self, input: "RDDAdapter", f):
# TODO check if this is ok
super().__init__(input._df, input._first_field)

if isinstance(input, Pipeline):
source = input._prev_source
Expand Down
35 changes: 35 additions & 0 deletions tests/test_rdd_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,38 @@ def test_rdd_variance(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
assert rdd.variance() == 8.25


def test_rdd_groupBy(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize(range(10))
res = rdd.groupBy(lambda x: x % 2).collect()
assert list(res[0][1]) == [0, 2, 4, 6, 8]
assert list(res[1][1]) == [1, 3, 5, 7, 9]

# Create a test where we create a dataframe to create the RDD first
df = spark_session.createDataFrame([(1, 2), (1, 3), (2, 4)], ["key", "value"])
res = df.rdd.groupBy(lambda x: x.key).collect()
assert list(res[0][1]) == [Row(key=1, value=2), Row(key=1, value=3)]
assert list(res[1][1]) == [Row(key=2, value=4)]


def test_rdd_map_with_rows(spark_session: "SparkSession"):
df = spark_session.createDataFrame([(1, 2), (1, 3), (2, 4)], ["key", "value"])
res = df.rdd.map(lambda x: x).collect()
assert res == [Row(key=1, value=2), Row(key=1, value=3), Row(key=2, value=4)]


def test_rdd_groupByKey(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize([(1, 2), (1, 3), (2, 4)])
res = rdd.groupByKey().collect()
assert list(res[0][1]) == [2, 3]
assert list(res[1][1]) == [4]


def test_rdd_mapValues(spark_session: "SparkSession"):
monkey_patch_spark()
rdd = spark_session.sparkContext.parallelize([(1, 2), (1, 3), (2, 4)])
res = rdd.mapValues(lambda x: x * 2).collect()
assert res == [(1, 4), (1, 6), (2, 8)]

0 comments on commit b8ff488

Please sign in to comment.