Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject a repartition layer for inner merges #335

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
7 changes: 6 additions & 1 deletion dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ def combine_similar(

if changed_dependency:
expr = type(expr)(*new_operands)
changed = True
if isinstance(expr, Projection):
# We might introduce stacked Projections (merge for example).
# So get rid of them here again
expr_simplify_down = expr._simplify_down()
if expr_simplify_down is not None:
expr = expr_simplify_down
if update_root:
root = expr
continue
Expand Down
224 changes: 165 additions & 59 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dask.core import flatten
from dask.dataframe.dispatch import make_meta, meta_nonempty
from dask.dataframe.shuffle import partitioning_index
from dask.utils import M, apply, get_default_shuffle_method

from dask_expr._expr import (
Expand All @@ -13,12 +14,21 @@
Projection,
)
from dask_expr._repartition import Repartition
from dask_expr._shuffle import AssignPartitioningIndex, Shuffle, _contains_index_name
from dask_expr._shuffle import (
AssignPartitioningIndex,
Shuffle,
_contains_index_name,
_select_columns_or_index,
)
from dask_expr._util import _convert_to_list

_HASH_COLUMN_NAME = "__hash_partition"


def _partition_reducer(x):
return max(x // 2, 1)


class Merge(Expr):
"""Merge / join two dataframes

Expand Down Expand Up @@ -55,6 +65,10 @@ class Merge(Expr):
"shuffle_backend": None,
}

# combine similar variables
_skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
_remove_ops = (Projection,)

def __str__(self):
return f"Merge({self._name[-7:]})"

Expand Down Expand Up @@ -156,12 +170,10 @@ def _lower(self):
or shuffle_backend is None
and get_default_shuffle_method() == "p2p"
):
left = AssignPartitioningIndex(
left, shuffle_left_on, _HASH_COLUMN_NAME, self.npartitions
)
right = AssignPartitioningIndex(
right, shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions
)
if self.how == "inner":
left = Repartition(left, _partition_reducer)
right = Repartition(right, _partition_reducer)

return HashJoinP2P(
left,
right,
Expand All @@ -171,6 +183,8 @@ def _lower(self):
indicator=self.indicator,
left_index=left_index,
right_index=right_index,
shuffle_left_on=shuffle_left_on,
shuffle_right_on=shuffle_right_on,
)

if shuffle_left_on:
Expand Down Expand Up @@ -252,55 +266,123 @@ def _simplify_up(self, parent):
return type(parent)(result)
return result[parent_columns]

def _validate_same_operations(self, common, op, remove_ops, skip_ops):
def _validate_same_operations(self, common, op, remove="both"):
# Travers left and right to check if we can find the same operation
# more than once. We have to account for potential projections on both sides
name = common._name
if name == op._name:
return True
op_left, _ = self._remove_operations(op.left, remove_ops, skip_ops)
op_right, _ = self._remove_operations(op.right, remove_ops, skip_ops)
return type(op)(op_left, op_right, *op.operands[2:])._name == name
return True, op.left.columns, op.right.columns

columns_left, columns_right = None, None
op_left, op_right = op.left, op.right
if remove in ("both", "left"):
op_left, columns_left = self._remove_operations(
op.left, self._remove_ops, self._skip_ops
)
if remove in ("both", "right"):
op_right, columns_right = self._remove_operations(
op.right, self._remove_ops, self._skip_ops
)

return (
type(op)(op_left, op_right, *op.operands[2:])._name == name,
columns_left,
columns_right,
)

@staticmethod
def _flatten_columns(expr, columns, side):
if len(columns) == 0:
return getattr(expr, side).columns
else:
return list(set(flatten(columns)))

def _combine_similar(self, root: Expr):
# Push projections back up to avoid performing the same merge multiple times
skip_ops = (Filter, AssignPartitioningIndex, Shuffle)
remove_ops = (Projection,)

def _flatten_columns(columns, side):
if len(columns) == 0:
return getattr(self, side).columns
else:
return list(set(flatten(columns)))

left, columns_left = self._remove_operations(self.left, remove_ops, skip_ops)
columns_left = _flatten_columns(columns_left, "left")
right, columns_right = self._remove_operations(self.right, remove_ops, skip_ops)
columns_right = _flatten_columns(columns_right, "right")
left, columns_left = self._remove_operations(
self.left, self._remove_ops, self._skip_ops
)
columns_left = self._flatten_columns(self, columns_left, "left")
right, columns_right = self._remove_operations(
self.right, self._remove_ops, self._skip_ops
)
columns_right = self._flatten_columns(self, columns_right, "right")

if left._name == self.left._name and right._name == self.right._name:
# There aren't any ops we can remove, so bail
return

common = type(self)(left, right, *self.operands[2:])
# We can not remove Projections on both sides at once, because only
# one side might need the push back up step. So try if removing Projections
# on either side works before removing them on both sides at once.

common_left = type(self)(self.left, right, *self.operands[2:])
common_right = type(self)(left, self.right, *self.operands[2:])
common_both = type(self)(left, right, *self.operands[2:])

push_up_op = False
for op in self._find_similar_operations(root, ignore=self._parameters):
if self._validate_same_operations(common, op, remove_ops, skip_ops):
push_up_op = True
columns, left_sub, right_sub = None, None, None

for op in self._find_similar_operations(root, ignore=["left", "right"]):
if op._name in (common_right._name, common_left._name, common_both._name):
continue

validation = self._validate_same_operations(common_right, op, "left")
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
columns = self.right.columns.copy()
columns += [col for col in self.left.columns if col not in columns]
break

validation = self._validate_same_operations(common_left, op, "right")
if validation[0]:
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = self.left.columns.copy()
columns += [col for col in self.right.columns if col not in columns]
break

validation = self._validate_same_operations(common_both, op)
if validation[0]:
left_sub = self._flatten_columns(op, validation[1], side="left")
right_sub = self._flatten_columns(op, validation[2], side="right")
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
break

if push_up_op:
columns = columns_left.copy()
columns += [col for col in columns_right if col not in columns_left]
if _HASH_COLUMN_NAME in columns:
# Don't filter for hash_column_name which is removed in p2p merge
columns.remove(_HASH_COLUMN_NAME)
if sorted(common.columns) != sorted(columns):
common = common[columns]
c = common._simplify_down()
common = c if c is not None else common
return common
if columns is not None:
expr = self

if left_sub is not None:
left_sub.extend([col for col in columns_left if col not in left_sub])
left = self._replace_projections(self.left, left_sub)
expr = expr.substitute(self.left, left)

if right_sub is not None:
right_sub.extend([col for col in columns_right if col not in right_sub])
right = self._replace_projections(self.right, right_sub)
expr = expr.substitute(self.right, right)

if sorted(expr.columns) != sorted(columns):
expr = expr[columns]
return expr

def _replace_projections(self, frame, new_columns):
# This branch might have a number of Projections that differ from our
# new columns. We replace those projections appropriately

operations = []
while isinstance(frame, self._remove_ops + self._skip_ops):
if isinstance(frame, self._remove_ops):
# Ignore Projection if new_columns = frame.frame.columns
if sorted(new_columns) != sorted(frame.frame.columns):
operations.append((type(frame), [new_columns]))
else:
operations.append((type(frame), frame.operands[1:]))
frame = frame.frame

for op_type, operands in reversed(operations):
frame = op_type(frame, *operands)
return frame


class HashJoinP2P(Merge, PartitionsFiltered):
Expand All @@ -315,6 +397,8 @@ class HashJoinP2P(Merge, PartitionsFiltered):
"suffixes",
"indicator",
"_partitions",
"shuffle_left_on",
"shuffle_right_on",
]
_defaults = {
"how": "inner",
Expand All @@ -325,54 +409,47 @@ class HashJoinP2P(Merge, PartitionsFiltered):
"suffixes": ("_x", "_y"),
"indicator": False,
"_partitions": None,
"shuffle_left_on": None,
"shuffle_right_on": None,
}

def _lower(self):
return None

@functools.cached_property
def _meta(self):
left = self.left._meta.drop(columns=_HASH_COLUMN_NAME)
right = self.right._meta.drop(columns=_HASH_COLUMN_NAME)
return left.merge(
right,
left_on=self.left_on,
right_on=self.right_on,
indicator=self.indicator,
suffixes=self.suffixes,
left_index=self.left_index,
right_index=self.right_index,
)

def _layer(self) -> dict:
from distributed.shuffle._core import ShuffleId, barrier_key
from distributed.shuffle._merge import merge_transfer, merge_unpack
from distributed.shuffle._merge import merge_unpack
from distributed.shuffle._shuffle import shuffle_barrier

dsk = {}
name_left = "hash-join-transfer-" + self.left._name
name_right = "hash-join-transfer-" + self.right._name
transfer_keys_left = list()
transfer_keys_right = list()
func = create_assign_index_merge_transfer()
for i in range(self.left.npartitions):
transfer_keys_left.append((name_left, i))
dsk[(name_left, i)] = (
merge_transfer,
func,
(self.left._name, i),
self.shuffle_left_on,
_HASH_COLUMN_NAME,
self.npartitions,
self.left._name,
i,
self.npartitions,
self.left._meta,
self._partitions,
)
for i in range(self.right.npartitions):
transfer_keys_right.append((name_right, i))
dsk[(name_right, i)] = (
merge_transfer,
func,
(self.right._name, i),
self.shuffle_right_on,
_HASH_COLUMN_NAME,
self.npartitions,
self.right._name,
i,
self.npartitions,
self.right._meta,
self._partitions,
)
Expand Down Expand Up @@ -408,6 +485,35 @@ def _simplify_up(self, parent):
return


def create_assign_index_merge_transfer():
import pandas as pd
from distributed.shuffle._core import ShuffleId
from distributed.shuffle._merge import merge_transfer

def assign_index_merge_transfer(
df,
index,
name,
npartitions,
id: ShuffleId,
input_partition: int,
meta: pd.DataFrame,
parts_out: set[int],
):
index = _select_columns_or_index(df, index)
if isinstance(index, (str, list, tuple)):
# Assume column selection from df
index = [index] if isinstance(index, str) else list(index)
index = partitioning_index(df[index], npartitions)
else:
index = partitioning_index(index, npartitions)
df[name] = index
meta[name] = 0
return merge_transfer(df, id, input_partition, npartitions, meta, parts_out)

return assign_index_merge_transfer


class BlockwiseMerge(Merge, Blockwise):
"""Merge two dataframes with aligned partitions

Expand Down
2 changes: 1 addition & 1 deletion dask_expr/_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _divisions(self):
def _partitions_boundaries(self):
npartitions = self.new_partitions
npartitions_input = self.frame.npartitions
assert npartitions_input > npartitions
assert npartitions_input >= npartitions

npartitions_ratio = npartitions_input / npartitions
new_partitions_boundaries = [
Expand Down
3 changes: 3 additions & 0 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def _meta(self):
def npartitions(self):
return len(self._fusion_buckets)

def _broadcast_dep(self, dep: Expr):
return dep.npartitions == 1

def _divisions(self):
divisions = self.operand("expr")._divisions()
new_divisions = [divisions[b[0]] for b in self._fusion_buckets]
Expand Down
Loading
Loading