diff --git a/.github/workflows/dask_runner_tests.yml b/.github/workflows/dask_runner_tests.yml index f87c70d8b720..0f60c22b6aab 100644 --- a/.github/workflows/dask_runner_tests.yml +++ b/.github/workflows/dask_runner_tests.yml @@ -78,7 +78,7 @@ jobs: run: pip install tox - name: Install SDK with dask working-directory: ./sdks/python - run: pip install setuptools --upgrade && pip install -e .[gcp,dask,test] + run: pip install setuptools --upgrade && pip install -e .[dask,test,dataframes] - name: Run tests basic unix if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') working-directory: ./sdks/python diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index 109c4379b45d..0f2317074cea 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -31,12 +31,22 @@ from apache_beam.pipeline import PipelineVisitor from apache_beam.runners.dask.overrides import dask_overrides from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS +from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator +from apache_beam.runners.dask.transform_evaluator import Flatten from apache_beam.runners.dask.transform_evaluator import NoOp from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineState +from apache_beam.transforms.sideinputs import SideInputMap from apache_beam.utils.interactive_utils import is_in_notebook +try: + # Added to try to prevent threading related issues, see + # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456 + import dask.distributed as ddist +except ImportError: + ddist = {} + class DaskOptions(PipelineOptions): @staticmethod @@ -86,10 +96,9 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: @dataclasses.dataclass class DaskRunnerResult(PipelineResult): - from dask import distributed - client: distributed.Client - futures: t.Sequence[distributed.Future] + client: ddist.Client + futures: t.Sequence[ddist.Future] def __post_init__(self): super().__init__(PipelineState.RUNNING) @@ -99,8 +108,16 @@ def wait_until_finish(self, duration=None) -> str: if duration is not None: # Convert milliseconds to seconds duration /= 1000 - self.client.wait_for_workers(timeout=duration) - self.client.gather(self.futures, errors='raise') + for _ in ddist.as_completed(self.futures, + timeout=duration, + with_results=True): + # without gathering results, worker errors are not raised on the client: + # https://distributed.dask.org/en/stable/resilience.html#user-code-failures + # so we want to gather results to raise errors client-side, but we do + # not actually need to use the results here, so we just pass. to gather, + # we use the iterative `as_completed(..., with_results=True)`, instead + # of aggregate `client.gather`, to minimize memory footprint of results. + pass self._state = PipelineState.DONE except: # pylint: disable=broad-except self._state = PipelineState.FAILED @@ -133,6 +150,7 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None: op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) op = op_class(transform_node) + op_kws = {"input_bag": None, "side_inputs": None} inputs = list(transform_node.inputs) if inputs: bag_inputs = [] @@ -144,13 +162,28 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None: if prev_op in self.bags: bag_inputs.append(self.bags[prev_op]) - if len(bag_inputs) == 1: - self.bags[transform_node] = op.apply(bag_inputs[0]) + # Input to `Flatten` could be of length 1, e.g. a single-element + # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as + # an iterable, because `Flatten.apply` always takes an iterable. + if len(bag_inputs) == 1 and not isinstance(op, Flatten): + op_kws["input_bag"] = bag_inputs[0] else: - self.bags[transform_node] = op.apply(bag_inputs) + op_kws["input_bag"] = bag_inputs + + side_inputs = list(transform_node.side_inputs) + if side_inputs: + bag_side_inputs = [] + for si in side_inputs: + si_asbag = self.bags.get(si.pvalue.producer) + bag_side_inputs.append( + SideInputMap( + type(si), + si._view_options(), + DaskBagWindowedIterator(si_asbag, si._window_mapping_fn))) + + op_kws["side_inputs"] = bag_side_inputs - else: - self.bags[transform_node] = op.apply(None) + self.bags[transform_node] = op.apply(**op_kws) return DaskBagVisitor() @@ -159,6 +192,8 @@ def is_fnapi_compatible(): return False def run_pipeline(self, pipeline, options): + import dask + # TODO(alxr): Create interactive notebook support. if is_in_notebook(): raise NotImplementedError('interactive support will come later!') @@ -177,6 +212,6 @@ def run_pipeline(self, pipeline, options): dask_visitor = self.to_dask_bag_visitor() pipeline.visit(dask_visitor) - - futures = client.compute(list(dask_visitor.bags.values())) + opt_graph = dask.optimize(*list(dask_visitor.bags.values())) + futures = client.compute(opt_graph) return DaskRunnerResult(client, futures) diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py index d8b3e17d8a56..66dda4a984f4 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import datetime import inspect +import typing as t import unittest import apache_beam as beam @@ -22,12 +24,14 @@ from apache_beam.testing import test_pipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import window try: - from apache_beam.runners.dask.dask_runner import DaskOptions - from apache_beam.runners.dask.dask_runner import DaskRunner import dask import dask.distributed as ddist + + from apache_beam.runners.dask.dask_runner import DaskOptions # pylint: disable=ungrouped-imports + from apache_beam.runners.dask.dask_runner import DaskRunner # pylint: disable=ungrouped-imports except (ImportError, ModuleNotFoundError): raise unittest.SkipTest('Dask must be installed to run tests.') @@ -73,6 +77,11 @@ def test_create(self): pcoll = p | beam.Create([1]) assert_that(pcoll, equal_to([1])) + def test_create_multiple(self): + with self.pipeline as p: + pcoll = p | beam.Create([1, 2, 3, 4]) + assert_that(pcoll, equal_to([1, 2, 3, 4])) + def test_create_and_map(self): def double(x): return x * 2 @@ -81,6 +90,22 @@ def double(x): pcoll = p | beam.Create([1]) | beam.Map(double) assert_that(pcoll, equal_to([2])) + def test_create_and_map_multiple(self): + def double(x): + return x * 2 + + with self.pipeline as p: + pcoll = p | beam.Create([1, 2]) | beam.Map(double) + assert_that(pcoll, equal_to([2, 4])) + + def test_create_and_map_many(self): + def double(x): + return x * 2 + + with self.pipeline as p: + pcoll = p | beam.Create(list(range(1, 11))) | beam.Map(double) + assert_that(pcoll, equal_to(list(range(2, 21, 2)))) + def test_create_map_and_groupby(self): def double(x): return x * 2, x @@ -89,6 +114,288 @@ def double(x): pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() assert_that(pcoll, equal_to([(2, [1])])) + def test_create_map_and_groupby_multiple(self): + def double(x): + return x * 2, x + + with self.pipeline as p: + pcoll = ( + p + | beam.Create([1, 2, 1, 2, 3]) + | beam.Map(double) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])])) + + def test_map_with_positional_side_input(self): + def mult_by(x, y): + return x * y + + with self.pipeline as p: + side = p | "side" >> beam.Create([3]) + pcoll = ( + p + | "main" >> beam.Create([1]) + | beam.Map(mult_by, beam.pvalue.AsSingleton(side))) + assert_that(pcoll, equal_to([3])) + + def test_map_with_keyword_side_input(self): + def mult_by(x, y): + return x * y + + with self.pipeline as p: + side = p | "side" >> beam.Create([3]) + pcoll = ( + p + | "main" >> beam.Create([1]) + | beam.Map(mult_by, y=beam.pvalue.AsSingleton(side))) + assert_that(pcoll, equal_to([3])) + + def test_pardo_side_inputs(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b", "c"]) + side = p | "side" >> beam.Create(["x", "y"]) + assert_that( + main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)), + equal_to([ + ("a", "x"), + ("b", "x"), + ("c", "x"), + ("a", "y"), + ("b", "y"), + ("c", "y"), + ]), + ) + + def test_pardo_side_input_dependencies(self): + with self.pipeline as p: + inputs = [p | beam.Create([None])] + for k in range(1, 10): + inputs.append( + inputs[0] + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)], + )) + + def test_pardo_side_input_sparse_dependencies(self): + with self.pipeline as p: + inputs = [] + + def choose_input(s): + return inputs[(389 + s * 5077) % len(inputs)] + + for k in range(20): + num_inputs = int((k * k % 16)**0.5) + if num_inputs == 0: + inputs.append(p | f"Create{k}" >> beam.Create([f"Create{k}"])) + else: + inputs.append( + choose_input(0) + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[ + beam.pvalue.AsList(choose_input(s)) + for s in range(1, num_inputs) + ], + )) + + @unittest.expectedFailure + def test_pardo_windowed_side_inputs(self): + with self.pipeline as p: + # Now with some windowing. + pcoll = ( + p + | beam.Create(list(range(10))) + | beam.Map(lambda t: window.TimestampedValue(t, t))) + # Intentionally choosing non-aligned windows to highlight the transition. + main = pcoll | "WindowMain" >> beam.WindowInto(window.FixedWindows(5)) + side = pcoll | "WindowSide" >> beam.WindowInto(window.FixedWindows(7)) + res = main | beam.Map( + lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) + assert_that( + res, + equal_to([ + # The window [0, 5) maps to the window [0, 7). + (0, list(range(7))), + (1, list(range(7))), + (2, list(range(7))), + (3, list(range(7))), + (4, list(range(7))), + # The window [5, 10) maps to the window [7, 14). + (5, list(range(7, 10))), + (6, list(range(7, 10))), + (7, list(range(7, 10))), + (8, list(range(7, 10))), + (9, list(range(7, 10))), + ]), + label="windowed", + ) + + def test_flattened_side_input(self, with_transcoding=True): + with self.pipeline as p: + main = p | "main" >> beam.Create([None]) + side1 = p | "side1" >> beam.Create([("a", 1)]) + side2 = p | "side2" >> beam.Create([("b", 2)]) + if with_transcoding: + # Also test non-matching coder types (transcoding required) + third_element = [("another_type")] + else: + third_element = [("b", 3)] + side3 = p | "side3" >> beam.Create(third_element) + side = (side1, side2) | beam.Flatten() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, { + "a": 1, "b": 2 + })]), + label="CheckFlattenAsSideInput", + ) + assert_that( + (side, side3) | "FlattenAfter" >> beam.Flatten(), + equal_to([("a", 1), ("b", 2)] + third_element), + label="CheckFlattenOfSideInput", + ) + + def test_gbk_side_input(self): + with self.pipeline as p: + main = p | "main" >> beam.Create([None]) + side = p | "side" >> beam.Create([("a", 1)]) | beam.GroupByKey() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, { + "a": [1] + })]), + ) + + def test_multimap_side_input(self): + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + def test_multimap_multiside_input(self): + # A test where two transforms in the same stage consume the same PCollection + # twice as side input. + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | "first map" >> beam.Map( + lambda k, + d, + l: (k, sorted(d[k]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ) + | "second map" >> beam.Map( + lambda k, + d, + l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ), + equal_to([("a", [1, 3], [1, 2, 3]), ("b", [2], [1, 2, 3])]), + ) + + def test_multimap_side_input_type_coercion(self): + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + # The type of this side-input is forced to Any (overriding type + # inference). Without type coercion to Tuple[Any, Any], the usage of this + # side-input in AsMultiMap() below should fail. + side = p | "side" >> beam.Create([("a", 1), ("b", 2), + ("a", 3)]).with_output_types(t.Any) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + def test_pardo_unfusable_side_inputs__one(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + pcoll = p | "Create1" >> beam.Create(["a", "b"]) + assert_that( + pcoll | + "FlatMap1" >> beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + label="assert_that1", + ) + + def test_pardo_unfusable_side_inputs__two(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + pcoll = p | "Create2" >> beam.Create(["a", "b"]) + + derived = ((pcoll, ) + | beam.Flatten() + | beam.Map(lambda x: (x, x)) + | beam.GroupByKey() + | "Unkey" >> beam.Map(lambda kv: kv[0])) + assert_that( + pcoll | "FlatMap2" >> beam.FlatMap( + cross_product, beam.pvalue.AsList(derived)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + label="assert_that2", + ) + + def test_groupby_with_fixed_windows(self): + def double(x): + return x * 2, x + + def add_timestamp(pair): + delta = datetime.timedelta(seconds=pair[1] * 60) + now = (datetime.datetime.now() + delta).timestamp() + return window.TimestampedValue(pair, now) + + with self.pipeline as p: + pcoll = ( + p + | beam.Create([1, 2, 1, 2, 3]) + | beam.Map(double) + | beam.WindowInto(window.FixedWindows(60)) + | beam.Map(add_timestamp) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])])) + + def test_groupby_string_keys(self): + with self.pipeline as p: + pcoll = ( + p + | beam.Create([('a', 1), ('a', 2), ('b', 3), ('b', 4)]) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([('a', [1, 2]), ('b', [3, 4])])) + + +class ExpectingSideInputsFn(beam.DoFn): + def __init__(self, name): + self._name = name + + def default_label(self): + return self._name + + def process(self, element, *side_inputs): + if not all(list(s) for s in side_inputs): + raise ValueError(f"Missing data in side input {side_inputs}") + yield self._name + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/dask/overrides.py b/sdks/python/apache_beam/runners/dask/overrides.py index d07c7cd518af..b952834f12d7 100644 --- a/sdks/python/apache_beam/runners/dask/overrides.py +++ b/sdks/python/apache_beam/runners/dask/overrides.py @@ -73,7 +73,6 @@ def infer_output_type(self, input_type): @typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupAlsoByWindow(beam.ParDo): - """Not used yet...""" def __init__(self, windowing): super().__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing @@ -86,12 +85,23 @@ def expand(self, input_or_inputs): @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKey(beam.PTransform): def expand(self, input_or_inputs): - return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() + return ( + input_or_inputs + | "ReifyWindows" >> beam.ParDo(beam.GroupByKey.ReifyWindows()) + | "GroupByKey" >> _GroupByKeyOnly() + | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing)) class _Flatten(beam.PTransform): def expand(self, input_or_inputs): - is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) + if isinstance(input_or_inputs, beam.PCollection): + # NOTE(cisaacstern): I needed this to avoid + # `TypeError: 'PCollection' object is not iterable` + # being raised by `all(...)` call below for single-element flattens, i.e., + # `(pcoll, ) | beam.Flatten() | ...` + is_bounded = input_or_inputs.is_bounded + else: + is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index d4d58879b7fe..e3bd5fd87763 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -26,19 +26,110 @@ import dataclasses import math import typing as t +from dataclasses import field import apache_beam import dask.bag as db +from apache_beam import DoFn +from apache_beam import TaggedOutput from apache_beam.pipeline import AppliedPTransform +from apache_beam.runners.common import DoFnContext +from apache_beam.runners.common import DoFnInvoker +from apache_beam.runners.common import DoFnSignature +from apache_beam.runners.common import Receiver +from apache_beam.runners.common import _OutputHandler from apache_beam.runners.dask.overrides import _Create from apache_beam.runners.dask.overrides import _Flatten from apache_beam.runners.dask.overrides import _GroupByKeyOnly +from apache_beam.transforms.sideinputs import SideInputMap +from apache_beam.transforms.window import GlobalWindow +from apache_beam.transforms.window import TimestampedValue +from apache_beam.transforms.window import WindowFn +from apache_beam.utils.windowed_value import WindowedValue +# Inputs to DaskOps. OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None] +OpSide = t.Optional[t.Sequence[SideInputMap]] + +# Value types for PCollections (possibly Windowed Values). +PCollVal = t.Union[WindowedValue, t.Any] + + +def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue: + """Wraps a value (item) inside a Window.""" + if isinstance(item, TaggedOutput): + item = item.value + + if isinstance(item, WindowedValue): + windowed_value = item + elif isinstance(item, TimestampedValue): + assign_context = WindowFn.AssignContext(item.timestamp, item.value) + windowed_value = WindowedValue( + item.value, item.timestamp, tuple(window_fn.assign(assign_context))) + else: + windowed_value = WindowedValue(item, 0, (GlobalWindow(), )) + + return windowed_value + + +def defenestrate(x): + """Extracts the underlying item from a Window.""" + if isinstance(x, WindowedValue): + return x.value + return x + + +@dataclasses.dataclass +class DaskBagWindowedIterator: + """Iterator for `apache_beam.transforms.sideinputs.SideInputMap`""" + + bag: db.Bag + window_fn: WindowFn + + def __iter__(self): + # FIXME(cisaacstern): list() is likely inefficient, since it presumably + # materializes the full result before iterating over it. doing this for + # now as a proof-of-concept. can we can generate results incrementally? + for result in list(self.bag): + yield get_windowed_value(result, self.window_fn) + + +@dataclasses.dataclass +class TaggingReceiver(Receiver): + """A Receiver that handles tagged `WindowValue`s.""" + tag: str + values: t.List[PCollVal] + + def receive(self, windowed_value: WindowedValue): + if self.tag: + output = TaggedOutput(self.tag, windowed_value) + else: + output = windowed_value + self.values.append(output) + + +@dataclasses.dataclass +class OneReceiver(dict): + """A Receiver that tags value via dictionary lookup key.""" + values: t.List[PCollVal] = field(default_factory=list) + + def __missing__(self, key): + if key not in self: + self[key] = TaggingReceiver(key, self.values) + return self[key] @dataclasses.dataclass class DaskBagOp(abc.ABC): + """Abstract Base Class for all Dask-supported Operations. + + All DaskBagOps must support an `apply()` operation, which invokes the dask + bag upon the previous op's input. + + Attributes + applied: The underlying `AppliedPTransform` which holds the code for the + target operation. + """ applied: AppliedPTransform @property @@ -46,17 +137,19 @@ def transform(self): return self.applied.transform @abc.abstractmethod - def apply(self, input_bag: OpInput) -> db.Bag: + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: pass class NoOp(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: + """An identity on a dask bag: returns the input as-is.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: return input_bag class Create(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: + """The beginning of a Beam pipeline; the input must be `None`.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.transform) items = original_transform.values @@ -66,42 +159,95 @@ def apply(self, input_bag: OpInput) -> db.Bag: 1, math.ceil(math.sqrt(len(items)) / math.sqrt(100)))) +def apply_dofn_to_bundle( + items, do_fn_invoker_args, do_fn_invoker_kwargs, tagged_receivers): + """Invokes a DoFn within a bundle, implemented as a Dask partition.""" + + do_fn_invoker = DoFnInvoker.create_invoker( + *do_fn_invoker_args, **do_fn_invoker_kwargs) + + do_fn_invoker.invoke_setup() + do_fn_invoker.invoke_start_bundle() + + for it in items: + do_fn_invoker.invoke_process(it) + + results = [v.value for v in tagged_receivers.values] + + do_fn_invoker.invoke_finish_bundle() + do_fn_invoker.invoke_teardown() + + return results + + class ParDo(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: - transform = t.cast(apache_beam.ParDo, self.transform) - return input_bag.map( - transform.fn.process, *transform.args, **transform.kwargs).flatten() + """Apply a pure function in an embarrassingly-parallel way. + This consumes a sequence of items and returns a sequence of items. + """ + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: + transform = t.cast(apache_beam.ParDo, self.transform) -class Map(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: - transform = t.cast(apache_beam.Map, self.transform) - return input_bag.map( - transform.fn.process, *transform.args, **transform.kwargs) + args, kwargs = transform.raw_side_inputs + args = list(args) + main_input = next(iter(self.applied.main_inputs.values())) + window_fn = main_input.windowing.windowfn if hasattr( + main_input, "windowing") else None + + tagged_receivers = OneReceiver() + + do_fn_invoker_args = [ + DoFnSignature(transform.fn), + _OutputHandler( + window_fn=window_fn, + main_receivers=tagged_receivers[None], + tagged_receivers=tagged_receivers, + per_element_output_counter=None, + output_batch_converter=None, + process_yields_batches=False, + process_batch_yields_elements=False), + ] + do_fn_invoker_kwargs = dict( + context=DoFnContext(transform.label, state=None), + side_inputs=side_inputs, + input_args=args, + input_kwargs=kwargs, + user_state_context=None, + bundle_finalizer_param=DoFn.BundleFinalizerParam(), + ) + + return input_bag.map(get_windowed_value, window_fn).map_partitions( + apply_dofn_to_bundle, + do_fn_invoker_args, + do_fn_invoker_kwargs, + tagged_receivers, + ) class GroupByKey(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: + """Group a PCollection into a mapping of keys to elements.""" + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: def key(item): return item[0] def value(item): k, v = item - return k, [elm[1] for elm in v] + return k, [defenestrate(elm[1]) for elm in v] return input_bag.groupby(key).map(value) class Flatten(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: - assert type(input_bag) is list, 'Must take a sequence of bags!' + """Produces a flattened bag from a collection of bags.""" + def apply( + self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag: + assert isinstance(input_bag, list), 'Must take a sequence of bags!' return db.concat(input_bag) TRANSLATIONS = { _Create: Create, apache_beam.ParDo: ParDo, - apache_beam.Map: Map, _GroupByKeyOnly: GroupByKey, _Flatten: Flatten, } diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 21561e1bf6a9..4922b61169d1 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -64,6 +64,7 @@ excluded_patterns=( 'apache_beam/runners/portability/' 'apache_beam/runners/test/' 'apache_beam/runners/worker/' + 'apache_beam/runners/dask/transform_evaluator.*' 'apache_beam/testing/benchmarks/chicago_taxi/' 'apache_beam/testing/benchmarks/cloudml/' 'apache_beam/testing/benchmarks/inference/' @@ -134,7 +135,7 @@ autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", - "sentence_transformers", "redis", "tensorflow_text", "feast", + "sentence_transformers", "redis", "tensorflow_text", "feast", "dask", ] # Allow a special section for documenting DataFrame API diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 9ae5d3153f51..3b45cbf82fc1 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -512,8 +512,15 @@ def get_portability_package_data(): ], 'dataframe': dataframe_dependency, 'dask': [ - 'dask >= 2022.6', - 'distributed >= 2022.6', + 'distributed >= 2024.4.2', + 'dask >= 2024.4.2', + # For development, 'distributed >= 2023.12.1' should work with + # the above dask PR, however it can't be installed as part of + # a single `pip` call, since distributed releases are pinned to + # specific dask releases. As a workaround, distributed can be + # installed first, and then `.[dask]` installed second, with the + # `--update` / `-U` flag to replace the dask release brought in + # by distributed. ], 'yaml': [ 'docstring-parser>=0.15,<1.0', diff --git a/sdks/python/test-suites/tox/common.gradle b/sdks/python/test-suites/tox/common.gradle index df42a2c384c2..01265a6eeff5 100644 --- a/sdks/python/test-suites/tox/common.gradle +++ b/sdks/python/test-suites/tox/common.gradle @@ -31,7 +31,6 @@ test.dependsOn "testPy${pythonVersionSuffix}ML" // toxTask "testPy${pythonVersionSuffix}Dask", "py${pythonVersionSuffix}-dask", "${posargs}" // test.dependsOn "testPy${pythonVersionSuffix}Dask" - project.tasks.register("preCommitPy${pythonVersionSuffix}") { // Since codecoverage reports will always be generated for py38, // all tests will be exercised. diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index c7713498d87d..ad5d7ec5505e 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -109,9 +109,20 @@ commands = bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" [testenv:py{39,310,311,312}-dask] -extras = test,dask +extras = test,dask,dataframes +commands_pre = + pip install 'distributed>=2024.4.2' 'dask>=2024.4.2' commands = - bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" + bash {toxinidir}/scripts/run_pytest.sh {envname} {toxinidir}/apache_beam/runners/dask/ + +[testenv:py{39,310,311,312}-win-dask] +commands_pre = + pip install 'distributed>=2024.4.2' 'dask>=2024.4.2' +commands = + python apache_beam/examples/complete/autocomplete_test.py + bash {toxinidir}/scripts/run_pytest.sh {envname} {toxinidir}/apache_beam/runners/dask/ +install_command = {envbindir}/python.exe {envbindir}/pip.exe install --retries 10 {opts} {packages} +list_dependencies_command = {envbindir}/python.exe {envbindir}/pip.exe freeze [testenv:py39-cloudcoverage] deps = @@ -394,7 +405,7 @@ commands = [testenv:py39-tensorflow-212] deps = - 212: + 212: tensorflow>=2.12rc1,<2.13 # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 pydantic<2.7