Skip to content

Commit

Permalink
Add WaitOn transform to Python, analogous to Java's Wait.on. (#30807)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Apr 2, 2024
1 parent 0e86118 commit eb5c73d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
43 changes: 40 additions & 3 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from typing import TypeVar
from typing import Union

import apache_beam as beam
from apache_beam import coders
from apache_beam import pvalue
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.portability import common_urns
Expand Down Expand Up @@ -76,7 +78,6 @@
from apache_beam.utils.sharded_key import ShardedKey

if TYPE_CHECKING:
from apache_beam import pvalue
from apache_beam.runners.pipeline_context import PipelineContext

__all__ = [
Expand Down Expand Up @@ -751,7 +752,7 @@ def flush_batch(

class SharedKey():
"""A class that holds a per-process UUID used to key elements for streaming
BatchElements.
BatchElements.
"""
def __init__(self):
self.key = uuid.uuid4().hex
Expand All @@ -763,7 +764,7 @@ def load_shared_key():

class WithSharedKey(DoFn):
"""A DoFn that keys elements with a per-process UUID. Used in streaming
BatchElements.
BatchElements.
"""
def __init__(self):
self.shared_handle = shared.Shared()
Expand Down Expand Up @@ -1644,3 +1645,39 @@ def _process(element):
yield r

return pcoll | FlatMap(_process)


@typehints.with_input_types(T)
@typehints.with_output_types(T)
class WaitOn(PTransform):
"""Delays processing of a {@link PCollection} until another set of
PCollections has finished being processed. For example::
X | WaitOn(Y, Z) | SomeTransform()
would ensure that PCollections Y and Z (and hence their producing transforms)
are complete before SomeTransform gets executed on the elements of X.
This can be especially useful the waited-on PCollections are the outputs
of transforms that interact with external systems (such as writing to a
database or other sink).
For streaming, this delay is done on a per-window basis, i.e.
the corresponding window of each waited-on PCollection is computed before
elements are passed through the main collection.
This barrier often induces a fusion break.
"""
def __init__(self, *to_be_waited_on):
self._to_be_waited_on = to_be_waited_on

def expand(self, pcoll):
# All we care about is the watermark, not the data itself.
# The GroupByKey avoids writing empty files for each shard, and also
# ensures the respective window finishes before advancing the timestamp.
sides = [
pvalue.AsIter(
side
| f"WaitOn{ix}" >> (beam.FlatMap(lambda x: ()) | GroupByKey()))
for (ix, side) in enumerate(self._to_be_waited_on)
]
return pcoll | beam.Map(lambda x, *unused_sides: x, *sides)
29 changes: 29 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,35 @@ def test_split_without_empty(self):
assert_that(result, equal_to(expected_result))


class WaitOnTest(unittest.TestCase):
def test_find(self):
# We need shared reference that survives pickling.
def increment_global_counter():
try:
value = getattr(beam, '_WAIT_ON_TEST_COUNTER', 0)
return value
finally:
setattr(beam, '_WAIT_ON_TEST_COUNTER', value + 1)

def record(tag):
return f'Record({tag})' >> beam.Map(
lambda x: (x[0], tag, increment_global_counter()))

with TestPipeline() as p:
start = p | beam.Create([(None, ), (None, )])
x = start | record('x')
y = start | 'WaitForX' >> util.WaitOn(x) | record('y')
z = start | 'WaitForY' >> util.WaitOn(y) | record('z')
result = x | 'WaitForYZ' >> util.WaitOn(y, z) | record('result')
assert_that(x, equal_to([(None, 'x', 0), (None, 'x', 1)]), label='x')
assert_that(y, equal_to([(None, 'y', 2), (None, 'y', 3)]), label='y')
assert_that(z, equal_to([(None, 'z', 4), (None, 'z', 5)]), label='z')
assert_that(
result,
equal_to([(None, 'result', 6), (None, 'result', 7)]),
label='result')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit eb5c73d

Please sign in to comment.