From eb5c73da86f87d65109ca84e917fdaaeb61f7cf1 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 2 Apr 2024 08:24:53 -0700 Subject: [PATCH] Add WaitOn transform to Python, analogous to Java's Wait.on. (#30807) --- sdks/python/apache_beam/transforms/util.py | 43 +++++++++++++++++-- .../apache_beam/transforms/util_test.py | 29 +++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c554bef6c36d..edf79b7c7981 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -36,7 +36,9 @@ from typing import TypeVar from typing import Union +import apache_beam as beam from apache_beam import coders +from apache_beam import pvalue from apache_beam import typehints from apache_beam.metrics import Metrics from apache_beam.portability import common_urns @@ -76,7 +78,6 @@ from apache_beam.utils.sharded_key import ShardedKey if TYPE_CHECKING: - from apache_beam import pvalue from apache_beam.runners.pipeline_context import PipelineContext __all__ = [ @@ -751,7 +752,7 @@ def flush_batch( class SharedKey(): """A class that holds a per-process UUID used to key elements for streaming - BatchElements. + BatchElements. """ def __init__(self): self.key = uuid.uuid4().hex @@ -763,7 +764,7 @@ def load_shared_key(): class WithSharedKey(DoFn): """A DoFn that keys elements with a per-process UUID. Used in streaming - BatchElements. + BatchElements. """ def __init__(self): self.shared_handle = shared.Shared() @@ -1644,3 +1645,39 @@ def _process(element): yield r return pcoll | FlatMap(_process) + + +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class WaitOn(PTransform): + """Delays processing of a {@link PCollection} until another set of + PCollections has finished being processed. For example:: + + X | WaitOn(Y, Z) | SomeTransform() + + would ensure that PCollections Y and Z (and hence their producing transforms) + are complete before SomeTransform gets executed on the elements of X. + This can be especially useful the waited-on PCollections are the outputs + of transforms that interact with external systems (such as writing to a + database or other sink). + + For streaming, this delay is done on a per-window basis, i.e. + the corresponding window of each waited-on PCollection is computed before + elements are passed through the main collection. + + This barrier often induces a fusion break. + """ + def __init__(self, *to_be_waited_on): + self._to_be_waited_on = to_be_waited_on + + def expand(self, pcoll): + # All we care about is the watermark, not the data itself. + # The GroupByKey avoids writing empty files for each shard, and also + # ensures the respective window finishes before advancing the timestamp. + sides = [ + pvalue.AsIter( + side + | f"WaitOn{ix}" >> (beam.FlatMap(lambda x: ()) | GroupByKey())) + for (ix, side) in enumerate(self._to_be_waited_on) + ] + return pcoll | beam.Map(lambda x, *unused_sides: x, *sides) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 5dfe166d3c31..53898d579983 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1787,6 +1787,35 @@ def test_split_without_empty(self): assert_that(result, equal_to(expected_result)) +class WaitOnTest(unittest.TestCase): + def test_find(self): + # We need shared reference that survives pickling. + def increment_global_counter(): + try: + value = getattr(beam, '_WAIT_ON_TEST_COUNTER', 0) + return value + finally: + setattr(beam, '_WAIT_ON_TEST_COUNTER', value + 1) + + def record(tag): + return f'Record({tag})' >> beam.Map( + lambda x: (x[0], tag, increment_global_counter())) + + with TestPipeline() as p: + start = p | beam.Create([(None, ), (None, )]) + x = start | record('x') + y = start | 'WaitForX' >> util.WaitOn(x) | record('y') + z = start | 'WaitForY' >> util.WaitOn(y) | record('z') + result = x | 'WaitForYZ' >> util.WaitOn(y, z) | record('result') + assert_that(x, equal_to([(None, 'x', 0), (None, 'x', 1)]), label='x') + assert_that(y, equal_to([(None, 'y', 2), (None, 'y', 3)]), label='y') + assert_that(z, equal_to([(None, 'z', 4), (None, 'z', 5)]), label='z') + assert_that( + result, + equal_to([(None, 'result', 6), (None, 'result', 7)]), + label='result') + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()