diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 3b497ed1efab..109e98410852 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -21,9 +21,11 @@ import os import random import unittest +from typing import Any from typing import Callable from typing import Dict from typing import List +from typing import Optional from typing import Union from unittest import mock @@ -34,11 +36,63 @@ from apache_beam.examples.snippets.util import assert_matches_stdout from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.yaml import yaml_provider from apache_beam.yaml import yaml_transform from apache_beam.yaml.readme_test import TestEnvironment from apache_beam.yaml.readme_test import replace_recursive +# Used to simulate Enrichment transform during tests +# The GitHub action that invokes these tests does not +# have gcp dependencies installed which is a prerequisite +# to apache_beam.transforms.enrichment.Enrichment as a top-level +# import. +@beam.ptransform.ptransform_fn +def test_enrichment( + pcoll, + enrichment_handler: str, + handler_config: Dict[str, Any], + timeout: Optional[float] = 30): + if enrichment_handler == 'BigTable': + row_key = handler_config['row_key'] + bt_data = INPUT_TABLES[( + 'BigTable', handler_config['instance_id'], handler_config['table_id'])] + products = {str(data[row_key]): data for data in bt_data} + + def _fn(row): + left = row._asdict() + right = products[str(left[row_key])] + left['product'] = left.get('product', None) or right + return beam.Row(**left) + elif enrichment_handler == 'BigQuery': + row_key = handler_config['fields'] + dataset, table = handler_config['table_name'].split('.')[-2:] + bq_data = INPUT_TABLES[('BigQuery', str(dataset), str(table))] + bq_data = { + tuple(str(data[key]) for key in row_key): data + for data in bq_data + } + + def _fn(row): + left = row._asdict() + right = bq_data[tuple(str(left[k]) for k in row_key)] + row = { + key: left.get(key, None) or right[key] + for key in {*left.keys(), *right.keys()} + } + return beam.Row(**row) + + else: + raise ValueError(f'{enrichment_handler} is not a valid enrichment_handler.') + + return pcoll | beam.Map(_fn) + + +TEST_PROVIDERS = { + 'TestEnrichment': test_enrichment, +} + + def check_output(expected: List[str]): def _check_inner(actual: List[PCollection[str]]): formatted_actual = actual | beam.Flatten() | beam.Map( @@ -59,7 +113,31 @@ def products_csv(): ]) -def spanner_data(): +def spanner_orders_data(): + return [{ + 'order_id': 1, + 'customer_id': 1001, + 'product_id': 2001, + 'order_date': '24-03-24', + 'order_amount': 150, + }, + { + 'order_id': 2, + 'customer_id': 1002, + 'product_id': 2002, + 'order_date': '19-04-24', + 'order_amount': 90, + }, + { + 'order_id': 3, + 'customer_id': 1003, + 'product_id': 2003, + 'order_date': '7-05-24', + 'order_amount': 110, + }] + + +def spanner_shipments_data(): return [{ 'shipment_id': 'S1', 'customer_id': 'C1', @@ -110,6 +188,44 @@ def spanner_data(): }] +def bigtable_data(): + return [{ + 'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2' + }, { + 'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4' + }, { + 'product_id': '3', 'product_name': 'pixel 7', 'product_stock': '20' + }, { + 'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10' + }, { + 'product_id': '5', 'product_name': 'pixel 11', 'product_stock': '3' + }, { + 'product_id': '6', 'product_name': 'pixel 12', 'product_stock': '7' + }, { + 'product_id': '7', 'product_name': 'pixel 13', 'product_stock': '8' + }, { + 'product_id': '8', 'product_name': 'pixel 14', 'product_stock': '3' + }] + + +def bigquery_data(): + return [{ + 'customer_id': 1001, + 'customer_name': 'Alice', + 'customer_email': 'alice@gmail.com' + }, + { + 'customer_id': 1002, + 'customer_name': 'Bob', + 'customer_email': 'bob@gmail.com' + }, + { + 'customer_id': 1003, + 'customer_name': 'Claire', + 'customer_email': 'claire@gmail.com' + }] + + def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): @@ -135,7 +251,11 @@ def test_yaml_example(self): pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( 'options', {})))) as p: - actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] + actual = [ + yaml_transform.expand_pipeline( + p, + pipeline_spec, [yaml_provider.InlineProvider(TEST_PROVIDERS)]) + ] if not actual[0]: actual = list(p.transforms_stack[0].parts[-1].outputs.values()) for transform in p.transforms_stack[0].parts[:-1]: @@ -213,7 +333,8 @@ def _wordcount_test_preprocessor( 'test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml', 'test_spanner_read_yaml', - 'test_spanner_write_yaml' + 'test_spanner_write_yaml', + 'test_enrich_spanner_with_bigquery_yaml' ]) def _io_write_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): @@ -249,7 +370,8 @@ def _file_io_read_test_preprocessor( return test_spec -@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) +@YamlExamplesTestSuite.register_test_preprocessor( + ['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml']) def _spanner_io_read_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): @@ -265,14 +387,42 @@ def _spanner_io_read_test_preprocessor( k: v for k, v in config.items() if k.startswith('__') } - transform['config']['elements'] = INPUT_TABLES[( - str(instance), str(database), str(table))] + elements = INPUT_TABLES[(str(instance), str(database), str(table))] + if config.get('query', None): + config['query'].replace('select ', + 'SELECT ').replace(' from ', ' FROM ') + columns = set( + ''.join(config['query'].split('SELECT ')[1:]).split( + ' FROM', maxsplit=1)[0].split(', ')) + if columns != {'*'}: + elements = [{ + column: element[column] + for column in element if column in columns + } for element in elements] + transform['config']['elements'] = elements + + return test_spec + + +@YamlExamplesTestSuite.register_test_preprocessor( + ['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml']) +def _enrichment_test_preprocessor( + test_spec: dict, expected: List[str], env: TestEnvironment): + if pipeline := test_spec.get('pipeline', None): + for transform in pipeline.get('transforms', []): + if transform.get('type', '').startswith('Enrichment'): + transform['type'] = 'TestEnrichment' return test_spec INPUT_FILES = {'products.csv': products_csv()} -INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} +INPUT_TABLES = { + ('shipment-test', 'shipment', 'shipments'): spanner_shipments_data(), + ('orders-test', 'order-database', 'orders'): spanner_orders_data(), + ('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(), + ('BigQuery', 'ALL_TEST', 'customers'): bigquery_data() +} YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) ExamplesTest = YamlExamplesTestSuite( @@ -290,6 +440,10 @@ def _spanner_io_read_test_preprocessor( 'IOExamplesTest', os.path.join(YAML_DOCS_DIR, '../transforms/io/*.yaml')).run() +MLTest = YamlExamplesTestSuite( + 'MLExamplesTest', os.path.join(YAML_DOCS_DIR, + '../transforms/ml/*.yaml')).run() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml b/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml new file mode 100644 index 000000000000..788b69de7857 --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/ml/bigtable_enrichment.yaml @@ -0,0 +1,55 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + type: chain + transforms: + + # Step 1: Creating a collection of elements that needs + # to be enriched. Here we are simulating sales data + - type: Create + config: + elements: + - sale_id: 1 + customer_id: 1 + product_id: 1 + quantity: 1 + + # Step 2: Enriching the data with Bigtable + # This specific bigtable stores product data in the below format + # product:product_id, product:product_name, product:product_stock + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 + + # Step 3: Logging for testing + # This is a simple way to view the enriched data + # We can also store it somewhere like a json file + - type: LogForTesting + +options: + yaml_experimental_features: Enrichment + +# Expected: +# Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml b/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml new file mode 100644 index 000000000000..e63b3105cc0c --- /dev/null +++ b/sdks/python/apache_beam/yaml/examples/transforms/ml/enrich_spanner_with_bigquery.yaml @@ -0,0 +1,102 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipeline: + transforms: + # Step 1: Read orders details from Spanner + - type: ReadFromSpanner + name: ReadOrders + config: + project_id: 'apache-beam-testing' + instance_id: 'orders-test' + database_id: 'order-database' + query: 'SELECT customer_id, product_id, order_date, order_amount FROM orders' + + # Step 2: Enrich order details with customers details from BigQuery + - type: Enrichment + name: Enriched + input: ReadOrders + config: + enrichment_handler: 'BigQuery' + handler_config: + project: "apache-beam-testing" + table_name: "apache-beam-testing.ALL_TEST.customers" + row_restriction_template: "customer_id = 1001 or customer_id = 1003" + fields: ["customer_id"] + + # Step 3: Map enriched values to Beam schema + # TODO: This should be removed when schema'd enrichment is available + - type: MapToFields + name: MapEnrichedValues + input: Enriched + config: + language: python + fields: + customer_id: + callable: 'lambda x: x.customer_id' + output_type: integer + customer_name: + callable: 'lambda x: x.customer_name' + output_type: string + customer_email: + callable: 'lambda x: x.customer_email' + output_type: string + product_id: + callable: 'lambda x: x.product_id' + output_type: integer + order_date: + callable: 'lambda x: x.order_date' + output_type: string + order_amount: + callable: 'lambda x: x.order_amount' + output_type: integer + + # Step 4: Filter orders with amount greater than 110 + - type: Filter + name: FilterHighValueOrders + input: MapEnrichedValues + config: + keep: "order_amount > 110" + language: "python" + + + # Step 6: Write processed order to another spanner table + # Note: Make sure to replace $VARS with your values. + - type: WriteToSpanner + name: WriteProcessedOrders + input: FilterHighValueOrders + config: + project_id: '$PROJECT' + instance_id: '$INSTANCE' + database_id: '$DATABASE' + table_id: '$TABLE' + error_handling: + output: my_error_output + + # Step 7: Handle write errors by writing to JSON + - type: WriteToJson + name: WriteErrorsToJson + input: WriteProcessedOrders.my_error_output + config: + path: 'errors.json' + +options: + yaml_experimental_features: Enrichment + +# Expected: +# Row(customer_id=1001, customer_name='Alice', customer_email='alice@gmail.com', product_id=2001, order_date='24-03-24', order_amount=150)