Skip to content

Commit

Permalink
Allow annotations to be attached to transforms via a context.
Browse files Browse the repository at this point in the history
This API is offered both on the pipeline object itself, and also
as a (thread-local) top-level function  as one many not always
have an easy reference to the pipeline.
  • Loading branch information
robertwb committed Dec 6, 2024
1 parent 78bde63 commit d491769
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 29 deletions.
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from apache_beam import metrics
from apache_beam import typehints
from apache_beam import version
from apache_beam.pipeline import Pipeline
from apache_beam.pipeline import *
from apache_beam.transforms import *
from apache_beam.pvalue import PCollection
from apache_beam.pvalue import Row
Expand Down
106 changes: 83 additions & 23 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import re
import shutil
import tempfile
import threading
import unicodedata
import uuid
from collections import defaultdict
Expand Down Expand Up @@ -109,7 +110,7 @@
from apache_beam.runners.runner import PipelineResult
from apache_beam.transforms import environments

__all__ = ['Pipeline', 'PTransformOverride']
__all__ = ['Pipeline', 'transform_annotations']


class Pipeline(HasDisplayData):
Expand Down Expand Up @@ -226,7 +227,9 @@ def __init__(
self.runner = runner
# Stack of transforms generated by nested apply() calls. The stack will
# contain a root node as an enclosing (parent) node for top transforms.
self.transforms_stack = [AppliedPTransform(None, None, '', None)]
self.transforms_stack = [
AppliedPTransform(None, None, '', None, None, None)
]
# Set of transform labels (full labels) applied to the pipeline.
# If a transform is applied and the full label is already in the set
# then the transform will have to be cloned with a new label.
Expand All @@ -244,6 +247,7 @@ def __init__(

self._display_data = display_data or {}
self._error_handlers = []
self._annotations_stack = [{}]

def display_data(self):
# type: () -> Dict[str, Any]
Expand All @@ -268,6 +272,24 @@ def _current_transform(self):
"""Returns the transform currently on the top of the stack."""
return self.transforms_stack[-1]

@contextlib.contextmanager
def transform_annotations(self, **annotations):
"""A context manager for attaching annotations to a set of transforms.
All transforms applied while this context is active will have these
annotations attached. This includes sub-transforms applied within
composite transforms.
"""
self._annotations_stack.append({
**self._annotations_stack[-1], **encode_annotations(annotations)
})
yield
self._annotations_stack.pop()

def _current_annotations(self):
"""Returns the set of annotations that should be used on apply."""
return {**_global_annotations_stack()[-1], **self._annotations_stack[-1]}

def _root_transform(self):
# type: () -> AppliedPTransform

Expand Down Expand Up @@ -316,7 +338,9 @@ def _replace_if_needed(self, original_transform_node):
original_transform_node.parent,
replacement_transform,
original_transform_node.full_label,
original_transform_node.main_inputs)
original_transform_node.main_inputs,
None,
annotations=original_transform_node.annotations)

# TODO(https://github.com/apache/beam/issues/21178): Merge rather
# than override.
Expand Down Expand Up @@ -741,7 +765,12 @@ def apply(
'returned %s from %s' % (transform, inputs, pvalueish))

current = AppliedPTransform(
self._current_transform(), transform, full_label, inputs)
self._current_transform(),
transform,
full_label,
inputs,
None,
annotations=self._current_annotations())
self._current_transform().add_part(current)

