diff --git a/sdks/python/apache_beam/yaml/integration_tests.py b/sdks/python/apache_beam/yaml/integration_tests.py index aef0a3be4ad1..2b862e014bb4 100644 --- a/sdks/python/apache_beam/yaml/integration_tests.py +++ b/sdks/python/apache_beam/yaml/integration_tests.py @@ -92,12 +92,19 @@ def provider_sets(spec, require_available=False): """For transforms that are vended by multiple providers, yields all possible combinations of providers to use. """ - all_transform_types = set.union( - *( - set( - transform_types( - yaml_transform.preprocess(copy.deepcopy(p['pipeline'])))) - for p in spec['pipelines'])) + try: + for p in spec['pipelines']: + _ = yaml_transform.preprocess(copy.deepcopy(p['pipeline'])) + except Exception as exn: + print(exn) + all_transform_types = [] + else: + all_transform_types = set.union( + *( + set( + transform_types( + yaml_transform.preprocess(copy.deepcopy(p['pipeline'])))) + for p in spec['pipelines'])) def filter_to_available(t, providers): if require_available: diff --git a/sdks/python/apache_beam/yaml/tests/join.yaml b/sdks/python/apache_beam/yaml/tests/join.yaml new file mode 100644 index 000000000000..67013efe2295 --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/join.yaml @@ -0,0 +1,186 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the# Row(word='License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an# Row(word='AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +pipelines: + + - pipeline: + transforms: + - type: Create + name: A + config: + elements: + - {common: "x", a: 1} + - {common: "y", a: 2} + - {common: "z", a: 3} + + - type: Create + name: B + config: + elements: + - {common: "x", b: 10, other: "t"} + - {common: "y", b: 20, other: "u"} + - {common: "z", b: 30, other: "v"} + + - type: Create + name: C + config: + elements: + - {other: "t", c: 100} + - {other: "u", c: 200} + + - type: Join + input: + A: A + B: B + C: C + config: + type: inner + equalities: + - B: other + C: other + - A: common + B: common + fields: + A: [a] + B: [b] + C: [c] + + - type: AssertEqual + input: Join + config: + elements: + - {a: 1, b: 10, c: 100} + - {a: 2, b: 20, c: 200} + + + - pipeline: + transforms: + - type: Create + name: A + config: + elements: + - {common: "x", a: 1} + - {common: "y", a: 2} + - {common: "z", a: 3} + + - type: Create + name: B + config: + elements: + - {common: "x", b: 10} + - {common: "y", b: 20} + - {common: "z", b: 30} + + - type: Create + name: C + config: + elements: + - {common: "x", c: 100} + - {common: "y", c: 200} + + - type: Join + name: InnerJoin + input: + A: A + B: B + C: C + config: + type: inner + equalities: + - A: common + B: common + C: common + fields: + A: [common, a] + B: [b] + C: {c: c, common_c: common} + + - type: AssertEqual + input: InnerJoin + config: + elements: + - {common: "x", a: 1, b: 10, c: 100, common_c: "x"} + - {common: "y", a: 2, b: 20, c: 200, common_c: "y"} + + - type: Join + name: OuterJoin + input: + A: A + B: B + C: C + config: + type: outer + equalities: + - A: common + B: common + C: common + fields: + A: [common, a] + B: [b] + C: {c: c, common_c: common} + + - type: AssertEqual + input: OuterJoin + config: + elements: + - {common: "x", a: 1, b: 10, c: 100, common_c: "x"} + - {common: "y", a: 2, b: 20, c: 200, common_c: "y"} + - {common: "z", a: 3, b: 30, c: null, common_c: null} + + - type: Join + name: LeftJoin + input: + A: A + C: C + config: + type: left + equalities: + - A: common + C: common + fields: + A: [a] + C: [c] + + - type: AssertEqual + input: LeftJoin + config: + elements: + - {a: 1, c: 100} + - {a: 2, c: 200} + - {a: 3, c: null} + + - type: Join + name: RightJoin + input: + A: A + C: C + config: + type: right + equalities: + - A: common + C: common + fields: + A: [a] + C: [c] + + - type: AssertEqual + input: RightJoin + config: + elements: + - {a: 1, c: 100} + - {a: 2, c: 200} + diff --git a/sdks/python/apache_beam/yaml/yaml_join.py b/sdks/python/apache_beam/yaml/yaml_join.py new file mode 100644 index 000000000000..0b060b6a0ca8 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_join.py @@ -0,0 +1,281 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This module defines the Join operation.""" +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import apache_beam as beam +from apache_beam.yaml import yaml_provider + + +def _validate_input(pcolls): + error_prefix = f'Invalid input {pcolls} specified.' + if not isinstance(pcolls, dict): + raise ValueError(f'{error_prefix} It must be a dict.') + if len(pcolls) < 2: + raise ValueError( + f'{error_prefix} There should be at least 2 inputs to join.') + + +def _validate_type(type, pcolls): + error_prefix = f'Invalid value "{type}" for "type".' + if not isinstance(type, dict) and not isinstance(type, str): + raise ValueError(f'{error_prefix} It must be a dict or a str.') + if isinstance(type, dict): + error = ValueError( + f'{error_prefix} When specifying a dict for type, ' + f'it must follow this format: ' + f'{{"outer": [list of inputs to outer join]}}. ' + f'Example: {{"outer": ["input1", "input2"]}}') + if (len(type) != 1 or next(iter(type)) != 'outer' or + not isinstance(type['outer'], list)): + raise error + for input in type['outer']: + if input not in list(pcolls.keys()): + raise ValueError( + f'{error_prefix} An invalid input "{input}" was specified.') + if isinstance(type, str) and type not in ('inner', 'outer', 'left', 'right'): + raise ValueError( + f'{error_prefix} When specifying the value for type as a str, ' + f'it must be one of the following: "inner", "outer", "left", "right"') + + +def _validate_equalities(equalities, pcolls): + error_prefix = f'Invalid value "{equalities}" for "equalities".' + + valid_cols = { + name: set(dict(pcoll.element_type._fields).keys()) + for name, + pcoll in pcolls.items() + } + + if isinstance(equalities, str): + for cols in valid_cols.values(): + if equalities not in cols: + raise ValueError( + f'{error_prefix} When "equalities" is a str, ' + f'it must be a field name that exists in all the specified inputs.') + equality = {pcoll_tag: equalities for pcoll_tag in pcolls} + return [equality] + + if not isinstance(equalities, list): + raise ValueError(f'{error_prefix} It should be a str or a list.') + + input_edge_list = [] + for equality in equalities: + invalid_dict_error = ValueError( + f'{error_prefix} {equality} ' + f'should be a dict[str, str] containing at least 2 items.') + if not isinstance(equality, dict): + raise invalid_dict_error + if len(equality) < 2: + raise invalid_dict_error + + for pcoll_tag, col in equality.items(): + if pcoll_tag not in pcolls: + raise ValueError( + f'{error_prefix} "{pcoll_tag}" is not a specified alias in "input"') + if col not in valid_cols[pcoll_tag]: + raise ValueError( + f'{error_prefix} "{col}" is not a valid field in "{pcoll_tag}".') + + input_edge_list.append(tuple(equality.keys())) + + if not _is_connected(input_edge_list, len(pcolls)): + raise ValueError( + f'{error_prefix} ' + f'The provided equalities do not connect all of {list(pcolls.keys())}.') + + return equalities + + +def _parse_fields(tables, fields): + error_prefix = f'Invalid value "{fields}" for "fields".' + if not isinstance(fields, dict): + raise ValueError(f'{error_prefix} Fields must be a dict.') + output_fields = [] + named_columns = set() + for input, cols in fields.items(): + if input not in tables: + raise ValueError(f'An invalid input "{input}" was specified in "fields".') + if isinstance(cols, list): + for col in cols: + if not isinstance(col, str): + raise ValueError( + f'Invalid column "{col}" in "fields". Column name must be a str.') + if col in named_columns: + raise ValueError( + f'The field name "{col}" was specified more than once.') + output_fields.append(f'{input}.{col} AS {col}') + named_columns.add(col) + elif isinstance(cols, dict): + for k, v in cols.items(): + if k in named_columns: + raise ValueError( + f'The field name "{k}" was specified more than once.') + if not isinstance(v, str): + raise ValueError( + f'Invalid column "{v}" in "fields". Column name must be a str.') + output_fields.append(f'{input}.{v} AS {k}') + named_columns.add(k) + else: + raise ValueError( + f'{error_prefix} ' + f'For every input key in fields, ' + f'the value must either be a list or dict.') + for table in tables: + if table not in fields.keys(): + output_fields.append(f'{table}.*') + return output_fields + + +def _is_connected(edge_list, expected_node_count): + graph = {} + for edge_set in edge_list: + for u in edge_set: + if u not in graph: + graph[u] = set() + for v in edge_set: + if u != v: + graph[u].add(v) + + visited = set() + stack = [next(iter(graph))] + while stack: + node = stack.pop() + visited.add(node) + for neighbor in graph[node]: + if neighbor not in visited: + stack.append(neighbor) + + return len(visited) == len(graph) == expected_node_count + + +@beam.ptransform.ptransform_fn +def _SqlJoinTransform( + pcolls, + sql_transform_constructor, + type: Union[str, Dict[str, List]], + equalities: Union[str, List[Dict[str, str]]], + fields: Optional[Dict[str, Any]] = None): + """Joins two or more inputs using a specified condition. + + Args: + type: The type of join. Could be a string value in + ["inner", "left", "right", "outer"] that specifies the type of join to + be performed. For scenarios with multiple inputs to join where different + join types are desired, specify the inputs to be outer joined. For + example, {outer: [input1, input2]} means that input1 & input2 will be + outer joined using the conditions specified, while other inputs will be + inner joined. + equalities: The condition to join on. A list of sets of columns that should + be equal to fulfill the join condition. For the simple scenario to join + on the same column across all inputs and the column name is the same, + specify the column name as a str. + fields: The fields to be outputted. A mapping with the input alias as the + key and the fields in the input to be outputted. The value in the map + can either be a dictionary with the new field name as the key and the + original field name as the value (e.g new_field_name: field_name), or a + list of the fields to be outputted with their original names + (e.g [col1, col2, col3]), or an '*' indicating all fields in the input + will be outputted. If not specified, all fields from all inputs will be + outputted. + """ + + _validate_input(pcolls) + _validate_type(type, pcolls) + validate_equalities = _validate_equalities(equalities, pcolls) + + equalities_in_pairs = [] + for equality in validate_equalities: + inputs = list(equality.keys()) + first_input = inputs[0] + for input in inputs[1:]: + equalities_in_pairs.append({ + first_input: equality[first_input], input: equality[input] + }) + + tables = list(pcolls.keys()) + if isinstance(type, dict): + outer = type['outer'] + elif type == 'outer': + outer = tables + else: + outer = [] + first_table = tables[0] + conditioned = [first_table] + + def generate_join_type(left, right): + if left in outer and right in outer: + return 'FULL' + if left in outer: + return 'LEFT' + if right in outer: + return 'RIGHT' + if not outer: + return type.upper() + return 'INNER' + + prev_table = tables[0] + join_conditions = {} + for i in range(1, len(tables)): + curr_table = tables[i] + join_type = generate_join_type(prev_table, curr_table) + join_conditions[curr_table] = f' {join_type} JOIN {curr_table}' + prev_table = curr_table + + for equality in equalities_in_pairs: + left, right = equality.keys() + if left in conditioned and right in conditioned: + t = tables[max(tables.index(left), tables.index(right))] + join_conditions[t] = ( + f'{join_conditions[t]} ' + f'AND {left}.{equality[left]} = {right}.{equality[right]}') + elif left in conditioned: + join_conditions[right] = ( + f'{join_conditions[right]} ' + f'ON {left}.{equality[left]} = {right}.{equality[right]}') + conditioned.append(right) + elif right in conditioned: + join_conditions[left] = ( + f'{join_conditions[left]} ' + f'ON {left}.{equality[left]} = {right}.{equality[right]}') + conditioned.append(left) + else: + t = tables[max(tables.index(left), tables.index(right))] + join_conditions[t] = ( + f'{join_conditions[t]} ' + f'ON {left}.{equality[left]} = {right}.{equality[right]}') + conditioned.append(t) + + if fields: + selects = ', '.join(_parse_fields(tables, fields)) + else: + selects = '*' + query = f'SELECT {selects} FROM {first_table}' + query += ' '.join(condition for condition in join_conditions.values()) + return pcolls | sql_transform_constructor(query) + + +def create_join_providers(): + return [ + yaml_provider.SqlBackedProvider({'Join': _SqlJoinTransform}), + ] diff --git a/sdks/python/apache_beam/yaml/yaml_join_test.py b/sdks/python/apache_beam/yaml/yaml_join_test.py new file mode 100644 index 000000000000..5d43b1cdb3ab --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_join_test.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import apache_beam as beam +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.yaml.yaml_transform import YamlTransform + + +class ToRow(beam.PTransform): + def expand(self, pcoll): + return pcoll | beam.Map(lambda row: beam.Row(**row._asdict())) + + +FRUITS = [ + beam.Row(id=1, name='raspberry'), + beam.Row(id=2, name='blackberry'), +] + +QUANTITIES = [ + beam.Row(name='raspberry', quantity=1), + beam.Row(name='blackberry', quantity=2), + beam.Row(name='blueberry', quantity=3), +] + +CATEGORIES = [ + beam.Row(name='raspberry', category='juicy'), + beam.Row(name='blackberry', category='dry'), + beam.Row(name='blueberry', category='dry'), + beam.Row(name='blueberry', category='juicy'), +] + + +@unittest.skipIf( + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is + None, + 'Do not run this test on precommit suites.') +class YamlJoinTest(unittest.TestCase): + def test_basic_join(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + fruits = p | "fruits" >> beam.Create(FRUITS) + quantities = p | "quantities" >> beam.Create(QUANTITIES) + result = { + "fruits": fruits, "quantities": quantities + } | YamlTransform( + ''' + type: Join + input: + f: fruits + q: quantities + config: + type: inner + equalities: + - f: name + q: name + ''') | ToRow() + assert_that( + result, + equal_to([ + beam.Row(id=1, name='raspberry', name0='raspberry', quantity=1), + beam.Row(id=2, name='blackberry', name0='blackberry', quantity=2) + ])) + + def test_join_three_inputs(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + fruits = p | "fruits" >> beam.Create(FRUITS) + quantities = p | "quantities" >> beam.Create(QUANTITIES) + categories = p | "categories" >> beam.Create(CATEGORIES) + result = { + "fruits": fruits, "quantities": quantities, "categories": categories + } | YamlTransform( + ''' + type: Join + input: + f: fruits + q: quantities + c: categories + config: + type: inner + equalities: + - f: name + q: name + - f: name + c: name + ''') | ToRow() + assert_that( + result, + equal_to([ + beam.Row( + id=1, + name='raspberry', + name0='raspberry', + quantity=1, + name1='raspberry', + category='juicy'), + beam.Row( + id=2, + name='blackberry', + name0='blackberry', + quantity=2, + name1='blackberry', + category='dry') + ])) + + def test_join_with_fields(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + fruits = p | "fruits" >> beam.Create(FRUITS) + quantities = p | "quantities" >> beam.Create(QUANTITIES) + categories = p | "categories" >> beam.Create(CATEGORIES) + result = { + "fruits": fruits, "quantities": quantities, "categories": categories + } | YamlTransform( + ''' + type: Join + input: + f: fruits + q: quantities + c: categories + config: + type: inner + equalities: + - f: name + q: name + - f: name + c: name + fields: + f: + f_id: id + q: + - name + - quantity + ''') | ToRow() + + assert_that( + result, + equal_to([ + beam.Row( + f_id=1, + name='raspberry', + quantity=1, + name0='raspberry', + category='juicy'), + beam.Row( + f_id=2, + name='blackberry', + quantity=2, + name0='blackberry', + category='dry') + ])) + + def test_join_with_equalities_shorthand(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + fruits = p | "fruits" >> beam.Create(FRUITS) + quantities = p | "quantities" >> beam.Create(QUANTITIES) + categories = p | "categories" >> beam.Create(CATEGORIES) + + result = { + "fruits": fruits, "quantities": quantities, "categories": categories + } | YamlTransform( + ''' + type: Join + input: + f: fruits + q: quantities + c: categories + config: + type: inner + equalities: name + ''') | ToRow() + + assert_that( + result, + equal_to([ + beam.Row( + id=1, + name='raspberry', + name0='raspberry', + quantity=1, + name1='raspberry', + category='juicy'), + beam.Row( + id=2, + name='blackberry', + name0='blackberry', + quantity=2, + name1='blackberry', + category='dry') + ])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index dcf7ffaa6af3..5f53302028c8 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -1143,6 +1143,7 @@ def merge_providers(*provider_sets): def standard_providers(): from apache_beam.yaml.yaml_combine import create_combine_providers from apache_beam.yaml.yaml_mapping import create_mapping_providers + from apache_beam.yaml.yaml_join import create_join_providers from apache_beam.yaml.yaml_io import io_providers with open(os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')) as fin: @@ -1153,5 +1154,6 @@ def standard_providers(): create_java_builtin_provider(), create_mapping_providers(), create_combine_providers(), + create_join_providers(), io_providers(), parse_providers(standard_providers))