From 38cef666fcb9f524ab12629f183c9a4990bad988 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 12 Oct 2023 17:38:25 +0100 Subject: [PATCH 01/18] Add repartition for benchmarks --- dask_expr/_merge.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 88c7bab4..0ef4180a 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -162,6 +162,9 @@ def _lower(self): right = AssignPartitioningIndex( right, shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions ) + left = Repartition(left, lambda x: x // 2) + right = Repartition(right, lambda x: x // 2) + return HashJoinP2P( left, right, From 7d01aae435028c3f279d8c98499114d2e8444b4a Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:11:16 +0100 Subject: [PATCH 02/18] Fix --- dask_expr/_expr.py | 3 ++- dask_expr/_merge.py | 6 ++---- dask_expr/_shuffle.py | 22 +++++++++++++++++++++- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 47bba418..59dbbc1a 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2298,8 +2298,9 @@ def _fusion_pass(expr): dependents[next._name] = set() expr_mapping[next._name] = next + next_deps = [dep._name for dep in next.dependencies()] for operand in next.operands: - if isinstance(operand, Expr): + if isinstance(operand, Expr) and operand._name in next_deps: stack.append(operand) if isinstance(operand, Blockwise): if next._name in dependencies: diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 0ef4180a..68517ff7 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -157,13 +157,11 @@ def _lower(self): and get_default_shuffle_method() == "p2p" ): left = AssignPartitioningIndex( - left, shuffle_left_on, _HASH_COLUMN_NAME, self.npartitions + left, shuffle_left_on, _HASH_COLUMN_NAME, lambda x: x, right ) right = AssignPartitioningIndex( - right, shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions + right, shuffle_right_on, _HASH_COLUMN_NAME, lambda x: x, left ) - left = Repartition(left, lambda x: x // 2) - right = Repartition(right, lambda x: x // 2) return HashJoinP2P( left, diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index ccdaa3d9..bdfc9790 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -2,6 +2,7 @@ import math import operator import uuid +from typing import Callable import numpy as np import pandas as pd @@ -605,7 +606,26 @@ class AssignPartitioningIndex(Blockwise): Number of partitions after repartitioning is finished. """ - _parameters = ["frame", "partitioning_index", "index_name", "npartitions_out"] + _parameters = [ + "frame", + "partitioning_index", + "index_name", + "npartitions_out", + "other", + ] + _defaults = {"other": None} + + def dependencies(self): + return [self.frame] + + @functools.cached_property + def _args(self) -> list: + ops = self.operands.copy() + npart = ops[-2] + if isinstance(npart, Callable): + ops[-2] = npart(max(self.frame.npartitions, self.other.npartitions)) + ops.pop(-1) + return ops @staticmethod def operation(df, index, name: str, npartitions: int): From a77685b8af58b76c8e20600ec925ce51aaf9dccc Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 00:00:34 +0100 Subject: [PATCH 03/18] Fix --- dask_expr/_merge.py | 79 +++++++++++++++++++++++++++---------------- dask_expr/_shuffle.py | 22 +----------- 2 files changed, 50 insertions(+), 51 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 68517ff7..461bb3c9 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -1,7 +1,9 @@ import functools +import pandas as pd from dask.core import flatten from dask.dataframe.dispatch import make_meta, meta_nonempty +from dask.dataframe.shuffle import partitioning_index from dask.utils import M, apply, get_default_shuffle_method from dask_expr._expr import ( @@ -13,7 +15,7 @@ Projection, ) from dask_expr._repartition import Repartition -from dask_expr._shuffle import AssignPartitioningIndex, Shuffle, _contains_index_name +from dask_expr._shuffle import Shuffle, _contains_index_name, _select_columns_or_index from dask_expr._util import _convert_to_list _HASH_COLUMN_NAME = "__hash_partition" @@ -156,12 +158,8 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - left = AssignPartitioningIndex( - left, shuffle_left_on, _HASH_COLUMN_NAME, lambda x: x, right - ) - right = AssignPartitioningIndex( - right, shuffle_right_on, _HASH_COLUMN_NAME, lambda x: x, left - ) + # left = Repartition(left, lambda x: x // 2) + # right = Repartition(right, lambda x: x // 2) return HashJoinP2P( left, @@ -172,6 +170,8 @@ def _lower(self): indicator=self.indicator, left_index=left_index, right_index=right_index, + shuffle_left_on=shuffle_left_on, + shuffle_right_on=shuffle_right_on, ) if shuffle_left_on: @@ -265,7 +265,7 @@ def _validate_same_operations(self, common, op, remove_ops, skip_ops): def _combine_similar(self, root: Expr): # Push projections back up to avoid performing the same merge multiple times - skip_ops = (Filter, AssignPartitioningIndex, Shuffle) + skip_ops = (Filter, Shuffle) remove_ops = (Projection,) def _flatten_columns(columns, side): @@ -294,9 +294,6 @@ def _flatten_columns(columns, side): if push_up_op: columns = columns_left.copy() columns += [col for col in columns_right if col not in columns_left] - if _HASH_COLUMN_NAME in columns: - # Don't filter for hash_column_name which is removed in p2p merge - columns.remove(_HASH_COLUMN_NAME) if sorted(common.columns) != sorted(columns): common = common[columns] c = common._simplify_down() @@ -316,6 +313,8 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes", "indicator", "_partitions", + "shuffle_left_on", + "shuffle_right_on", ] _defaults = { "how": "inner", @@ -326,28 +325,16 @@ class HashJoinP2P(Merge, PartitionsFiltered): "suffixes": ("_x", "_y"), "indicator": False, "_partitions": None, + "shuffle_left_on": None, + "shuffle_right_on": None, } def _lower(self): return None - @functools.cached_property - def _meta(self): - left = self.left._meta.drop(columns=_HASH_COLUMN_NAME) - right = self.right._meta.drop(columns=_HASH_COLUMN_NAME) - return left.merge( - right, - left_on=self.left_on, - right_on=self.right_on, - indicator=self.indicator, - suffixes=self.suffixes, - left_index=self.left_index, - right_index=self.right_index, - ) - def _layer(self) -> dict: from distributed.shuffle._core import ShuffleId, barrier_key - from distributed.shuffle._merge import merge_transfer, merge_unpack + from distributed.shuffle._merge import merge_unpack from distributed.shuffle._shuffle import shuffle_barrier dsk = {} @@ -355,25 +342,30 @@ def _layer(self) -> dict: name_right = "hash-join-transfer-" + self.right._name transfer_keys_left = list() transfer_keys_right = list() + func = create_assign_index_merge_transfer() for i in range(self.left.npartitions): transfer_keys_left.append((name_left, i)) dsk[(name_left, i)] = ( - merge_transfer, + func, (self.left._name, i), + self.shuffle_left_on, + _HASH_COLUMN_NAME, + self.npartitions, self.left._name, i, - self.npartitions, self.left._meta, self._partitions, ) for i in range(self.right.npartitions): transfer_keys_right.append((name_right, i)) dsk[(name_right, i)] = ( - merge_transfer, + func, (self.right._name, i), + self.shuffle_left_on, + _HASH_COLUMN_NAME, + self.npartitions, self.right._name, i, - self.npartitions, self.right._meta, self._partitions, ) @@ -409,6 +401,33 @@ def _simplify_up(self, parent): return +def create_assign_index_merge_transfer(): + from distributed.shuffle._core import ShuffleId + from distributed.shuffle._merge import merge_transfer + + def assign_index_merge_transfer( + df, + index, + name, + npartitions, + id: ShuffleId, + input_partition: int, + meta: pd.DataFrame, + parts_out: set[int], + ): + index = _select_columns_or_index(df, index) + if isinstance(index, (str, list, tuple)): + # Assume column selection from df + index = [index] if isinstance(index, str) else list(index) + index = partitioning_index(df[index], npartitions) + else: + index = partitioning_index(index, npartitions) + df[name] = index + return merge_transfer(df, id, input_partition, npartitions, meta, parts_out) + + return assign_index_merge_transfer + + class BlockwiseMerge(Merge, Blockwise): """Merge two dataframes with aligned partitions diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index bdfc9790..ccdaa3d9 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -2,7 +2,6 @@ import math import operator import uuid -from typing import Callable import numpy as np import pandas as pd @@ -606,26 +605,7 @@ class AssignPartitioningIndex(Blockwise): Number of partitions after repartitioning is finished. """ - _parameters = [ - "frame", - "partitioning_index", - "index_name", - "npartitions_out", - "other", - ] - _defaults = {"other": None} - - def dependencies(self): - return [self.frame] - - @functools.cached_property - def _args(self) -> list: - ops = self.operands.copy() - npart = ops[-2] - if isinstance(npart, Callable): - ops[-2] = npart(max(self.frame.npartitions, self.other.npartitions)) - ops.pop(-1) - return ops + _parameters = ["frame", "partitioning_index", "index_name", "npartitions_out"] @staticmethod def operation(df, index, name: str, npartitions: int): From 7d4d37212a6ac64bc548a4771d37aae92fb61235 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 00:35:30 +0100 Subject: [PATCH 04/18] Fix --- dask_expr/_merge.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 461bb3c9..f250190a 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -423,6 +423,7 @@ def assign_index_merge_transfer( else: index = partitioning_index(index, npartitions) df[name] = index + meta[name] = 0 return merge_transfer(df, id, input_partition, npartitions, meta, parts_out) return assign_index_merge_transfer From 4f5d78fbcd4f705bfe0d4b4cc33ed33ae3696e51 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:25:21 +0100 Subject: [PATCH 05/18] Fix --- dask_expr/_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index f250190a..904e9d6d 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -361,7 +361,7 @@ def _layer(self) -> dict: dsk[(name_right, i)] = ( func, (self.right._name, i), - self.shuffle_left_on, + self.shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions, self.right._name, From 63fa4b6a77a2cb16799aaad9de3ad9f55f014635 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:29:07 +0100 Subject: [PATCH 06/18] Improve combine_similar for merge --- dask_expr/_expr.py | 7 ++- dask_expr/_merge.py | 133 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 114 insertions(+), 26 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 59dbbc1a..5236e661 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -417,7 +417,12 @@ def combine_similar( if changed_dependency: expr = type(expr)(*new_operands) - changed = True + if isinstance(expr, Projection): + # We might introduce stacked Projections (merge for example). + # So get rid of them here again + expr_simplify_down = expr._simplify_down() + if expr_simplify_down is not None: + expr = expr_simplify_down if update_root: root = expr continue diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 904e9d6d..82ca843b 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -15,7 +15,12 @@ Projection, ) from dask_expr._repartition import Repartition -from dask_expr._shuffle import Shuffle, _contains_index_name, _select_columns_or_index +from dask_expr._shuffle import ( + AssignPartitioningIndex, + Shuffle, + _contains_index_name, + _select_columns_or_index, +) from dask_expr._util import _convert_to_list _HASH_COLUMN_NAME = "__hash_partition" @@ -57,6 +62,10 @@ class Merge(Expr): "shuffle_backend": None, } + # combine similar variables + _skip_ops = (Filter, AssignPartitioningIndex, Shuffle) + _remove_ops = (Projection,) + def __str__(self): return f"Merge({self._name[-7:]})" @@ -253,52 +262,126 @@ def _simplify_up(self, parent): return type(parent)(result) return result[parent_columns] - def _validate_same_operations(self, common, op, remove_ops, skip_ops): + def _validate_same_operations(self, common, op, remove="both"): # Travers left and right to check if we can find the same operation # more than once. We have to account for potential projections on both sides name = common._name if name == op._name: - return True - op_left, _ = self._remove_operations(op.left, remove_ops, skip_ops) - op_right, _ = self._remove_operations(op.right, remove_ops, skip_ops) - return type(op)(op_left, op_right, *op.operands[2:])._name == name + return True, op.left.columns, op.right.columns + + columns_left, columns_right = None, None + op_left, op_right = op.left, op.right + if remove in ("both", "left"): + op_left, columns_left = self._remove_operations( + op.left, self._remove_ops, self._skip_ops + ) + if remove in ("both", "right"): + op_right, columns_right = self._remove_operations( + op.right, self._remove_ops, self._skip_ops + ) + + return ( + type(op)(op_left, op_right, *op.operands[2:])._name == name, + columns_left, + columns_right, + ) def _combine_similar(self, root: Expr): # Push projections back up to avoid performing the same merge multiple times - skip_ops = (Filter, Shuffle) - remove_ops = (Projection,) - def _flatten_columns(columns, side): - if len(columns) == 0: + def _flatten_columns(columns, side=None): + if len(columns) == 0 and side is not None: return getattr(self, side).columns else: return list(set(flatten(columns))) - left, columns_left = self._remove_operations(self.left, remove_ops, skip_ops) + left, columns_left = self._remove_operations( + self.left, self._remove_ops, self._skip_ops + ) columns_left = _flatten_columns(columns_left, "left") - right, columns_right = self._remove_operations(self.right, remove_ops, skip_ops) + right, columns_right = self._remove_operations( + self.right, self._remove_ops, self._skip_ops + ) columns_right = _flatten_columns(columns_right, "right") if left._name == self.left._name and right._name == self.right._name: # There aren't any ops we can remove, so bail return - common = type(self)(left, right, *self.operands[2:]) + # We can not remove Projections on both sides at once, because only + # one side might need the push back up step. So try if removing Projections + # on either side works before removing them on both sides at once. - push_up_op = False - for op in self._find_similar_operations(root, ignore=self._parameters): - if self._validate_same_operations(common, op, remove_ops, skip_ops): - push_up_op = True + common_left = type(self)(self.left, right, *self.operands[2:]) + common_right = type(self)(left, self.right, *self.operands[2:]) + common_both = type(self)(left, right, *self.operands[2:]) + + columns, left_sub, right_sub = None, None, None + + for op in self._find_similar_operations(root, ignore=["left", "right"]): + if op._name in (common_right._name, common_left._name, common_both._name): + continue + + validation = self._validate_same_operations(common_right, op, "left") + if validation[0]: + left_sub = _flatten_columns(validation[1]) + columns = self.right.columns.copy() + columns += [col for col in self.left.columns if col not in columns] + break + + validation = self._validate_same_operations(common_left, op, "right") + if validation[0]: + right_sub = _flatten_columns(validation[2]) + columns = self.left.columns.copy() + columns += [col for col in self.right.columns if col not in columns] break - if push_up_op: - columns = columns_left.copy() - columns += [col for col in columns_right if col not in columns_left] - if sorted(common.columns) != sorted(columns): - common = common[columns] - c = common._simplify_down() - common = c if c is not None else common - return common + validation = self._validate_same_operations(common_both, op) + if validation[0]: + left_sub = _flatten_columns(validation[1]) + right_sub = _flatten_columns(validation[2]) + columns = columns_left.copy() + columns += [col for col in columns_right if col not in columns_left] + break + + if columns is not None: + if _HASH_COLUMN_NAME in columns: + # Don't filter for hash_column_name which is removed in p2p merge + columns.remove(_HASH_COLUMN_NAME) + + expr = self + + if left_sub is not None: + left_sub.extend([col for col in columns_left if col not in left_sub]) + left = self._replace_projections(self.left, left_sub) + expr = expr.substitute(self.left, left) + + if right_sub is not None: + right_sub.extend([col for col in columns_right if col not in right_sub]) + right = self._replace_projections(self.right, right_sub) + expr = expr.substitute(self.right, right) + + if sorted(expr.columns) != sorted(columns): + expr = expr[columns] + return expr + + def _replace_projections(self, frame, new_columns): + # This branch might have a number of Projections that differ from our + # new columns. We replace those projections appropriately + + operations = [] + while isinstance(frame, self._remove_ops + self._skip_ops): + if isinstance(frame, self._remove_ops): + # Ignore Projection if new_columns = frame.frame.columns + if sorted(new_columns) != sorted(frame.frame.columns): + operations.append((type(frame), [new_columns])) + else: + operations.append((type(frame), frame.operands[1:])) + frame = frame.frame + + for op_type, operands in reversed(operations): + frame = op_type(frame, *operands) + return frame class HashJoinP2P(Merge, PartitionsFiltered): From 9938db6a781ed11490282aafcb596db48085f822 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:46:36 +0100 Subject: [PATCH 07/18] Add repartition layer --- dask_expr/_merge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 82ca843b..f6439954 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -167,8 +167,8 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - # left = Repartition(left, lambda x: x // 2) - # right = Repartition(right, lambda x: x // 2) + left = Repartition(left, lambda x: x // 3) + right = Repartition(right, lambda x: x // 3) return HashJoinP2P( left, From 8b16b4f500fd9550f26c70c0d7487769b16780d0 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 11:44:14 +0100 Subject: [PATCH 08/18] Fix repartition layer --- dask_expr/_merge.py | 4 ++-- dask_expr/_repartition.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index f6439954..5f5eebf1 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -167,8 +167,8 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - left = Repartition(left, lambda x: x // 3) - right = Repartition(right, lambda x: x // 3) + left = Repartition(left, lambda x: max(x // 3, 1)) + right = Repartition(right, lambda x: max(x // 3, 1)) return HashJoinP2P( left, diff --git a/dask_expr/_repartition.py b/dask_expr/_repartition.py index 48115349..08cbc76b 100644 --- a/dask_expr/_repartition.py +++ b/dask_expr/_repartition.py @@ -135,7 +135,7 @@ def _divisions(self): def _partitions_boundaries(self): npartitions = self.new_partitions npartitions_input = self.frame.npartitions - assert npartitions_input > npartitions + assert npartitions_input >= npartitions npartitions_ratio = npartitions_input / npartitions new_partitions_boundaries = [ From 63d175bbc65a9780132ac65866a657b90b58027c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:16:29 +0100 Subject: [PATCH 09/18] Fix fusion --- dask_expr/_merge.py | 33 +++++++++++++++++++-------------- dask_expr/io/io.py | 3 +++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 5f5eebf1..a0f24f44 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -26,6 +26,10 @@ _HASH_COLUMN_NAME = "__hash_partition" +def _partition_reducer(x): + return max(x // 3, 1) + + class Merge(Expr): """Merge / join two dataframes @@ -167,8 +171,8 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - left = Repartition(left, lambda x: max(x // 3, 1)) - right = Repartition(right, lambda x: max(x // 3, 1)) + left = Repartition(left, _partition_reducer) + right = Repartition(right, _partition_reducer) return HashJoinP2P( left, @@ -286,23 +290,24 @@ def _validate_same_operations(self, common, op, remove="both"): columns_right, ) + @staticmethod + def _flatten_columns(expr, columns, side): + if len(columns) == 0: + return getattr(expr, side).columns + else: + return list(set(flatten(columns))) + def _combine_similar(self, root: Expr): # Push projections back up to avoid performing the same merge multiple times - def _flatten_columns(columns, side=None): - if len(columns) == 0 and side is not None: - return getattr(self, side).columns - else: - return list(set(flatten(columns))) - left, columns_left = self._remove_operations( self.left, self._remove_ops, self._skip_ops ) - columns_left = _flatten_columns(columns_left, "left") + columns_left = self._flatten_columns(self, columns_left, "left") right, columns_right = self._remove_operations( self.right, self._remove_ops, self._skip_ops ) - columns_right = _flatten_columns(columns_right, "right") + columns_right = self._flatten_columns(self, columns_right, "right") if left._name == self.left._name and right._name == self.right._name: # There aren't any ops we can remove, so bail @@ -324,22 +329,22 @@ def _flatten_columns(columns, side=None): validation = self._validate_same_operations(common_right, op, "left") if validation[0]: - left_sub = _flatten_columns(validation[1]) + left_sub = self._flatten_columns(op, validation[1], side="left") columns = self.right.columns.copy() columns += [col for col in self.left.columns if col not in columns] break validation = self._validate_same_operations(common_left, op, "right") if validation[0]: - right_sub = _flatten_columns(validation[2]) + right_sub = self._flatten_columns(op, validation[2], side="right") columns = self.left.columns.copy() columns += [col for col in self.right.columns if col not in columns] break validation = self._validate_same_operations(common_both, op) if validation[0]: - left_sub = _flatten_columns(validation[1]) - right_sub = _flatten_columns(validation[2]) + left_sub = self._flatten_columns(op, validation[1], side="left") + right_sub = self._flatten_columns(op, validation[2], side="right") columns = columns_left.copy() columns += [col for col in columns_right if col not in columns_left] break diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 876a2278..915f13eb 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -137,6 +137,9 @@ def _meta(self): def npartitions(self): return len(self._fusion_buckets) + def _broadcast_dep(self, dep: Expr): + return dep.npartitions == 1 + def _divisions(self): divisions = self.operand("expr")._divisions() new_divisions = [divisions[b[0]] for b in self._fusion_buckets] From a94493ef1a0fa5abefbe0182453c27542e113509 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:24:03 +0100 Subject: [PATCH 10/18] Reduce repartition ratio --- dask_expr/_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index a0f24f44..59f06705 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -27,7 +27,7 @@ def _partition_reducer(x): - return max(x // 3, 1) + return max(x // 2, 1) class Merge(Expr): From 9e8f546290ab9bf895d68768f18933767e151a15 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:11:18 +0100 Subject: [PATCH 11/18] Remove --- dask_expr/_expr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 59dbbc1a..47bba418 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2298,9 +2298,8 @@ def _fusion_pass(expr): dependents[next._name] = set() expr_mapping[next._name] = next - next_deps = [dep._name for dep in next.dependencies()] for operand in next.operands: - if isinstance(operand, Expr) and operand._name in next_deps: + if isinstance(operand, Expr): stack.append(operand) if isinstance(operand, Blockwise): if next._name in dependencies: From ec5247bf9d5bd6d211846aae8b3b4f8458553668 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:12:30 +0100 Subject: [PATCH 12/18] Add assignpartitioning index --- dask_expr/_merge.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 904e9d6d..9871a5c3 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -15,7 +15,12 @@ Projection, ) from dask_expr._repartition import Repartition -from dask_expr._shuffle import Shuffle, _contains_index_name, _select_columns_or_index +from dask_expr._shuffle import ( + AssignPartitioningIndex, + Shuffle, + _contains_index_name, + _select_columns_or_index, +) from dask_expr._util import _convert_to_list _HASH_COLUMN_NAME = "__hash_partition" @@ -265,7 +270,7 @@ def _validate_same_operations(self, common, op, remove_ops, skip_ops): def _combine_similar(self, root: Expr): # Push projections back up to avoid performing the same merge multiple times - skip_ops = (Filter, Shuffle) + skip_ops = (Filter, AssignPartitioningIndex, Shuffle) remove_ops = (Projection,) def _flatten_columns(columns, side): From c9031ee87ed73c8c355b92fe7a310cb9da0e1a6c Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:13:59 +0100 Subject: [PATCH 13/18] Move import --- dask_expr/_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 9871a5c3..60ada836 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -1,6 +1,5 @@ import functools -import pandas as pd from dask.core import flatten from dask.dataframe.dispatch import make_meta, meta_nonempty from dask.dataframe.shuffle import partitioning_index @@ -407,6 +406,7 @@ def _simplify_up(self, parent): def create_assign_index_merge_transfer(): + import pandas as pd from distributed.shuffle._core import ShuffleId from distributed.shuffle._merge import merge_transfer From 1751a112dbfab38866678b67722a9888d249c8bb Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:21:02 +0100 Subject: [PATCH 14/18] Add test --- dask_expr/io/tests/test_distributed.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 dask_expr/io/tests/test_distributed.py diff --git a/dask_expr/io/tests/test_distributed.py b/dask_expr/io/tests/test_distributed.py new file mode 100644 index 00000000..ef6895df --- /dev/null +++ b/dask_expr/io/tests/test_distributed.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest + +from dask_expr.tests._util import _backend_library + +distributed = pytest.importorskip("distributed") + +from distributed import Client, LocalCluster +from distributed.utils_test import client as c # noqa F401 + +import dask_expr as dx + +lib = _backend_library() + + +def test_io_fusion_merge(tmpdir): + pdf = lib.DataFrame({c: range(100) for c in "abcdefghij"}) + with LocalCluster(processes=False, n_workers=2) as cluster: + with Client(cluster) as client: # noqa: F841 + dx.from_pandas(pdf, 10).to_parquet(tmpdir) + df = dx.read_parquet(tmpdir).merge( + dx.read_parquet(tmpdir).add_suffix("_x"), left_on="a", right_on="a_x" + )[["a_x", "b_x", "b"]] + out = df.compute() + lib.testing.assert_frame_equal( + out.sort_values(by="a_x", ignore_index=True), + pdf.merge(pdf.add_suffix("_x"), left_on="a", right_on="a_x")[ + ["a_x", "b_x", "b"] + ], + ) From c901743f800153c6ed810e3d7636be1162948fcb Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:32:07 +0100 Subject: [PATCH 15/18] Fix --- dask_expr/_merge.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index a9f5bd5c..43432ca2 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -166,9 +166,6 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - # left = Repartition(left, lambda x: x // 2) - # right = Repartition(right, lambda x: x // 2) - return HashJoinP2P( left, right, @@ -285,23 +282,24 @@ def _validate_same_operations(self, common, op, remove="both"): columns_right, ) + @staticmethod + def _flatten_columns(expr, columns, side): + if len(columns) == 0: + return getattr(expr, side).columns + else: + return list(set(flatten(columns))) + def _combine_similar(self, root: Expr): # Push projections back up to avoid performing the same merge multiple times - def _flatten_columns(columns, side=None): - if len(columns) == 0 and side is not None: - return getattr(self, side).columns - else: - return list(set(flatten(columns))) - left, columns_left = self._remove_operations( self.left, self._remove_ops, self._skip_ops ) - columns_left = _flatten_columns(columns_left, "left") + columns_left = self._flatten_columns(self, columns_left, "left") right, columns_right = self._remove_operations( self.right, self._remove_ops, self._skip_ops ) - columns_right = _flatten_columns(columns_right, "right") + columns_right = self._flatten_columns(self, columns_right, "right") if left._name == self.left._name and right._name == self.right._name: # There aren't any ops we can remove, so bail @@ -323,22 +321,22 @@ def _flatten_columns(columns, side=None): validation = self._validate_same_operations(common_right, op, "left") if validation[0]: - left_sub = _flatten_columns(validation[1]) + left_sub = self._flatten_columns(op, validation[1], side="left") columns = self.right.columns.copy() columns += [col for col in self.left.columns if col not in columns] break validation = self._validate_same_operations(common_left, op, "right") if validation[0]: - right_sub = _flatten_columns(validation[2]) + right_sub = self._flatten_columns(op, validation[2], side="right") columns = self.left.columns.copy() columns += [col for col in self.right.columns if col not in columns] break validation = self._validate_same_operations(common_both, op) if validation[0]: - left_sub = _flatten_columns(validation[1]) - right_sub = _flatten_columns(validation[2]) + left_sub = self._flatten_columns(op, validation[1], side="left") + right_sub = self._flatten_columns(op, validation[2], side="right") columns = columns_left.copy() columns += [col for col in columns_right if col not in columns_left] break From 34b3441feff6af1f5d6d811d3c60a23db88e1b62 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:33:24 +0100 Subject: [PATCH 16/18] Remove --- dask_expr/_merge.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 43432ca2..615cc69f 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -342,10 +342,6 @@ def _combine_similar(self, root: Expr): break if columns is not None: - if _HASH_COLUMN_NAME in columns: - # Don't filter for hash_column_name which is removed in p2p merge - columns.remove(_HASH_COLUMN_NAME) - expr = self if left_sub is not None: From 8f41dc6fa7cc2d5d23b99401ead8d147adc11b38 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:56:42 +0100 Subject: [PATCH 17/18] Add test --- dask_expr/tests/test_merge.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index ca8013f0..8df294f4 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -2,6 +2,7 @@ from dask.dataframe.utils import assert_eq from dask_expr import from_pandas +from dask_expr._expr import Projection from dask_expr.tests._util import _backend_library # Set DataFrame backend for this module @@ -206,3 +207,26 @@ def test_merge_combine_similar(npartitions_left, npartitions_right): expected["new"] = expected.b + expected.c expected = expected.groupby(["a", "e", "x"]).new.sum() assert_eq(query, expected) + + +def test_merge_combine_similar_intermediate_projections(): + pdf = lib.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "b": 1, + "c": 1, + } + ) + pdf2 = lib.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "x": 1}) + pdf3 = lib.DataFrame({"b": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "y": 1}) + + df = from_pandas(pdf, npartitions=2) + df2 = from_pandas(pdf2, npartitions=3) + df3 = from_pandas(pdf3, npartitions=3) + + q = df.merge(df2).merge(df3)[["b", "x", "y"]] + result = q.optimize(fuse=False) + # Check that we have intermediate projections dropping unnecessary columns + assert isinstance(result.expr.left, Projection) + assert result.expr.left.operand("columns") == ["b", "x"] + assert_eq(result, pdf.merge(pdf2).merge(pdf3)[["b", "x", "y"]], check_index=False) From babd562f2aec8cf8b58722f95448afd8fdf94f7a Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:58:10 +0100 Subject: [PATCH 18/18] Restrict to inner --- dask_expr/_merge.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index db23fcf1..870f6314 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -170,8 +170,9 @@ def _lower(self): or shuffle_backend is None and get_default_shuffle_method() == "p2p" ): - left = Repartition(left, _partition_reducer) - right = Repartition(right, _partition_reducer) + if self.how == "inner": + left = Repartition(left, _partition_reducer) + right = Repartition(right, _partition_reducer) return HashJoinP2P( left,