From b557cae4d86ece47e9486d3ca99cc0ce6217e936 Mon Sep 17 00:00:00 2001 From: Jeff Kinard <35542536+Polber@users.noreply.github.com> Date: Wed, 13 Sep 2023 20:53:03 -0400 Subject: [PATCH] [yaml] Add more complete UDF support to mapping transforms (#28114) --- .../apache_beam/utils/python_callable.py | 31 +-- sdks/python/apache_beam/yaml/yaml_mapping.py | 170 ++++++++---- sdks/python/apache_beam/yaml/yaml_udf_test.py | 242 ++++++++++++++++++ .../license_scripts/dep_urls_py.yaml | 2 + .../py310/base_image_requirements.txt | 3 + .../py311/base_image_requirements.txt | 3 + .../py38/base_image_requirements.txt | 4 + .../py39/base_image_requirements.txt | 3 + sdks/python/setup.py | 1 + 9 files changed, 401 insertions(+), 58 deletions(-) create mode 100644 sdks/python/apache_beam/yaml/yaml_udf_test.py diff --git a/sdks/python/apache_beam/utils/python_callable.py b/sdks/python/apache_beam/utils/python_callable.py index a7de214ec926..70aa7cb39e5c 100644 --- a/sdks/python/apache_beam/utils/python_callable.py +++ b/sdks/python/apache_beam/utils/python_callable.py @@ -77,7 +77,7 @@ def load_from_fully_qualified_name(fully_qualified_name): return o @staticmethod - def load_from_script(source): + def load_from_script(source, method_name=None): lines = [ line for line in source.split('\n') if line.strip() and line.strip()[0] != '#' @@ -85,26 +85,27 @@ def load_from_script(source): common_indent = min(len(line) - len(line.lstrip()) for line in lines) lines = [line[common_indent:] for line in lines] - for ix, line in reversed(list(enumerate(lines))): - if line[0] != ' ': - if line.startswith('def '): - name = line[4:line.index('(')].strip() - elif line.startswith('class '): - name = line[5:line.index('(') if '(' in - line else line.index(':')].strip() - else: - name = '__python_callable__' - lines[ix] = name + ' = ' + line - break - else: - raise ValueError("Unable to identify callable from %r" % source) + if method_name is None: + for ix, line in reversed(list(enumerate(lines))): + if line[0] != ' ': + if line.startswith('def '): + method_name = line[4:line.index('(')].strip() + elif line.startswith('class '): + method_name = line[5:line.index('(') if '(' in + line else line.index(':')].strip() + else: + method_name = '__python_callable__' + lines[ix] = method_name + ' = ' + line + break + else: + raise ValueError("Unable to identify callable from %r" % source) # pylint: disable=exec-used # pylint: disable=ungrouped-imports import apache_beam as beam exec_globals = {'beam': beam} exec('\n'.join(lines), exec_globals) - return exec_globals[name] + return exec_globals[method_name] def default_label(self): src = self._source.strip() diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 64c7ea726062..b6dea894b3e9 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -16,10 +16,12 @@ # """This module defines the basic MapToFields operation.""" - import itertools +import js2py + import apache_beam as beam +from apache_beam.io.filesystems import FileSystems from apache_beam.typehints import row_type from apache_beam.typehints import trivial_inference from apache_beam.typehints.schemas import named_fields_from_element_type @@ -27,33 +29,114 @@ from apache_beam.yaml import yaml_provider -def _as_callable(original_fields, expr): +def _check_mapping_arguments( + transform_name, expression=None, callable=None, name=None, path=None): + # Argument checking + if not expression and not callable and not path and not name: + raise ValueError( + f'{transform_name} must specify either "expression", "callable", ' + f'or both "path" and "name"') + if expression and callable: + raise ValueError( + f'{transform_name} cannot specify both "expression" and "callable"') + if (expression or callable) and (path or name): + raise ValueError( + f'{transform_name} cannot specify "expression" or "callable" with ' + f'"path" or "name"') + if path and not name: + raise ValueError(f'{transform_name} cannot specify "path" without "name"') + if name and not path: + raise ValueError(f'{transform_name} cannot specify "name" without "path"') + + +# js2py's JsObjectWrapper object has a self-referencing __dict__ property +# that cannot be pickled without implementing the __getstate__ and +# __setstate__ methods. +class _CustomJsObjectWrapper(js2py.base.JsObjectWrapper): + def __init__(self, js_obj): + super().__init__(js_obj.__dict__['_obj']) + + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + self.__dict__.update(state) + + +# TODO(yaml) Consider adding optional language version parameter to support +# ECMAScript 5 and 6 +def _expand_javascript_mapping_func( + original_fields, expression=None, callable=None, path=None, name=None): + if expression: + args = ', '.join(original_fields) + js_func = f'function fn({args}) {{return ({expression})}}' + js_callable = _CustomJsObjectWrapper(js2py.eval_js(js_func)) + return lambda __row__: js_callable(*__row__._asdict().values()) + + elif callable: + js_callable = _CustomJsObjectWrapper(js2py.eval_js(callable)) + return lambda __row__: js_callable(__row__._asdict()) + + else: + if not path.endswith('.js'): + raise ValueError(f'File "{path}" is not a valid .js file.') + udf_code = FileSystems.open(path).read().decode() + js = js2py.EvalJs() + js.eval(udf_code) + js_callable = _CustomJsObjectWrapper(getattr(js, name)) + return lambda __row__: js_callable(__row__._asdict()) + + +def _expand_python_mapping_func( + original_fields, expression=None, callable=None, path=None, name=None): + if path and name: + if not path.endswith('.py'): + raise ValueError(f'File "{path}" is not a valid .py file.') + py_file = FileSystems.open(path).read().decode() + + return python_callable.PythonCallableWithSource.load_from_script( + py_file, name) + + elif expression: + # TODO(robertwb): Consider constructing a single callable that takes + # the row and returns the new row, rather than invoking (and unpacking) + # for each field individually. + source = '\n'.join(['def fn(__row__):'] + [ + f' {name} = __row__.{name}' + for name in original_fields if name in expression + ] + [' return (' + expression + ')']) + + else: + source = callable + + return python_callable.PythonCallableWithSource(source) + + +def _as_callable(original_fields, expr, transform_name, language): if expr in original_fields: return expr + + # TODO(yaml): support a type parameter + # TODO(yaml): support an imports parameter + # TODO(yaml): support a requirements parameter (possibly at a higher level) + if isinstance(expr, str): + expr = {'expression': expr} + if not isinstance(expr, dict): + raise ValueError( + f"Ambiguous expression type (perhaps missing quoting?): {expr}") + elif len(expr) != 1 and ('path' not in expr or 'name' not in expr): + raise ValueError(f"Ambiguous expression type: {list(expr.keys())}") + + _check_mapping_arguments(transform_name, **expr) + + if language == "javascript": + return _expand_javascript_mapping_func(original_fields, **expr) + elif language == "python": + return _expand_python_mapping_func(original_fields, **expr) else: - # TODO(yaml): support a type parameter - # TODO(yaml): support an imports parameter - # TODO(yaml): support a requirements parameter (possibly at a higher level) - if isinstance(expr, str): - expr = {'expression': expr} - if not isinstance(expr, dict): - raise ValueError( - f"Ambiguous expression type (perhaps missing quoting?): {expr}") - elif len(expr) != 1: - raise ValueError(f"Ambiguous expression type: {list(expr.keys())}") - if 'expression' in expr: - # TODO(robertwb): Consider constructing a single callable that takes - # the row and returns the new row, rather than invoking (and unpacking) - # for each field individually. - source = '\n'.join(['def fn(__row__):'] + [ - f' {name} = __row__.{name}' - for name in original_fields if name in expr['expression'] - ] + [' return (' + expr['expression'] + ')']) - elif 'callable' in expr: - source = expr['callable'] - else: - raise ValueError(f"Unknown expression type: {list(expr.keys())}") - return python_callable.PythonCallableWithSource(source) + raise ValueError( + f'Unknown language for mapping transform: {language}. ' + 'Supported languages are "javascript" and "python."') # TODO(yaml): This should be available in all environments, in which case @@ -88,14 +171,12 @@ def explode_zip(base, fields): yield beam.Row(**copy) return ( - beam.core._MaybePValueWithErrors( - pcoll, self._exception_handling_args) + beam.core._MaybePValueWithErrors(pcoll, self._exception_handling_args) | beam.FlatMap( - lambda row: ( - explode_cross_product if self._cross_product else explode_zip)( - {name: getattr(row, name) for name in all_fields}, # yapf - to_explode)) - ).as_result() + lambda row: + (explode_cross_product if self._cross_product else explode_zip) + ({name: getattr(row, name) + for name in all_fields}, to_explode))).as_result() def infer_output_type(self, input_type): return row_type.RowTypeConstraint.from_fields([( @@ -116,6 +197,8 @@ def _PythonProjectionTransform( pcoll, *, fields, + transform_name, + language, keep=None, explode=(), cross_product=True, @@ -138,19 +221,16 @@ def _PythonProjectionTransform( if isinstance(keep, str) and keep in original_fields: keep_fn = lambda row: getattr(row, keep) else: - keep_fn = _as_callable(original_fields, keep) + keep_fn = _as_callable(original_fields, keep, transform_name, language) filtered = pcoll | beam.Filter(keep_fn) else: filtered = pcoll - if list(fields.items()) == [(name, name) for name in original_fields]: - projected = filtered - else: - projected = filtered | beam.Select( - **{ - name: _as_callable(original_fields, expr) - for (name, expr) in fields.items() - }) + projected = filtered | beam.Select( + **{ + name: _as_callable(original_fields, expr, transform_name, language) + for (name, expr) in fields.items() + }) if explode: result = projected | _Explode(explode, cross_product=cross_product) @@ -177,8 +257,8 @@ def MapToFields( drop=(), language=None, error_handling=None, + transform_name="MapToFields", **language_keywords): - if isinstance(explode, str): explode = [explode] if cross_product is None: @@ -242,13 +322,15 @@ def MapToFields( return result - elif language == 'python': + elif language == 'python' or language == 'javascript': return pcoll | yaml_create_transform({ 'type': 'PyTransform', 'config': { 'constructor': __name__ + '._PythonProjectionTransform', 'kwargs': { 'fields': fields, + 'transform_name': transform_name, + 'language': language, 'keep': keep, 'explode': explode, 'cross_product': cross_product, @@ -281,6 +363,7 @@ def create_mapping_provider(): keep=keep, fields={}, append=True, + transform_name='Filter', **kwargs)), 'Explode': ( lambda yaml_create_transform, @@ -290,5 +373,6 @@ def create_mapping_provider(): explode=explode, fields={}, append=True, + transform_name='Explode', **kwargs)), }) diff --git a/sdks/python/apache_beam/yaml/yaml_udf_test.py b/sdks/python/apache_beam/yaml/yaml_udf_test.py new file mode 100644 index 000000000000..bb15cd494757 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_udf_test.py @@ -0,0 +1,242 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.io import localfilesystem +from apache_beam.options import pipeline_options +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.yaml.yaml_transform import YamlTransform + + +class YamlUDFMappingTest(unittest.TestCase): + def __init__(self, method_name='runYamlMappingTest'): + super().__init__(method_name) + self.data = [ + beam.Row(label='11a', conductor=11, rank=0), + beam.Row(label='37a', conductor=37, rank=1), + beam.Row(label='389a', conductor=389, rank=2), + ] + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.fs = localfilesystem.LocalFileSystem(pipeline_options) + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def test_map_to_fields_filter_inline_js(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: MapToFields + input: input + config: + language: javascript + fields: + label: + callable: "function label_map(x) {return x.label + 'x'}" + conductor: + callable: "function conductor_map(x) {return x.conductor + 1}" + keep: + callable: "function filter(x) {return x.rank > 0}" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37ax', conductor=38), + beam.Row(label='389ax', conductor=390), + ])) + + def test_map_to_fields_filter_inline_py(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: MapToFields + input: input + config: + language: python + fields: + label: + callable: "lambda x: x.label + 'x'" + conductor: + callable: "lambda x: x.conductor + 1" + keep: + callable: "lambda x: x.rank > 0" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37ax', conductor=38), + beam.Row(label='389ax', conductor=390), + ])) + + def test_filter_inline_js(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: Filter + input: input + config: + language: javascript + keep: + callable: "function filter(x) {return x.rank > 0}" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37a', conductor=37, rank=1), + beam.Row(label='389a', conductor=389, rank=2), + ])) + + def test_filter_inline_py(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: Filter + input: input + config: + language: python + keep: + callable: "lambda x: x.rank > 0" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37a', conductor=37, rank=1), + beam.Row(label='389a', conductor=389, rank=2), + ])) + + def test_filter_expression_js(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: Filter + input: input + config: + language: javascript + keep: + expression: "label.toUpperCase().indexOf('3') == -1 && conductor" + ''') + assert_that( + result, equal_to([ + beam.Row(label='11a', conductor=11, rank=0), + ])) + + def test_filter_expression_py(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + ''' + type: Filter + input: input + config: + language: python + keep: + expression: "'3' not in label" + ''') + assert_that( + result, equal_to([ + beam.Row(label='11a', conductor=11, rank=0), + ])) + + def test_filter_inline_js_file(self): + data = ''' + function f(x) { + return x.rank > 0 + } + + function g(x) { + return x.rank > 1 + } + '''.replace(' ', '') + + path = os.path.join(self.tmpdir, 'udf.js') + self.fs.create(path).write(data.encode('utf8')) + + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + f''' + type: Filter + input: input + config: + language: javascript + keep: + path: {path} + name: "f" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37a', conductor=37, rank=1), + beam.Row(label='389a', conductor=389, rank=2), + ])) + + def test_filter_inline_py_file(self): + data = ''' + def f(x): + return x.rank > 0 + + def g(x): + return x.rank > 1 + '''.replace(' ', '') + + path = os.path.join(self.tmpdir, 'udf.py') + self.fs.create(path).write(data.encode('utf8')) + + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create(self.data) + result = elements | YamlTransform( + f''' + type: Filter + input: input + config: + language: python + keep: + path: {path} + name: "f" + ''') + assert_that( + result, + equal_to([ + beam.Row(label='37a', conductor=37, rank=1), + beam.Row(label='389a', conductor=389, rank=2), + ])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/container/license_scripts/dep_urls_py.yaml b/sdks/python/container/license_scripts/dep_urls_py.yaml index beea506ca91c..36efb36c321c 100644 --- a/sdks/python/container/license_scripts/dep_urls_py.yaml +++ b/sdks/python/container/license_scripts/dep_urls_py.yaml @@ -129,6 +129,8 @@ pip_dependencies: notice: "https://raw.githubusercontent.com/apache/arrow/master/NOTICE.txt" pyhamcrest: license: "https://raw.githubusercontent.com/hamcrest/PyHamcrest/master/LICENSE.txt" + pyjsparser: + license: "https://github.com/PiotrDabkowski/pyjsparser/blob/master/LICENSE" pymongo: license: "https://raw.githubusercontent.com/mongodb/mongo-python-driver/master/LICENSE" python-gflags: diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 58aca4a4aea7..340f3e1b9691 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -80,6 +80,7 @@ hypothesis==6.84.3 idna==3.4 iniconfig==2.0.0 joblib==1.3.2 +Js2Py==0.74 mmh3==4.0.1 mock==5.1.0 nltk==3.8.1 @@ -102,6 +103,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 +pyjsparser==2.7.1 pymongo==4.5.0 PyMySQL==1.1.0 pyparsing==3.1.1 @@ -130,6 +132,7 @@ threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.66.1 typing_extensions==4.7.1 +tzlocal==5.0.1 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.3 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 5aaeba15c69e..cb1637c2eb12 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -77,6 +77,7 @@ hypothesis==6.84.3 idna==3.4 iniconfig==2.0.0 joblib==1.3.2 +Js2Py==0.74 mmh3==4.0.1 mock==5.1.0 nltk==3.8.1 @@ -99,6 +100,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 +pyjsparser==2.7.1 pymongo==4.5.0 PyMySQL==1.1.0 pyparsing==3.1.1 @@ -125,6 +127,7 @@ testcontainers==3.7.1 threadpoolctl==3.2.0 tqdm==4.66.1 typing_extensions==4.7.1 +tzlocal==5.0.1 urllib3==1.26.16 websocket-client==1.6.3 wrapt==1.15.0 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index 472ee0c0bf8d..3d59060cd3ee 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -22,6 +22,7 @@ # Reach out to a committer if you need help. attrs==23.1.0 +backports.zoneinfo==0.2.1 beautifulsoup4==4.12.2 bs4==0.0.1 cachetools==5.3.1 @@ -80,6 +81,7 @@ hypothesis==6.84.3 idna==3.4 iniconfig==2.0.0 joblib==1.3.2 +Js2Py==0.74 mmh3==4.0.1 mock==5.1.0 nltk==3.8.1 @@ -102,6 +104,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 +pyjsparser==2.7.1 pymongo==4.5.0 PyMySQL==1.1.0 pyparsing==3.1.1 @@ -130,6 +133,7 @@ threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.66.1 typing_extensions==4.7.1 +tzlocal==5.0.1 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.3 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 257bcf9869e2..6342cfe6edc1 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -80,6 +80,7 @@ hypothesis==6.84.3 idna==3.4 iniconfig==2.0.0 joblib==1.3.2 +Js2Py==0.74 mmh3==4.0.1 mock==5.1.0 nltk==3.8.1 @@ -102,6 +103,7 @@ pyasn1-modules==0.3.0 pycparser==2.21 pydot==1.4.2 PyHamcrest==2.0.4 +pyjsparser==2.7.1 pymongo==4.5.0 PyMySQL==1.1.0 pyparsing==3.1.1 @@ -130,6 +132,7 @@ threadpoolctl==3.2.0 tomli==2.0.1 tqdm==4.66.1 typing_extensions==4.7.1 +tzlocal==5.0.1 uritemplate==4.1.1 urllib3==1.26.16 websocket-client==1.6.3 diff --git a/sdks/python/setup.py b/sdks/python/setup.py index d5ca354fcfbe..cadc4f34c86d 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -247,6 +247,7 @@ def get_portability_package_data(): 'grpcio>=1.33.1,!=1.48.0,<2', 'hdfs>=2.1.0,<3.0.0', 'httplib2>=0.8,<0.23.0', + 'js2py>=0.74,<1', # numpy can have breaking changes in minor versions. # Use a strict upper bound. 'numpy>=1.14.3,<1.25.0', # Update build-requirements.txt as well.