Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[yaml] Add use cases for Enrichment transform in YAML #32289

Merged
merged 15 commits into from
Dec 16, 2024
7 changes: 6 additions & 1 deletion sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,12 @@ class BigQueryWrapper(object):

HISTOGRAM_METRIC_LOGGER = MetricLogger()

def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):
def __init__(
self,
client=None,
temp_dataset_id=None,
temp_table_ref=None,
project_id=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do anything with this project_id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I remember, running the code without assigning project_id resulted in the following error:
OSError: Project ID was not passed and could not be determined from the environment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is actually getting assigned, though, and other functions take in project_id as parameters. If it is passed in to initialize the object, I think it will be ignored as written

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Should I go ahead and remove it, then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, lets go ahead and get rid of it - thanks

self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions())
self.gcp_bq_client = client or gcp_bigquery.Client(
client_info=ClientInfo(
Expand Down
168 changes: 161 additions & 7 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import os
import random
import unittest
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from unittest import mock

Expand All @@ -34,11 +36,63 @@
from apache_beam.examples.snippets.util import assert_matches_stdout
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
from apache_beam.yaml.readme_test import TestEnvironment
from apache_beam.yaml.readme_test import replace_recursive


# Used to simulate Enrichment transform during tests
# The GitHub action that invokes these tests does not
# have gcp dependencies installed which is a prerequisite
# to apache_beam.transforms.enrichment.Enrichment as a top-level
# import.
@beam.ptransform.ptransform_fn
def test_enrichment(
pcoll,
enrichment_handler: str,
handler_config: Dict[str, Any],
timeout: Optional[float] = 30):
if enrichment_handler == 'BigTable':
row_key = handler_config['row_key']
bt_data = INPUT_TABLES[(
'BigTable', handler_config['instance_id'], handler_config['table_id'])]
products = {str(data[row_key]): data for data in bt_data}

def _fn(row):
left = row._asdict()
right = products[str(row[row_key])]
left['product'] = left.get('product', None) or right
return beam.Row(**row)
reeba212 marked this conversation as resolved.
Show resolved Hide resolved
elif enrichment_handler == 'BigQuery':
row_key = handler_config['fields']
dataset, table = handler_config['table_name'].split('.')[-2:]
bq_data = INPUT_TABLES[('BigQuery', str(dataset), str(table))]
products = {
tuple(str(data[key]) for key in row_key): data
for data in bq_data
}

def _fn(row):
left = row._asdict()
right = products[tuple(str(left[k]) for k in row_key)]
row = {
key: left.get(key, None) or right[key]
for key in {*left.keys(), *right.keys()}
}
return beam.Row(**row)
reeba212 marked this conversation as resolved.
Show resolved Hide resolved

else:
raise ValueError(f'{enrichment_handler} is not a valid enrichment_handler.')

return pcoll | beam.Map(_fn)


TEST_PROVIDERS = {
'TestEnrichment': test_enrichment,
}


def check_output(expected: List[str]):
def _check_inner(actual: List[PCollection[str]]):
formatted_actual = actual | beam.Flatten() | beam.Map(
Expand All @@ -59,7 +113,31 @@ def products_csv():
])


def spanner_data():
def spanner_orders_data():
return [{
'order_id': 1,
'customer_id': 1001,
'product_id': 2001,
'order_date': '24-03-24',
'order_amount': 150,
},
{
'order_id': 2,
'customer_id': 1002,
'product_id': 2002,
'order_date': '19-04-24',
'order_amount': 90,
},
{
'order_id': 3,
'customer_id': 1003,
'product_id': 2003,
'order_date': '7-05-24',
'order_amount': 110,
}]


def spanner_shipments_data():
return [{
'shipment_id': 'S1',
'customer_id': 'C1',
Expand Down Expand Up @@ -110,6 +188,44 @@ def spanner_data():
}]


def bigtable_data():
return [{
'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'
}, {
'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'
}, {
'product_id': '3', 'product_name': 'pixel 7', 'product_stock': '20'
}, {
'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'
}, {
'product_id': '5', 'product_name': 'pixel 11', 'product_stock': '3'
}, {
'product_id': '6', 'product_name': 'pixel 12', 'product_stock': '7'
}, {
'product_id': '7', 'product_name': 'pixel 13', 'product_stock': '8'
}, {
'product_id': '8', 'product_name': 'pixel 14', 'product_stock': '3'
}]


def bigquery_data():
return [{
'customer_id': 1001,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'customer_id': 1002,
'customer_name': 'Bob',
'customer_email': '[email protected]'
},
{
'customer_id': 1003,
'customer_name': 'Claire',
'customer_email': '[email protected]'
}]


def create_test_method(
pipeline_spec_file: str,
custom_preprocessors: List[Callable[..., Union[Dict, List]]]):
Expand All @@ -135,7 +251,11 @@ def test_yaml_example(self):
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {})))) as p:
actual = [yaml_transform.expand_pipeline(p, pipeline_spec)]
actual = [
yaml_transform.expand_pipeline(
p,
pipeline_spec, [yaml_provider.InlineProvider(TEST_PROVIDERS)])
]
if not actual[0]:
actual = list(p.transforms_stack[0].parts[-1].outputs.values())
for transform in p.transforms_stack[0].parts[:-1]:
Expand Down Expand Up @@ -213,7 +333,8 @@ def _wordcount_test_preprocessor(
'test_simple_filter_yaml',
'test_simple_filter_and_combine_yaml',
'test_spanner_read_yaml',
'test_spanner_write_yaml'
'test_spanner_write_yaml',
'test_enrich_spanner_with_bigquery_yaml'
])
def _io_write_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
Expand Down Expand Up @@ -249,7 +370,8 @@ def _file_io_read_test_preprocessor(
return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml'])
@YamlExamplesTestSuite.register_test_preprocessor(
['test_spanner_read_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _spanner_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

Expand All @@ -265,14 +387,42 @@ def _spanner_io_read_test_preprocessor(
k: v
for k, v in config.items() if k.startswith('__')
}
transform['config']['elements'] = INPUT_TABLES[(
str(instance), str(database), str(table))]
elements = INPUT_TABLES[(str(instance), str(database), str(table))]
if config.get('query', None):
config['query'].replace('select ',
'SELECT ').replace(' from ', ' FROM ')
columns = set(
''.join(config['query'].split('SELECT ')[1:]).split(
' FROM', maxsplit=1)[0])
reeba212 marked this conversation as resolved.
Show resolved Hide resolved
if columns != {'*'}:
elements = [{
column: element[column]
for column in element if column in columns
} for element in elements]
transform['config']['elements'] = elements

return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(
['test_bigtable_enrichment_yaml', 'test_enrich_spanner_with_bigquery_yaml'])
def _enrichment_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('Enrichment'):
transform['type'] = 'TestEnrichment'

return test_spec


INPUT_FILES = {'products.csv': products_csv()}
INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()}
INPUT_TABLES = {
('shipment-test', 'shipment', 'shipments'): spanner_shipments_data(),
('orders-test', 'order-database', 'orders'): spanner_orders_data(),
('BigTable', 'beam-test', 'bigtable-enrichment-test'): bigtable_data(),
('BigQuery', 'ALL_TEST', 'customers'): bigquery_data()
}

YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__))
ExamplesTest = YamlExamplesTestSuite(
Expand All @@ -290,6 +440,10 @@ def _spanner_io_read_test_preprocessor(
'IOExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/io/*.yaml')).run()

MLTest = YamlExamplesTestSuite(
'MLExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/ml/*.yaml')).run()

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

pipeline:
type: chain
transforms:

# Step 1: Creating a collection of elements that needs
# to be enriched. Here we are simulating sales data
- type: Create
config:
elements:
- sale_id: 1
customer_id: 1
product_id: 1
quantity: 1

# Step 2: Enriching the data with Bigtable
# This specific bigtable stores product data in the below format
# product:product_id, product:product_name, product:product_stock
- type: Enrichment
config:
enrichment_handler: 'BigTable'
handler_config:
project_id: 'apache-beam-testing'
instance_id: 'beam-test'
table_id: 'bigtable-enrichment-test'
row_key: 'product_id'
timeout: 30

# Step 3: Logging for testing
# This is a simple way to view the enriched data
# We can also store it somewhere like a json file
- type: LogForTesting

options:
yaml_experimental_features: Enrichment

# Expected:
# Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Loading
Loading