Skip to content

Commit

Permalink
[yaml] Add more complete UDF support to mapping transforms (#28114)
Browse files Browse the repository at this point in the history
  • Loading branch information
Polber authored Sep 14, 2023
1 parent 1b42ded commit b557cae
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 58 deletions.
31 changes: 16 additions & 15 deletions sdks/python/apache_beam/utils/python_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,35 @@ def load_from_fully_qualified_name(fully_qualified_name):
return o

@staticmethod
def load_from_script(source):
def load_from_script(source, method_name=None):
lines = [
line for line in source.split('\n')
if line.strip() and line.strip()[0] != '#'
]
common_indent = min(len(line) - len(line.lstrip()) for line in lines)
lines = [line[common_indent:] for line in lines]

for ix, line in reversed(list(enumerate(lines))):
if line[0] != ' ':
if line.startswith('def '):
name = line[4:line.index('(')].strip()
elif line.startswith('class '):
name = line[5:line.index('(') if '(' in
line else line.index(':')].strip()
else:
name = '__python_callable__'
lines[ix] = name + ' = ' + line
break
else:
raise ValueError("Unable to identify callable from %r" % source)
if method_name is None:
for ix, line in reversed(list(enumerate(lines))):
if line[0] != ' ':
if line.startswith('def '):
method_name = line[4:line.index('(')].strip()
elif line.startswith('class '):
method_name = line[5:line.index('(') if '(' in
line else line.index(':')].strip()
else:
method_name = '__python_callable__'
lines[ix] = method_name + ' = ' + line
break
else:
raise ValueError("Unable to identify callable from %r" % source)

# pylint: disable=exec-used
# pylint: disable=ungrouped-imports
import apache_beam as beam
exec_globals = {'beam': beam}
exec('\n'.join(lines), exec_globals)
return exec_globals[name]
return exec_globals[method_name]

def default_label(self):
src = self._source.strip()
Expand Down
170 changes: 127 additions & 43 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,127 @@
#

"""This module defines the basic MapToFields operation."""

import itertools

import js2py

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.typehints import row_type
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
from apache_beam.yaml import yaml_provider


def _as_callable(original_fields, expr):
def _check_mapping_arguments(
transform_name, expression=None, callable=None, name=None, path=None):
# Argument checking
if not expression and not callable and not path and not name:
raise ValueError(
f'{transform_name} must specify either "expression", "callable", '
f'or both "path" and "name"')
if expression and callable:
raise ValueError(
f'{transform_name} cannot specify both "expression" and "callable"')
if (expression or callable) and (path or name):
raise ValueError(
f'{transform_name} cannot specify "expression" or "callable" with '
f'"path" or "name"')
if path and not name:
raise ValueError(f'{transform_name} cannot specify "path" without "name"')
if name and not path:
raise ValueError(f'{transform_name} cannot specify "name" without "path"')


# js2py's JsObjectWrapper object has a self-referencing __dict__ property
# that cannot be pickled without implementing the __getstate__ and
# __setstate__ methods.
class _CustomJsObjectWrapper(js2py.base.JsObjectWrapper):
def __init__(self, js_obj):
super().__init__(js_obj.__dict__['_obj'])

def __getstate__(self):
return self.__dict__.copy()

def __setstate__(self, state):
self.__dict__.update(state)


# TODO(yaml) Consider adding optional language version parameter to support
# ECMAScript 5 and 6
def _expand_javascript_mapping_func(
original_fields, expression=None, callable=None, path=None, name=None):
if expression:
args = ', '.join(original_fields)
js_func = f'function fn({args}) {{return ({expression})}}'
js_callable = _CustomJsObjectWrapper(js2py.eval_js(js_func))
return lambda __row__: js_callable(*__row__._asdict().values())

elif callable:
js_callable = _CustomJsObjectWrapper(js2py.eval_js(callable))
return lambda __row__: js_callable(__row__._asdict())

else:
if not path.endswith('.js'):
raise ValueError(f'File "{path}" is not a valid .js file.')
udf_code = FileSystems.open(path).read().decode()
js = js2py.EvalJs()
js.eval(udf_code)
js_callable = _CustomJsObjectWrapper(getattr(js, name))
return lambda __row__: js_callable(__row__._asdict())


def _expand_python_mapping_func(
original_fields, expression=None, callable=None, path=None, name=None):
if path and name:
if not path.endswith('.py'):
raise ValueError(f'File "{path}" is not a valid .py file.')
py_file = FileSystems.open(path).read().decode()

return python_callable.PythonCallableWithSource.load_from_script(
py_file, name)

elif expression:
# TODO(robertwb): Consider constructing a single callable that takes
# the row and returns the new row, rather than invoking (and unpacking)
# for each field individually.
source = '\n'.join(['def fn(__row__):'] + [
f' {name} = __row__.{name}'
for name in original_fields if name in expression
] + [' return (' + expression + ')'])

else:
source = callable

return python_callable.PythonCallableWithSource(source)


def _as_callable(original_fields, expr, transform_name, language):
if expr in original_fields:
return expr

# TODO(yaml): support a type parameter
# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
if isinstance(expr, str):
expr = {'expression': expr}
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
elif len(expr) != 1 and ('path' not in expr or 'name' not in expr):
raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")

_check_mapping_arguments(transform_name, **expr)

if language == "javascript":
return _expand_javascript_mapping_func(original_fields, **expr)
elif language == "python":
return _expand_python_mapping_func(original_fields, **expr)
else:
# TODO(yaml): support a type parameter
# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
if isinstance(expr, str):
expr = {'expression': expr}
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
elif len(expr) != 1:
raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")
if 'expression' in expr:
# TODO(robertwb): Consider constructing a single callable that takes
# the row and returns the new row, rather than invoking (and unpacking)
# for each field individually.
source = '\n'.join(['def fn(__row__):'] + [
f' {name} = __row__.{name}'
for name in original_fields if name in expr['expression']
] + [' return (' + expr['expression'] + ')'])
elif 'callable' in expr:
source = expr['callable']
else:
raise ValueError(f"Unknown expression type: {list(expr.keys())}")
return python_callable.PythonCallableWithSource(source)
raise ValueError(
f'Unknown language for mapping transform: {language}. '
'Supported languages are "javascript" and "python."')


# TODO(yaml): This should be available in all environments, in which case
Expand Down Expand Up @@ -88,14 +171,12 @@ def explode_zip(base, fields):
yield beam.Row(**copy)

return (
beam.core._MaybePValueWithErrors(
pcoll, self._exception_handling_args)
beam.core._MaybePValueWithErrors(pcoll, self._exception_handling_args)
| beam.FlatMap(
lambda row: (
explode_cross_product if self._cross_product else explode_zip)(
{name: getattr(row, name) for name in all_fields}, # yapf
to_explode))
).as_result()
lambda row:
(explode_cross_product if self._cross_product else explode_zip)
({name: getattr(row, name)
for name in all_fields}, to_explode))).as_result()

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([(
Expand All @@ -116,6 +197,8 @@ def _PythonProjectionTransform(
pcoll,
*,
fields,
transform_name,
language,
keep=None,
explode=(),
cross_product=True,
Expand All @@ -138,19 +221,16 @@ def _PythonProjectionTransform(
if isinstance(keep, str) and keep in original_fields:
keep_fn = lambda row: getattr(row, keep)
else:
keep_fn = _as_callable(original_fields, keep)
keep_fn = _as_callable(original_fields, keep, transform_name, language)
filtered = pcoll | beam.Filter(keep_fn)
else:
filtered = pcoll

if list(fields.items()) == [(name, name) for name in original_fields]:
projected = filtered
else:
projected = filtered | beam.Select(
**{
name: _as_callable(original_fields, expr)
for (name, expr) in fields.items()
})
projected = filtered | beam.Select(
**{
name: _as_callable(original_fields, expr, transform_name, language)
for (name, expr) in fields.items()
})

if explode:
result = projected | _Explode(explode, cross_product=cross_product)
Expand All @@ -177,8 +257,8 @@ def MapToFields(
drop=(),
language=None,
error_handling=None,
transform_name="MapToFields",
**language_keywords):

if isinstance(explode, str):
explode = [explode]
if cross_product is None:
Expand Down Expand Up @@ -242,13 +322,15 @@ def MapToFields(

return result

elif language == 'python':
elif language == 'python' or language == 'javascript':
return pcoll | yaml_create_transform({
'type': 'PyTransform',
'config': {
'constructor': __name__ + '._PythonProjectionTransform',
'kwargs': {
'fields': fields,
'transform_name': transform_name,
'language': language,
'keep': keep,
'explode': explode,
'cross_product': cross_product,
Expand Down Expand Up @@ -281,6 +363,7 @@ def create_mapping_provider():
keep=keep,
fields={},
append=True,
transform_name='Filter',
**kwargs)),
'Explode': (
lambda yaml_create_transform,
Expand All @@ -290,5 +373,6 @@ def create_mapping_provider():
explode=explode,
fields={},
append=True,
transform_name='Explode',
**kwargs)),
})
Loading

0 comments on commit b557cae

Please sign in to comment.