From ab5c0695530fb73568cf4810da101a683da56874 Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Wed, 13 Nov 2024 18:31:36 -0500 Subject: [PATCH] [yaml] Fix examples catalog tests (#33027) --- .../yaml/examples/testing/examples_test.py | 105 ++++++++++++++++-- .../{ => transforms}/io/spanner_read.yaml | 23 ++-- .../{ => transforms}/io/spanner_write.yaml | 17 +-- .../apache_beam/yaml/generate_yaml_docs.py | 2 +- sdks/python/apache_beam/yaml/yaml_errors.py | 88 +++++++++++++++ sdks/python/apache_beam/yaml/yaml_io.py | 6 +- sdks/python/apache_beam/yaml/yaml_mapping.py | 78 +------------ sdks/python/apache_beam/yaml/yaml_provider.py | 7 +- .../yaml/yaml_transform_scope_test.py | 6 +- .../yaml/yaml_transform_unit_test.py | 2 +- 10 files changed, 224 insertions(+), 110 deletions(-) rename sdks/python/apache_beam/yaml/examples/{ => transforms}/io/spanner_read.yaml (73%) rename sdks/python/apache_beam/yaml/examples/{ => transforms}/io/spanner_write.yaml (69%) create mode 100644 sdks/python/apache_beam/yaml/yaml_errors.py diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 6c8efac980aa..3b497ed1efab 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -40,8 +40,8 @@ def check_output(expected: List[str]): - def _check_inner(actual: PCollection[str]): - formatted_actual = actual | beam.Map( + def _check_inner(actual: List[PCollection[str]]): + formatted_actual = actual | beam.Flatten() | beam.Map( lambda row: str(beam.Row(**row._asdict()))) assert_matches_stdout(formatted_actual, expected) @@ -59,6 +59,57 @@ def products_csv(): ]) +def spanner_data(): + return [{ + 'shipment_id': 'S1', + 'customer_id': 'C1', + 'shipment_date': '2023-05-01', + 'shipment_cost': 150.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S2', + 'customer_id': 'C2', + 'shipment_date': '2023-06-12', + 'shipment_cost': 300.0, + 'customer_name': 'Bob', + 'customer_email': 'bob@example.com' + }, + { + 'shipment_id': 'S3', + 'customer_id': 'C1', + 'shipment_date': '2023-05-10', + 'shipment_cost': 20.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S4', + 'customer_id': 'C4', + 'shipment_date': '2024-07-01', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }, + { + 'shipment_id': 'S5', + 'customer_id': 'C5', + 'shipment_date': '2023-05-09', + 'shipment_cost': 300.0, + 'customer_name': 'Erin', + 'customer_email': 'erin@example.com' + }, + { + 'shipment_id': 'S6', + 'customer_id': 'C4', + 'shipment_date': '2024-07-02', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }] + + def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): @@ -84,9 +135,12 @@ def test_yaml_example(self): pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( 'options', {})))) as p: - actual = yaml_transform.expand_pipeline(p, pipeline_spec) - if not actual: - actual = p.transforms_stack[0].parts[-1].outputs[None] + actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] + if not actual[0]: + actual = list(p.transforms_stack[0].parts[-1].outputs.values()) + for transform in p.transforms_stack[0].parts[:-1]: + if transform.transform.label == 'log_for_testing': + actual += list(transform.outputs.values()) check_output(expected)(actual) return test_yaml_example @@ -155,9 +209,13 @@ def _wordcount_test_preprocessor( env.input_file('kinglear.txt', '\n'.join(lines))) -@YamlExamplesTestSuite.register_test_preprocessor( - ['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml']) -def _file_io_write_test_preprocessor( +@YamlExamplesTestSuite.register_test_preprocessor([ + 'test_simple_filter_yaml', + 'test_simple_filter_and_combine_yaml', + 'test_spanner_read_yaml', + 'test_spanner_write_yaml' +]) +def _io_write_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): if pipeline := test_spec.get('pipeline', None): @@ -166,8 +224,8 @@ def _file_io_write_test_preprocessor( transform['type'] = 'LogForTesting' transform['config'] = { k: v - for k, - v in transform.get('config', {}).items() if k.startswith('__') + for (k, v) in transform.get('config', {}).items() + if (k.startswith('__') or k == 'error_handling') } return test_spec @@ -191,7 +249,30 @@ def _file_io_read_test_preprocessor( return test_spec +@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) +def _spanner_io_read_test_preprocessor( + test_spec: dict, expected: List[str], env: TestEnvironment): + + if pipeline := test_spec.get('pipeline', None): + for transform in pipeline.get('transforms', []): + if transform.get('type', '').startswith('ReadFromSpanner'): + config = transform['config'] + instance, database = config['instance_id'], config['database_id'] + if table := config.get('table', None) is None: + table = config.get('query', '').split('FROM')[-1].strip() + transform['type'] = 'Create' + transform['config'] = { + k: v + for k, v in config.items() if k.startswith('__') + } + transform['config']['elements'] = INPUT_TABLES[( + str(instance), str(database), str(table))] + + return test_spec + + INPUT_FILES = {'products.csv': products_csv()} +INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) ExamplesTest = YamlExamplesTestSuite( @@ -205,6 +286,10 @@ def _file_io_read_test_preprocessor( 'AggregationExamplesTest', os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run() +IOTest = YamlExamplesTestSuite( + 'IOExamplesTest', os.path.join(YAML_DOCS_DIR, + '../transforms/io/*.yaml')).run() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml similarity index 73% rename from sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml index c86d42c1e0c6..26f68b68d931 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml @@ -18,10 +18,10 @@ pipeline: transforms: - # Reading data from a Spanner database. The table used here has the following columns: - # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) - # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query - # A table with a list of columns can also be specified instead of a query + # Reading data from a Spanner database. The table used here has the following columns: + # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) + # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query + # A table with a list of columns can also be specified instead of a query - type: ReadFromSpanner name: ReadShipments config: @@ -30,8 +30,8 @@ pipeline: database_id: 'shipment' query: 'SELECT * FROM shipments' - # Filtering the data based on a specific condition - # Here, the condition is used to keep only the rows where the customer_id is 'C1' + # Filtering the data based on a specific condition + # Here, the condition is used to keep only the rows where the customer_id is 'C1' - type: Filter name: FilterShipments input: ReadShipments @@ -39,9 +39,9 @@ pipeline: language: python keep: "customer_id == 'C1'" - # Mapping the data fields and applying transformations - # A new field 'shipment_cost_category' is added with a custom transformation - # A callable is defined to categorize shipment cost + # Mapping the data fields and applying transformations + # A new field 'shipment_cost_category' is added with a custom transformation + # A callable is defined to categorize shipment cost - type: MapToFields name: MapFieldsForSpanner input: FilterShipments @@ -65,7 +65,7 @@ pipeline: else: return 'High Cost' - # Writing the transformed data to a CSV file + # Writing the transformed data to a CSV file - type: WriteToCsv name: WriteBig input: MapFieldsForSpanner @@ -73,8 +73,7 @@ pipeline: path: shipments.csv - # On executing the above pipeline, a new CSV file is created with the following records - +# On executing the above pipeline, a new CSV file is created with the following records # Expected: # Row(shipment_id='S1', customer_id='C1', shipment_date='2023-05-01', shipment_cost=150.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Medium Cost') # Row(shipment_id='S3', customer_id='C1', shipment_date='2023-05-10', shipment_cost=20.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Low Cost') diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml similarity index 69% rename from sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml index 74ac35de260f..1667fcfcc163 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml @@ -18,8 +18,8 @@ pipeline: transforms: - # Step 1: Creating rows to be written to Spanner - # The element names correspond to the column names in the Spanner table + # Step 1: Creating rows to be written to Spanner + # The element names correspond to the column names in the Spanner table - type: Create name: CreateRows config: @@ -31,10 +31,10 @@ pipeline: customer_name: "Erin" customer_email: "erin@example.com" - # Step 2: Writing the created rows to a Spanner database - # We require the project ID, instance ID, database ID and table ID to connect to Spanner - # Error handling can be specified optionally to ensure any failed operations aren't lost - # The failed data is passed on in the pipeline and can be handled + # Step 2: Writing the created rows to a Spanner database + # We require the project ID, instance ID, database ID and table ID to connect to Spanner + # Error handling can be specified optionally to ensure any failed operations aren't lost + # The failed data is passed on in the pipeline and can be handled - type: WriteToSpanner name: WriteSpanner input: CreateRows @@ -46,8 +46,11 @@ pipeline: error_handling: output: my_error_output - # Step 3: Writing the failed records to a JSON file + # Step 3: Writing the failed records to a JSON file - type: WriteToJson input: WriteSpanner.my_error_output config: path: errors.json + +# Expected: +# Row(shipment_id='S5', customer_id='C5', shipment_date='2023-05-09', shipment_cost=300.0, customer_name='Erin', customer_email='erin@example.com') diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 4088e17afe2c..2123c7a9f202 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -30,7 +30,7 @@ from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider -from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig +from apache_beam.yaml.yaml_errors import ErrorHandlingConfig def _singular(name): diff --git a/sdks/python/apache_beam/yaml/yaml_errors.py b/sdks/python/apache_beam/yaml/yaml_errors.py new file mode 100644 index 000000000000..c0d448473f42 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_errors.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import inspect +from typing import NamedTuple + +import apache_beam as beam +from apache_beam.typehints.row_type import RowTypeConstraint + + +class ErrorHandlingConfig(NamedTuple): + """Class to define Error Handling parameters. + + Args: + output (str): Name to use for the output error collection + """ + output: str + # TODO: Other parameters are valid here too, but not common to Java. + + +def exception_handling_args(error_handling_spec): + if error_handling_spec: + return { + 'dead_letter_tag' if k == 'output' else k: v + for (k, v) in error_handling_spec.items() + } + else: + return None + + +def map_errors_to_standard_format(input_type): + # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. + + return beam.Map( + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) + ).with_output_types( + RowTypeConstraint.from_fields([("element", input_type), ("msg", str), + ("stack", str)])) + + +def maybe_with_exception_handling(inner_expand): + def expand(self, pcoll): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, self._exception_handling_args) + return inner_expand(self, wrapped_pcoll).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + return expand + + +def maybe_with_exception_handling_transform_fn(transform_fn): + @functools.wraps(transform_fn) + def expand(pcoll, error_handling=None, **kwargs): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, exception_handling_args(error_handling)) + return transform_fn(wrapped_pcoll, **kwargs).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + original_signature = inspect.signature(transform_fn) + new_parameters = list(original_signature.parameters.values()) + error_handling_param = inspect.Parameter( + 'error_handling', + inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=ErrorHandlingConfig) + if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + new_parameters.insert(-1, error_handling_param) + else: + new_parameters.append(error_handling_param) + expand.__signature__ = original_signature.replace(parameters=new_parameters) + + return expand diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index 22663bdb8461..a6525aef9877 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -45,7 +45,7 @@ from apache_beam.portability.api import schema_pb2 from apache_beam.typehints import schemas from apache_beam.yaml import json_utils -from apache_beam.yaml import yaml_mapping +from apache_beam.yaml import yaml_errors from apache_beam.yaml import yaml_provider @@ -289,7 +289,7 @@ def formatter(row): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def read_from_pubsub( root, *, @@ -393,7 +393,7 @@ def mapper(msg): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def write_to_pubsub( pcoll, *, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 130bde75ed96..3bef1a0a1101 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -16,8 +16,6 @@ # """This module defines the basic MapToFields operation.""" -import functools -import inspect import itertools import re from collections import abc @@ -27,7 +25,6 @@ from typing import Dict from typing import List from typing import Mapping -from typing import NamedTuple from typing import Optional from typing import TypeVar from typing import Union @@ -41,7 +38,6 @@ from apache_beam.typehints import trivial_inference from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type -from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.typehints.schemas import typing_from_runner_api @@ -49,6 +45,10 @@ from apache_beam.yaml import json_utils from apache_beam.yaml import options from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_errors import exception_handling_args +from apache_beam.yaml.yaml_errors import map_errors_to_standard_format +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn from apache_beam.yaml.yaml_provider import dicts_to_rows # Import js2py package if it exists @@ -418,71 +418,6 @@ def checking_func(row): return func -class ErrorHandlingConfig(NamedTuple): - """Class to define Error Handling parameters. - - Args: - output (str): Name to use for the output error collection - """ - output: str - # TODO: Other parameters are valid here too, but not common to Java. - - -def exception_handling_args(error_handling_spec): - if error_handling_spec: - return { - 'dead_letter_tag' if k == 'output' else k: v - for (k, v) in error_handling_spec.items() - } - else: - return None - - -def _map_errors_to_standard_format(input_type): - # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. - - return beam.Map( - lambda x: beam.Row( - element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) - ).with_output_types( - RowTypeConstraint.from_fields([("element", input_type), ("msg", str), - ("stack", str)])) - - -def maybe_with_exception_handling(inner_expand): - def expand(self, pcoll): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, self._exception_handling_args) - return inner_expand(self, wrapped_pcoll).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - return expand - - -def maybe_with_exception_handling_transform_fn(transform_fn): - @functools.wraps(transform_fn) - def expand(pcoll, error_handling=None, **kwargs): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, exception_handling_args(error_handling)) - return transform_fn(wrapped_pcoll, **kwargs).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - original_signature = inspect.signature(transform_fn) - new_parameters = list(original_signature.parameters.values()) - error_handling_param = inspect.Parameter( - 'error_handling', - inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=ErrorHandlingConfig) - if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: - new_parameters.insert(-1, error_handling_param) - else: - new_parameters.append(error_handling_param) - expand.__signature__ = original_signature.replace(parameters=new_parameters) - - return expand - - class _StripErrorMetadata(beam.PTransform): """Strips error metadata from outputs returned via error handling. @@ -845,9 +780,8 @@ def split(element): splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T) result = {out: getattr(splits, out) for out in output_set} if error_output: - result[ - error_output] = result[error_output] | _map_errors_to_standard_format( - pcoll.element_type) + result[error_output] = result[error_output] | map_errors_to_standard_format( + pcoll.element_type) return result diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index ef2316f51f0e..a07638953551 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -63,6 +63,7 @@ from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn class Provider: @@ -876,8 +877,10 @@ def _parse_window_spec(spec): return beam.WindowInto(window_fn) @staticmethod + @beam.ptransform_fn + @maybe_with_exception_handling_transform_fn def log_for_testing( - level: Optional[str] = 'INFO', prefix: Optional[str] = ''): + pcoll, *, level: Optional[str] = 'INFO', prefix: Optional[str] = ''): """Logs each element of its input PCollection. The output of this transform is a copy of its input for ease of use in @@ -918,7 +921,7 @@ def log_and_return(x): logger(prefix + json.dumps(to_loggable_json_recursive(x))) return x - return "LogForTesting" >> beam.Map(log_and_return) + return pcoll | "LogForTesting" >> beam.Map(log_and_return) @staticmethod def create_builtin_provider(): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index f00403b07e2a..2a5a96aa42df 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -72,10 +72,12 @@ def test_get_pcollection_output(self): str(scope.get_pcollection("Create"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("Square"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("Square"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("LogForTesting"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("LogForTesting"))) self.assertTrue( scope.get_pcollection("Square") == scope.get_pcollection( diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 8c4b00351b24..bc0493509d5a 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -213,7 +213,7 @@ def test_expand_composite_transform_with_name_input(self): inputs={'elements': elements}) self.assertRegex( str(expand_composite_transform(spec, scope)['output']), - r"PCollection.*Composite/LogForTesting.*") + r"PCollection.*Composite/log_for_testing/LogForTesting.*") def test_expand_composite_transform_root(self): with new_pipeline() as p: