Skip to content

Commit

Permalink
Windowing Support for the Dask Runner (#32941)
Browse files Browse the repository at this point in the history
Windowing Support for the Dask Runner

---------

Co-authored-by: Pablo E <[email protected]>
Co-authored-by: Pablo <[email protected]>
Co-authored-by: Charles Stern <[email protected]>
  • Loading branch information
4 people authored Nov 18, 2024
1 parent bff2cbb commit e939be3
Show file tree
Hide file tree
Showing 9 changed files with 558 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dask_runner_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
run: pip install tox
- name: Install SDK with dask
working-directory: ./sdks/python
run: pip install setuptools --upgrade && pip install -e .[gcp,dask,test]
run: pip install setuptools --upgrade && pip install -e .[dask,test,dataframes]
- name: Run tests basic unix
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
working-directory: ./sdks/python
Expand Down
59 changes: 47 additions & 12 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,22 @@
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator
from apache_beam.runners.dask.transform_evaluator import Flatten
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms.sideinputs import SideInputMap
from apache_beam.utils.interactive_utils import is_in_notebook

try:
# Added to try to prevent threading related issues, see
# https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
import dask.distributed as ddist
except ImportError:
ddist = {}


class DaskOptions(PipelineOptions):
@staticmethod
Expand Down Expand Up @@ -86,10 +96,9 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:

@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
from dask import distributed

client: distributed.Client
futures: t.Sequence[distributed.Future]
client: ddist.Client
futures: t.Sequence[ddist.Future]

def __post_init__(self):
super().__init__(PipelineState.RUNNING)
Expand All @@ -99,8 +108,16 @@ def wait_until_finish(self, duration=None) -> str:
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
self.client.wait_for_workers(timeout=duration)
self.client.gather(self.futures, errors='raise')
for _ in ddist.as_completed(self.futures,
timeout=duration,
with_results=True):
# without gathering results, worker errors are not raised on the client:
# https://distributed.dask.org/en/stable/resilience.html#user-code-failures
# so we want to gather results to raise errors client-side, but we do
# not actually need to use the results here, so we just pass. to gather,
# we use the iterative `as_completed(..., with_results=True)`, instead
# of aggregate `client.gather`, to minimize memory footprint of results.
pass
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
Expand Down Expand Up @@ -133,6 +150,7 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)

op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
Expand All @@ -144,13 +162,28 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])

if len(bag_inputs) == 1:
self.bags[transform_node] = op.apply(bag_inputs[0])
# Input to `Flatten` could be of length 1, e.g. a single-element
# tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
# an iterable, because `Flatten.apply` always takes an iterable.
if len(bag_inputs) == 1 and not isinstance(op, Flatten):
op_kws["input_bag"] = bag_inputs[0]
else:
self.bags[transform_node] = op.apply(bag_inputs)
op_kws["input_bag"] = bag_inputs

side_inputs = list(transform_node.side_inputs)
if side_inputs:
bag_side_inputs = []
for si in side_inputs:
si_asbag = self.bags.get(si.pvalue.producer)
bag_side_inputs.append(
SideInputMap(
type(si),
si._view_options(),
DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))

op_kws["side_inputs"] = bag_side_inputs

else:
self.bags[transform_node] = op.apply(None)
self.bags[transform_node] = op.apply(**op_kws)

return DaskBagVisitor()

Expand All @@ -159,6 +192,8 @@ def is_fnapi_compatible():
return False

def run_pipeline(self, pipeline, options):
import dask

# TODO(alxr): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')
Expand All @@ -177,6 +212,6 @@ def run_pipeline(self, pipeline, options):

dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)

futures = client.compute(list(dask_visitor.bags.values()))
opt_graph = dask.optimize(*list(dask_visitor.bags.values()))
futures = client.compute(opt_graph)
return DaskRunnerResult(client, futures)
Loading

0 comments on commit e939be3

Please sign in to comment.