From ac7ae9f3d59bd90068d7c58963c6791e3ba08d67 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 6 Dec 2024 15:00:50 -0800 Subject: [PATCH] Allow annotations to be attached to transforms via a context. This API is offered both on the pipeline object itself, and also as a (thread-local) top-level function as one many not always have an easy reference to the pipeline. --- sdks/python/apache_beam/__init__.py | 2 +- sdks/python/apache_beam/pipeline.py | 106 ++++++++++++++---- sdks/python/apache_beam/pipeline_test.py | 45 ++++++++ .../runners/dataflow/dataflow_runner_test.py | 16 ++- 4 files changed, 140 insertions(+), 29 deletions(-) diff --git a/sdks/python/apache_beam/__init__.py b/sdks/python/apache_beam/__init__.py index af88934b0e71..650b639760dc 100644 --- a/sdks/python/apache_beam/__init__.py +++ b/sdks/python/apache_beam/__init__.py @@ -89,7 +89,7 @@ from apache_beam import metrics from apache_beam import typehints from apache_beam import version -from apache_beam.pipeline import Pipeline +from apache_beam.pipeline import * from apache_beam.transforms import * from apache_beam.pvalue import PCollection from apache_beam.pvalue import Row diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 6209ca1ddae8..eb08cd1115f9 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -54,6 +54,7 @@ import re import shutil import tempfile +import threading import unicodedata import uuid from collections import defaultdict @@ -109,7 +110,7 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.transforms import environments -__all__ = ['Pipeline', 'PTransformOverride'] +__all__ = ['Pipeline', 'transform_annotations'] class Pipeline(HasDisplayData): @@ -226,7 +227,9 @@ def __init__( self.runner = runner # Stack of transforms generated by nested apply() calls. The stack will # contain a root node as an enclosing (parent) node for top transforms. - self.transforms_stack = [AppliedPTransform(None, None, '', None)] + self.transforms_stack = [ + AppliedPTransform(None, None, '', None, None, None) + ] # Set of transform labels (full labels) applied to the pipeline. # If a transform is applied and the full label is already in the set # then the transform will have to be cloned with a new label. @@ -244,6 +247,7 @@ def __init__( self._display_data = display_data or {} self._error_handlers = [] + self._annotations_stack = [{}] def display_data(self): # type: () -> Dict[str, Any] @@ -268,6 +272,24 @@ def _current_transform(self): """Returns the transform currently on the top of the stack.""" return self.transforms_stack[-1] + @contextlib.contextmanager + def transform_annotations(self, **annotations): + """A context manager for attaching annotations to a set of transforms. + + All transforms applied while this context is active will have these + annotations attached. This includes sub-transforms applied within + composite transforms. + """ + self._annotations_stack.append({ + **self._annotations_stack[-1], **encode_annotations(annotations) + }) + yield + self._annotations_stack.pop() + + def _current_annotations(self): + """Returns the set of annotations that should be used on apply.""" + return {**_global_annotations_stack()[-1], **self._annotations_stack[-1]} + def _root_transform(self): # type: () -> AppliedPTransform @@ -316,7 +338,9 @@ def _replace_if_needed(self, original_transform_node): original_transform_node.parent, replacement_transform, original_transform_node.full_label, - original_transform_node.main_inputs) + original_transform_node.main_inputs, + None, + annotations=original_transform_node.annotations) # TODO(https://github.com/apache/beam/issues/21178): Merge rather # than override. @@ -741,7 +765,12 @@ def apply( 'returned %s from %s' % (transform, inputs, pvalueish)) current = AppliedPTransform( - self._current_transform(), transform, full_label, inputs) + self._current_transform(), + transform, + full_label, + inputs, + None, + annotations=self._current_annotations()) self._current_transform().add_part(current) try: @@ -1014,7 +1043,7 @@ def from_runner_api( root_transform_id, = proto.root_transform_ids p.transforms_stack = [context.transforms.get_by_id(root_transform_id)] else: - p.transforms_stack = [AppliedPTransform(None, None, '', None)] + p.transforms_stack = [AppliedPTransform(None, None, '', None, None, None)] # TODO(robertwb): These are only needed to continue construction. Omit? p.applied_labels = { t.unique_name @@ -1124,8 +1153,8 @@ def __init__( transform, # type: Optional[ptransform.PTransform] full_label, # type: str main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]] - environment_id=None, # type: Optional[str] - annotations=None, # type: Optional[Dict[str, bytes]] + environment_id, # type: Optional[str] + annotations, # type: Optional[Dict[str, bytes]] ): # type: (...) -> None self.parent = parent @@ -1149,24 +1178,11 @@ def __init__( transform.get_resource_hints()) if transform else { } # type: Dict[str, bytes] - if annotations is None and transform: - - def annotation_to_bytes(key, a: Any) -> bytes: - if isinstance(a, bytes): - return a - elif isinstance(a, str): - return a.encode('ascii') - elif isinstance(a, message.Message): - return a.SerializeToString() - else: - raise TypeError( - 'Unknown annotation type %r (type %s) for %s' % (a, type(a), key)) - + if transform: annotations = { - key: annotation_to_bytes(key, a) - for key, - a in transform.annotations().items() + **annotations or {}, **encode_annotations(transform.annotations()) } + self.annotations = annotations @property @@ -1478,6 +1494,50 @@ def _merge_outer_resource_hints(self): part._merge_outer_resource_hints() +def encode_annotations(annotations: Optional[Dict[str, Any]]): + """Encodes non-byte annotation values as bytes.""" + if not annotations: + return {} + + def annotation_to_bytes(key, a: Any) -> bytes: + if isinstance(a, bytes): + return a + elif isinstance(a, str): + return a.encode('ascii') + elif isinstance(a, message.Message): + return a.SerializeToString() + else: + raise TypeError( + 'Unknown annotation type %r (type %s) for %s' % (a, type(a), key)) + + return {key: annotation_to_bytes(key, a) for (key, a) in annotations.items()} + + +_global_annotations_stack_data = threading.local() + + +def _global_annotations_stack(): + try: + return _global_annotations_stack_data.stack + except AttributeError: + _global_annotations_stack_data.stack = [{}] + return _global_annotations_stack_data.stack + + +@contextlib.contextmanager +def transform_annotations(**annotations): + """A context manager for attaching annotations to a set of transforms. + + All transforms applied while this context is active will have these + annotations attached. This includes sub-transforms applied within + composite transforms. + """ + cur_stack = _global_annotations_stack() + cur_stack.append({**cur_stack[-1], **encode_annotations(annotations)}) + yield + cur_stack.pop() + + class PTransformOverride(metaclass=abc.ABCMeta): """For internal use only; no backwards-compatibility guarantees. diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 1c11f953c58d..8c334ef2c44d 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -1016,6 +1016,51 @@ def annotations(self): transform.annotations['proto'], some_proto.SerializeToString()) self.assertEqual(seen, 2) + def assertHasAnnotation(self, pipeline_proto, transform, key, value): + for transform_proto in pipeline_proto.components.transforms.values(): + if transform_proto.unique_name == transform: + self.assertIn(key, transform_proto.annotations.keys()) + self.assertEqual(transform_proto.annotations[key], value) + break + else: + self.fail( + "Unknown transform: %r not in %s" % ( + transform, + sorted([ + t.unique_name + for t in pipeline_proto.components.transforms.values() + ]))) + + def test_pipeline_context_annotations(self): + p = beam.Pipeline() + with p.transform_annotations(foo='first'): + pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1) + with p.transform_annotations(foo='second'): + pcoll | 'Second' >> beam.Map(lambda x: x * 2) + with p.transform_annotations(foo='nested', another='more'): + pcoll | 'Nested' >> beam.Map(lambda x: x * 3) + + proto = p.to_runner_api() + self.assertHasAnnotation(proto, 'First', 'foo', b'first') + self.assertHasAnnotation(proto, 'Second', 'foo', b'second') + self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested') + self.assertHasAnnotation(proto, 'Nested', 'another', b'more') + + def test_beam_context_annotations(self): + p = beam.Pipeline() + with beam.transform_annotations(foo='first'): + pcoll = p | beam.Create([1, 2, 3]) | 'First' >> beam.Map(lambda x: x + 1) + with beam.transform_annotations(foo='second'): + pcoll | 'Second' >> beam.Map(lambda x: x * 2) + with beam.transform_annotations(foo='nested', another='more'): + pcoll | 'Nested' >> beam.Map(lambda x: x * 3) + + proto = p.to_runner_api() + self.assertHasAnnotation(proto, 'First', 'foo', b'first') + self.assertHasAnnotation(proto, 'Second', 'foo', b'second') + self.assertHasAnnotation(proto, 'Nested', 'foo', b'nested') + self.assertHasAnnotation(proto, 'Nested', 'another', b'more') + def test_transform_ids(self): class MyPTransform(beam.PTransform): def expand(self, p): diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index b5568305ce65..fcbde090dbb3 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -272,7 +272,7 @@ def test_group_by_key_input_visitor_with_valid_inputs(self): pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: applied = AppliedPTransform( - None, beam.GroupByKey(), "label", {'pcoll': pcoll}) + None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None) applied.outputs[None] = PCollection(None) common.group_by_key_input_visitor().visit_transform(applied) self.assertEqual( @@ -291,7 +291,9 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self): for pcoll in [pcoll1, pcoll2]: with self.assertRaisesRegex(ValueError, err_msg): common.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll})) + AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}), + None, + None) def test_group_by_key_input_visitor_for_non_gbk_transforms(self): p = TestPipeline() @@ -299,7 +301,9 @@ def test_group_by_key_input_visitor_for_non_gbk_transforms(self): for transform in [beam.Flatten(), beam.Map(lambda x: x)]: pcoll.element_type = typehints.Any common.group_by_key_input_visitor().visit_transform( - AppliedPTransform(None, transform, "label", {'in': pcoll})) + AppliedPTransform(None, transform, "label", {'in': pcoll}), + None, + None) self.assertEqual(pcoll.element_type, typehints.Any) def test_flatten_input_with_visitor_with_single_input(self): @@ -319,7 +323,8 @@ def _test_flatten_input_visitor(self, input_type, output_type, num_inputs): output_pcoll = PCollection(p) output_pcoll.element_type = output_type - flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs) + flatten = AppliedPTransform( + None, beam.Flatten(), "label", inputs, None, None) flatten.add_output(output_pcoll, None) DataflowRunner.flatten_input_visitor().visit_transform(flatten) for _ in range(num_inputs): @@ -357,7 +362,8 @@ def test_side_input_visitor(self): z: (x, y, z), beam.pvalue.AsSingleton(pc), beam.pvalue.AsMultiMap(pc)) - applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc}) + applied_transform = AppliedPTransform( + None, transform, "label", {'pc': pc}, None, None) DataflowRunner.side_input_visitor().visit_transform(applied_transform) self.assertEqual(2, len(applied_transform.side_inputs)) self.assertEqual(