try:
Expand Down Expand Up @@ -1014,7 +1043,7 @@ def from_runner_api(
root_transform_id, = proto.root_transform_ids
p.transforms_stack = [context.transforms.get_by_id(root_transform_id)]
else:
p.transforms_stack = [AppliedPTransform(None, None, '', None)]
p.transforms_stack = [AppliedPTransform(None, None, '', None, None, None)]
# TODO(robertwb): These are only needed to continue construction. Omit?
p.applied_labels = {
t.unique_name
Expand Down Expand Up @@ -1124,8 +1153,8 @@ def __init__(
transform, # type: Optional[ptransform.PTransform]
full_label, # type: str
main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
environment_id=None, # type: Optional[str]
annotations=None, # type: Optional[Dict[str, bytes]]
environment_id, # type: Optional[str]
annotations, # type: Optional[Dict[str, bytes]]
):
# type: (...) -> None
self.parent = parent
Expand All @@ -1149,24 +1178,11 @@ def __init__(
transform.get_resource_hints()) if transform else {
} # type: Dict[str, bytes]

if annotations is None and transform:

def annotation_to_bytes(key, a: Any) -> bytes:
if isinstance(a, bytes):
return a
elif isinstance(a, str):
return a.encode('ascii')
elif isinstance(a, message.Message):
return a.SerializeToString()
else:
raise TypeError(
'Unknown annotation type %r (type %s) for %s' % (a, type(a), key))

if transform:
annotations = {
key: annotation_to_bytes(key, a)
for key,
a in transform.annotations().items()
**(annotations or {}), **encode_annotations(transform.annotations())
}

self.annotations = annotations

@property
Expand Down Expand Up @@ -1478,6 +1494,50 @@ def _merge_outer_resource_hints(self):
part._merge_outer_resource_hints()


def encode_annotations(annotations: Optional[Dict[str, Any]]):
"""Encodes non-byte annotation values as bytes."""
if not annotations:
return {}

def annotation_to_bytes(key, a: Any) -> bytes:
if isinstance(a, bytes):
return a
elif isinstance(a, str):
return a.encode('ascii')
elif isinstance(a, message.Message):
return a.SerializeToString()
else:
raise TypeError(
'Unknown annotation type %r (type %s) for %s' % (a, type(a), key))

return {key: annotation_to_bytes(key, a) for (key, a) in annotations.items()}


_global_annotations_stack_data = threading.local()


def _global_annotations_stack():
try:
return _global_annotations_stack_data.stack
except AttributeError:
_global_annotations_stack_data.stack = [{}]
return _global_annotations_stack_data.stack


@contextlib.contextmanager
def transform_annotations(**annotations):
"""A context manager for attaching annotations to a set of transforms.
All transforms applied while this context is active will have these
annotations attached. This includes sub-transforms applied within
composite transforms.
"""
cur_stack = _global_annotations_stack()
cur_stack.append({**cur_stack[-1], **encode_annotations(annotations)})
yield
cur_stack.pop()


class PTransformOverride(metaclass=abc.ABCMeta):
"""For internal use only; no backwards-compatibility guarantees.
Expand Down
45 changes: 45 additions & 0 deletions sdks/python/apache_beam/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,51 @@ def annotations(self):
transform.annotations['proto'], some_proto.SerializeToString())
self.assertEqual(seen, 2)

def assertHasAnnotation(self, pipeline_proto, transform, key, value):
for transform_proto in pipeline_proto.components.transforms.values():
if transform_proto.unique_name == transform:
self.assertIn(key, transform_proto.annotations.keys())
self.assertEqual(transform_proto.annotations[key], value)
break
else:
self.fail(
"Unknown transform: %r not in %s" % (
transform,
sorted([
t.unique_name
for t in pipeline_proto.components.transforms.values()
])))

def test_pipeline_context_annotations(self):
p = beam.Pipeline()
with p.transform_annotations(foo='first'):
pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1)
with p.transform_annotations(foo='second'):
pcoll | 'Second' >> beam.Map(lambda x: x * 2)
with p.transform_annotations(foo='nested', another='more'):
pcoll | 'Nested' >> beam.Map(lambda x: x * 3)

proto = p.to_runner_api()
self.assertHasAnnotation(proto, 'First', 'foo', b'first')
self.assertHasAnnotation(proto, 'Second', 'foo', b'second')
self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested')
self.assertHasAnnotation(proto, 'Nested', 'another', b'more')

def test_beam_context_annotations(self):
p = beam.Pipeline()
with beam.transform_annotations(foo='first'):
pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1)
with beam.transform_annotations(foo='second'):
pcoll | 'Second' >> beam.Map(lambda x: x * 2)
with beam.transform_annotations(foo='nested', another='more'):
pcoll | 'Nested' >> beam.Map(lambda x: x * 3)

proto = p.to_runner_api()
self.assertHasAnnotation(proto, 'First', 'foo', b'first')
self.assertHasAnnotation(proto, 'Second', 'foo', b'second')
self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested')
self.assertHasAnnotation(proto, 'Nested', 'another', b'more')

def test_transform_ids(self):
class MyPTransform(beam.PTransform):
def expand(self, p):
Expand Down
16 changes: 11 additions & 5 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
for pcoll in [pcoll1, pcoll2, pcoll3]:
applied = AppliedPTransform(
None, beam.GroupByKey(), "label", {'pcoll': pcoll})
None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None)
applied.outputs[None] = PCollection(None)
common.group_by_key_input_visitor().visit_transform(applied)
self.assertEqual(
Expand All @@ -291,15 +291,19 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self):
for pcoll in [pcoll1, pcoll2]:
with self.assertRaisesRegex(ValueError, err_msg):
common.group_by_key_input_visitor().visit_transform(
AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}),
None,
None)

def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
p = TestPipeline()
pcoll = PCollection(p)
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
pcoll.element_type = typehints.Any
common.group_by_key_input_visitor().visit_transform(
AppliedPTransform(None, transform, "label", {'in': pcoll}))
AppliedPTransform(None, transform, "label", {'in': pcoll}),
None,
None)
self.assertEqual(pcoll.element_type, typehints.Any)

def test_flatten_input_with_visitor_with_single_input(self):
Expand All @@ -319,7 +323,8 @@ def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
output_pcoll = PCollection(p)
output_pcoll.element_type = output_type

flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs)
flatten = AppliedPTransform(
None, beam.Flatten(), "label", inputs, None, None)
flatten.add_output(output_pcoll, None)
DataflowRunner.flatten_input_visitor().visit_transform(flatten)
for _ in range(num_inputs):
Expand Down Expand Up @@ -357,7 +362,8 @@ def test_side_input_visitor(self):
z: (x, y, z),
beam.pvalue.AsSingleton(pc),
beam.pvalue.AsMultiMap(pc))
applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
applied_transform = AppliedPTransform(
None, transform, "label", {'pc': pc}, None, None)
DataflowRunner.side_input_visitor().visit_transform(applied_transform)
self.assertEqual(2, len(applied_transform.side_inputs))
self.assertEqual(
Expand Down

0 comments on commit d491769

Please sign in to comment.