From b0d80139cbe089d33acee655b343d5697ba172a2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 19 Sep 2023 14:10:34 -0700 Subject: [PATCH] Remove special PyMap, PyFlatMap, etc. Prefer to use generic MapToFields instead. --- sdks/python/apache_beam/yaml/README.md | 44 +++--- sdks/python/apache_beam/yaml/yaml_mapping.py | 9 +- sdks/python/apache_beam/yaml/yaml_provider.py | 16 +- .../python/apache_beam/yaml/yaml_transform.py | 13 +- .../yaml/yaml_transform_scope_test.py | 48 ++---- .../apache_beam/yaml/yaml_transform_test.py | 139 +++++++++--------- .../yaml/yaml_transform_unit_test.py | 54 +------ 7 files changed, 132 insertions(+), 191 deletions(-) diff --git a/sdks/python/apache_beam/yaml/README.md b/sdks/python/apache_beam/yaml/README.md index 6e5b11ee0631..3ba78784c997 100644 --- a/sdks/python/apache_beam/yaml/README.md +++ b/sdks/python/apache_beam/yaml/README.md @@ -72,14 +72,15 @@ pipeline: - type: ReadFromCsv config: path: /path/to/input*.csv - - type: PyFilter + - type: Filter config: - keep: "lambda x: x.col3 > 100" + language: python + keep: "col3 > 100" input: ReadFromCsv - type: WriteToJson config: path: /path/to/output.json - input: PyFilter + input: Filter ``` or two. @@ -90,15 +91,16 @@ pipeline: - type: ReadFromCsv config: path: /path/to/input*.csv - - type: PyFilter + - type: Filter config: - keep: "lambda x: x.col3 > 100" + language: python + keep: "col3 > 100" input: ReadFromCsv - type: Sql name: MySqlTransform config: query: "select col1, count(*) as cnt from PCOLLECTION group by col1" - input: PyFilter + input: Filter - type: WriteToJson config: path: /path/to/output.json @@ -116,9 +118,10 @@ pipeline: - type: ReadFromCsv config: path: /path/to/input*.csv - - type: PyFilter + - type: Filter config: - keep: "lambda x: x.col3 > 100" + language: python + keep: "col3 > 100" - type: Sql name: MySqlTransform config: @@ -141,9 +144,10 @@ pipeline: path: /path/to/input*.csv transforms: - - type: PyFilter + - type: Filter config: - keep: "lambda x: x.col3 > 100" + language: python + keep: "col3 > 100" - type: Sql name: MySqlTransform @@ -185,11 +189,12 @@ pipeline: config: path: /path/to/all.json - - type: PyFilter + - type: Filter name: FilterToBig input: Sql config: - keep: "lambda x: x.col2 > 100" + language: python + keep: "col2 > 100" - type: WriteToCsv name: WriteBig @@ -231,15 +236,18 @@ pipeline: name: ExtraProcessingForBigRows input: Sql transforms: - - type: PyFilter + - type: Filter config: - keep: "lambda x: x.col2 > 100" - - type: PyFilter + language: python + keep: "col2 > 100" + - type: Filter config: - keep: "lambda x: len(x.col1) > 10" - - type: PyFilter + language: python + keep: "len(col1) > 10" + - type: Filter config: - keep: "lambda x: x.col1 > 'z'" + language: python + keep: "col1 > 'z'" sink: type: WriteToCsv config: diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 90234d079da9..38d31d8bfa9f 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -257,8 +257,13 @@ def with_exception_handling(self, **kwargs): @maybe_with_exception_handling_transform_fn def _PyJsFilter( pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None): + try: + input_schema = dict(named_fields_from_element_type(pcoll.element_type)) + except (TypeError, ValueError) as exn: + if is_expr(keep): + raise ValueError("Can only use expressions on a schema'd input.") from exn + input_schema = {} # unused - input_schema = dict(named_fields_from_element_type(pcoll.element_type)) if isinstance(keep, str) and keep in input_schema: keep_fn = lambda row: getattr(row, keep) else: @@ -273,7 +278,7 @@ def is_expr(v): def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'): try: input_schema = dict(named_fields_from_element_type(pcoll.element_type)) - except ValueError as exn: + except (TypeError, ValueError) as exn: if drop: raise ValueError("Can only drop fields on a schema'd input.") from exn if append: diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 52e1230a24f1..964877ac2a2e 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -22,6 +22,7 @@ import collections import hashlib import json +import logging import os import subprocess import sys @@ -509,19 +510,14 @@ def _parse_window_spec(spec): # TODO: Triggering, etc. return beam.WindowInto(window_fn) + def log_and_return(x): + logging.info(x) + return x + return InlineProvider( dict({ 'Create': create, - 'PyMap': lambda fn: beam.Map( - python_callable.PythonCallableWithSource(fn)), - 'PyMapTuple': lambda fn: beam.MapTuple( - python_callable.PythonCallableWithSource(fn)), - 'PyFlatMap': lambda fn: beam.FlatMap( - python_callable.PythonCallableWithSource(fn)), - 'PyFlatMapTuple': lambda fn: beam.FlatMapTuple( - python_callable.PythonCallableWithSource(fn)), - 'PyFilter': lambda keep: beam.Filter( - python_callable.PythonCallableWithSource(keep)), + 'LogForTesting': lambda: beam.Map(log_and_return), 'PyTransform': fully_qualified_named_transform, 'PyToRow': lambda fields: beam.Select( **{ diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 59f2e897fa21..f30a2ea26335 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -827,7 +827,8 @@ def ensure_transforms_have_providers(spec): if known_transforms: if spec['type'] not in known_transforms: raise ValueError( - f'Unknown type or missing provider for {identify_object(spec)}') + 'Unknown type or missing provider ' + f'for type {spec["type"]} for {identify_object(spec)}') return spec def preprocess_langauges(spec): @@ -869,12 +870,14 @@ class YamlTransform(beam.PTransform): def __init__(self, spec, providers={}): # pylint: disable=dangerous-default-value if isinstance(spec, str): spec = yaml.load(spec, Loader=SafeLineLoader) - # TODO(BEAM-26941): Validate as a transform. - self._providers = yaml_provider.merge_providers( - { + if isinstance(providers, dict): + providers = { key: yaml_provider.as_provider_list(key, value) for (key, value) in providers.items() - }, + } + # TODO(BEAM-26941): Validate as a transform. + self._providers = yaml_provider.merge_providers( + providers, yaml_provider.standard_providers()) self._spec = preprocess(spec, known_transforms=self._providers.keys()) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index 733f47583a7f..ed0988967e85 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -60,11 +60,9 @@ def test_get_pcollection_output(self): - type: Create config: elements: [0, 1, 3, 4] - - type: PyMap + - type: LogForTesting name: Square input: Create - config: - fn: "lambda x: x*x" ''' scope, spec = self.get_scope_by_spec(p, spec) @@ -77,30 +75,30 @@ def test_get_pcollection_output(self): "PCollection[Square.None]", str(scope.get_pcollection("Square"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("PyMap"))) + "PCollection[Square.None]", str(scope.get_pcollection("LogForTesting"))) self.assertTrue( - scope.get_pcollection("Square") == scope.get_pcollection("PyMap")) + scope.get_pcollection("Square") == scope.get_pcollection("LogForTesting")) def test_create_ptransform(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: spec = ''' transforms: - - type: PyMap + - type: Create config: - fn: "lambda x: x*x" + elements: [1, 2, 3] ''' scope, spec = self.get_scope_by_spec(p, spec) result = scope.create_ptransform(spec['transforms'][0], []) - self.assertIsInstance(result, beam.transforms.ParDo) - self.assertEqual(result.label, 'Map(lambda x: x*x)') + self.assertIsInstance(result, beam.transforms.Create) + self.assertEqual(result.label, 'Create') result_annotations = {**result.annotations()} target_annotations = { - 'yaml_type': 'PyMap', - 'yaml_args': '{"fn": "lambda x: x*x"}', + 'yaml_type': 'Create', + 'yaml_args': '{"elements": [1, 2, 3]}', 'yaml_provider': '{"type": "InlineProvider"}' } @@ -110,34 +108,6 @@ def test_create_ptransform(self): **result_annotations, **target_annotations }) - def test_create_ptransform_with_inputs(self): - with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( - pickle_library='cloudpickle')) as p: - spec = ''' - transforms: - - type: PyMap - config: - fn: "lambda x: x*x" - ''' - scope, spec = self.get_scope_by_spec(p, spec) - - result = scope.create_ptransform(spec['transforms'][0], []) - self.assertIsInstance(result, beam.transforms.ParDo) - self.assertEqual(result.label, 'Map(lambda x: x*x)') - - result_annotations = { - key: value - for (key, value) in result.annotations().items() - if key.startswith('yaml') - } - target_annotations = { - 'yaml_type': 'PyMap', - 'yaml_args': '{"fn": "lambda x: x*x"}', - 'yaml_provider': '{"type": "InlineProvider"}' - } - self.assertDictEqual(result_annotations, target_annotations) - - class TestProvider(yaml_provider.InlineProvider): def __init__(self, transform, name): super().__init__({ diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 136722052c33..118b399b772f 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -25,10 +25,52 @@ import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.utils import python_callable from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_transform import YamlTransform +class CreateTimestamped(beam.PTransform): + def __init__(self, elements): + self._elements = elements + + def expand(self, p): + return ( + p + | beam.Create(self._elements) + | beam.Map(lambda x: beam.transforms.window.TimestampedValue(x, x))) + + +class SumGlobally(beam.PTransform): + def expand(self, pcoll): + return pcoll | beam.CombineGlobally(sum).without_defaults() + + +class SizeLimiter(beam.PTransform): + def __init__(self, limit, error_handling): + self._limit = limit + self._error_handling = error_handling + + def expand(self, pcoll): + def raise_on_big(element): + if len(element) > self._limit: + raise ValueError(element) + else: + return element + + good, bad = pcoll | beam.Map(raise_on_big).with_exception_handling() + return {'small_elements': good, self._error_handling['output']: bad} + + +TEST_PROVIDERS = { + 'CreateInts': lambda elements: beam.Create(elements), + 'CreateTimestamped': CreateTimestamped, + 'SumGlobally': SumGlobally, + 'SizeLimiter': SizeLimiter, + 'PyMap': lambda fn: beam.Map(python_callable.PythonCallableWithSource(fn)), +} + + class YamlTransformE2ETest(unittest.TestCase): def test_composite(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( @@ -55,7 +97,8 @@ def test_composite(self): input: [Square, Cube] output: Flatten - ''') + ''', + providers=TEST_PROVIDERS) assert_that(result, equal_to([1, 4, 9, 1, 8, 27])) def test_chain_with_input(self): @@ -74,7 +117,7 @@ def test_chain_with_input(self): - type: PyMap config: fn: "lambda x: x + 41" - ''') + ''', providers=TEST_PROVIDERS) assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131])) def test_chain_with_source_sink(self): @@ -84,7 +127,7 @@ def test_chain_with_source_sink(self): ''' type: chain source: - type: Create + type: CreateInts config: elements: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] transforms: @@ -95,7 +138,7 @@ def test_chain_with_source_sink(self): type: PyMap config: fn: "lambda x: x + 41" - ''') + ''', providers=TEST_PROVIDERS) assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131])) def test_chain_with_root(self): @@ -105,7 +148,7 @@ def test_chain_with_root(self): ''' type: chain transforms: - - type: Create + - type: CreateInts config: elements: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - type: PyMap @@ -114,7 +157,7 @@ def test_chain_with_root(self): - type: PyMap config: fn: "lambda x: x + 41" - ''') + ''', providers=TEST_PROVIDERS) assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131])) def create_has_schema(self): @@ -132,10 +175,7 @@ def create_has_schema(self): language: python fields: repeated: a * b - - type: PyMap - config: - fn: "lambda x: x.repeated" - ''') + ''') | beam.Map(lambda x: x.repeated) assert_that(result, equal_to(['x', 'yy'])) def test_implicit_flatten(self): @@ -145,11 +185,11 @@ def test_implicit_flatten(self): ''' type: composite transforms: - - type: Create + - type: CreateInts name: CreateSmall config: elements: [1, 2, 3] - - type: Create + - type: CreateInts name: CreateBig config: elements: [100, 200] @@ -158,7 +198,7 @@ def test_implicit_flatten(self): config: fn: "lambda x: x * x" output: PyMap - ''') + ''', providers=TEST_PROVIDERS) assert_that(result, equal_to([1, 4, 9, 10000, 40000])) def test_csv_to_json(self): @@ -214,15 +254,15 @@ def test_name_is_not_ambiguous(self): name: Create config: elements: [0, 1, 3, 4] - - type: PyFilter - name: Filter + - type: PyMap + name: PyMap config: - keep: "lambda elem: elem > 2" + fn: "lambda elem: elem * elem" input: Create - output: Filter - ''') + output: PyMap + ''', providers=TEST_PROVIDERS) # No exception raised - assert_that(result, equal_to([3, 4])) + assert_that(result, equal_to([0, 1, 9, 16])) def test_name_is_ambiguous(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( @@ -237,18 +277,18 @@ def test_name_is_ambiguous(self): name: CreateData config: elements: [0, 1, 3, 4] - - type: PyFilter - name: PyFilter + - type: PyMap + name: PyMap config: - keep: "lambda elem: elem > 2" + fn: "lambda elem: elem + 2" input: CreateData - - type: PyFilter - name: AnotherFilter + - type: PyMap + name: AnotherMap config: - keep: "lambda elem: elem > 3" - input: PyFilter - output: AnotherFilter - ''') + fn: "lambda elem: elem + 3" + input: PyMap + output: AnotherMap + ''', providers=TEST_PROVIDERS) def test_annotations(self): t = LinearTransform(5, b=100) @@ -268,45 +308,6 @@ def test_annotations(self): assert_that(result, equal_to([100, 105, 110, 115])) -class CreateTimestamped(beam.PTransform): - def __init__(self, elements): - self._elements = elements - - def expand(self, p): - return ( - p - | beam.Create(self._elements) - | beam.Map(lambda x: beam.transforms.window.TimestampedValue(x, x))) - - -class SumGlobally(beam.PTransform): - def expand(self, pcoll): - return pcoll | beam.CombineGlobally(sum).without_defaults() - - -class SizeLimiter(beam.PTransform): - def __init__(self, limit, error_handling): - self._limit = limit - self._error_handling = error_handling - - def expand(self, pcoll): - def raise_on_big(element): - if len(element) > self._limit: - raise ValueError(element) - else: - return element - - good, bad = pcoll | beam.Map(raise_on_big).with_exception_handling() - return {'small_elements': good, self._error_handling['output']: bad} - - -TEST_PROVIDERS = { - 'CreateTimestamped': CreateTimestamped, - 'SumGlobally': SumGlobally, - 'SizeLimiter': SizeLimiter, -} - - class ErrorHandlingTest(unittest.TestCase): def test_error_handling_outputs(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( @@ -364,12 +365,12 @@ def test_mapping_errors(self): ''' type: composite transforms: - - type: Create + - type: CreateInts config: elements: [0, 1, 2, 4] - type: PyMap name: ToRow - input: Create + input: CreateInts config: fn: "lambda x: beam.Row(num=x, str='a' * x or 'bbb')" - type: Filter diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index d57a77d326fb..a47de306b626 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -201,12 +201,10 @@ def test_expand_composite_transform_with_name_input(self): type: composite input: elements transforms: - - type: PyMap + - type: LogForTesting input: input - config: - fn: 'lambda x: x*x' output: - PyMap + LogForTesting ''' elements = p | beam.Create(range(3)) scope, spec = self.get_scope_by_spec(p, spec, @@ -950,43 +948,6 @@ def test_ensure_errors_consumed_no_output_in_error_handling(self): with self.assertRaisesRegex(ValueError, r"Missing output.*"): ensure_errors_consumed(spec) - def test_expand_pipeline_with_string_spec(self): - with new_pipeline() as p: - spec = ''' - pipeline: - type: chain - transforms: - - type: Create - config: - elements: [1,2,3] - - type: PyMap - config: - fn: 'lambda x: x*x' - ''' - result = expand_pipeline(p, spec) - - self.assertIsInstance(result, PCollection) - self.assertEqual(str(result), 'PCollection[Map(lambda x: x*x).None]') - - def test_expand_pipeline_with_spec(self): - with new_pipeline() as p: - spec = ''' - pipeline: - type: chain - transforms: - - type: Create - config: - elements: [1,2,3] - - type: PyMap - config: - fn: 'lambda x: x*x' - ''' - spec = yaml.load(spec, Loader=SafeLineLoader) - result = expand_pipeline(p, spec) - - self.assertIsInstance(result, PCollection) - self.assertEqual(str(result), 'PCollection[Map(lambda x: x*x).None]') - def test_only_element(self): self.assertEqual(only_element((1, )), 1) @@ -1003,13 +964,12 @@ def test_init_with_string(self): transforms: - type: Create elements: [1,2,3] - - type: PyMap - fn: 'lambda x: x*x' + - type: LogForTesting ''' result = YamlTransform(spec, providers_dict) self.assertIn('p1', result._providers) # check for custom providers self.assertIn('p2', result._providers) # check for custom providers - self.assertIn('PyMap', result._providers) # check for standard provider + self.assertIn('LogForTesting', result._providers) # check for standard provider self.assertEqual(result._spec['type'], "composite") # preprocessed spec def test_init_with_dict(self): @@ -1019,13 +979,11 @@ def test_init_with_dict(self): - type: Create config: elements: [1,2,3] - - type: PyMap - config: - fn: 'lambda x: x*x' + - type: LogForTesting ''' spec = yaml.load(spec, Loader=SafeLineLoader) result = YamlTransform(spec) - self.assertIn('PyMap', result._providers) # check for standard provider + self.assertIn('LogForTesting', result._providers) # check for standard provider self.assertEqual(result._spec['type'], "composite") # preprocessed spec