Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Add optional output type parameter to mappings. #29077

Merged
merged 9 commits into from
Oct 25, 2023
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,9 +3364,16 @@ def expand(self, pcoll):
for name, expr in self._fields}))).as_result()

def infer_output_type(self, input_type):
def extract_return_type(expr):
expr_hints = get_type_hints(expr)
if (expr_hints and expr_hints.has_simple_output_type() and
expr_hints.simple_output_type(None) != typehints.Any):
return expr_hints.simple_output_type(None)
else:
return trivial_inference.infer_return_type(expr, [input_type])

return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._fields
(name, extract_return_type(expr)) for (name, expr) in self._fields
])


Expand Down
37 changes: 37 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,40 @@ criteria. This can be accomplished with a `Filter` transform, e.g.
language: sql
keep: "col2 > 0"
```

## Types

Beam will try to infer the types involved in the mappings, but sometimes this
is not possible. In these cases one can explicitly denote the expected output
type, e.g.

```
- type: MapToFields
config:
language: python
fields:
new_col:
expression: "col1.upper()"
type: string
```

The expected type is given in json schema notation, with the addition that
a top-level basic types may be given as a literal string rather than requiring
a `{type: 'basic_type_name'}` nesting.

```
- type: MapToFields
config:
language: python
fields:
new_col:
expression: "col1.upper()"
type: string
another_col:
expression: "beam.Row(a=col1, b=[col2])"
type:
type: 'object'
properties:
a: {type: 'string'}
b: {type: 'array', items: {type: 'number'}}
robertwb marked this conversation as resolved.
Show resolved Hide resolved
```
67 changes: 62 additions & 5 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import row_type
from apache_beam.typehints import trivial_inference
from apache_beam.typehints import schemas
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
from apache_beam.yaml import json_utils
from apache_beam.yaml import yaml_provider


Expand Down Expand Up @@ -120,6 +123,44 @@ def _expand_python_mapping_func(
return python_callable.PythonCallableWithSource(source)


def _validator(beam_type: schema_pb2.FieldType) -> Callable[[Any], bool]:
"""Returns a callable converting rows of the given type to Json objects."""
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
if beam_type.atomic_type == schema_pb2.BOOLEAN:
return lambda x: isinstance(x, bool)
elif beam_type.atomic_type == schema_pb2.INT64:
return lambda x: isinstance(x, int)
elif beam_type.atomic_type == schema_pb2.DOUBLE:
return lambda x: isinstance(x, (int, float))
elif beam_type.atomic_type == schema_pb2.STRING:
return lambda x: isinstance(x, str)
else:
raise ValueError(
f'Unknown or unsupported atomic type: {beam_type.atomic_type}')
elif type_info == "array_type":
element_validator = _validator(beam_type.array_type.element_type)
return lambda value: all(element_validator(e) for e in value)
elif type_info == "iterable_type":
element_validator = _validator(beam_type.iterable_type.element_type)
return lambda value: all(element_validator(e) for e in value)
elif type_info == "map_type":
key_validator = _validator(beam_type.iterable_type.key_type)
value_validator = _validator(beam_type.iterable_type.value_type)
return lambda value: all(
key_validator(k) and value_validator(v) for (k, v) in value.items())
elif type_info == "row_type":
validators = {
field.name: _validator(field.type)
for field in beam_type.row_type.schema.fields
}
return lambda row: all(
validator(getattr(row, name))
for (name, validator) in validators.items())
else:
raise ValueError(f"Unrecognized type_info: {type_info!r}")


def _as_callable(original_fields, expr, transform_name, language):
if expr in original_fields:
return expr
Expand All @@ -132,20 +173,36 @@ def _as_callable(original_fields, expr, transform_name, language):
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
elif len(expr) != 1 and ('path' not in expr or 'name' not in expr):
raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")

explicit_type = expr.pop('type', None)
_check_mapping_arguments(transform_name, **expr)

if language == "javascript":
return _expand_javascript_mapping_func(original_fields, **expr)
func = _expand_javascript_mapping_func(original_fields, **expr)
elif language == "python":
return _expand_python_mapping_func(original_fields, **expr)
func = _expand_python_mapping_func(original_fields, **expr)
else:
raise ValueError(
f'Unknown language for mapping transform: {language}. '
'Supported languages are "javascript" and "python."')

if explicit_type:
if isinstance(explicit_type, str):
explicit_type = {'type': explicit_type}
beam_type = json_utils.json_type_to_beam_type(explicit_type)
validator = _validator(beam_type)

@beam.typehints.with_output_types(schemas.typing_from_runner_api(beam_type))
def checking_func(row):
result = func(row)
if not validator(result):
raise TypeError(f'{result} violates schema {explicit_type}')
return result

return checking_func

else:
return func


def exception_handling_args(error_handling_spec):
if error_handling_spec:
Expand Down
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ def test_explode(self):
beam.Row(a=3, b='y', c=.125, range=2),
]))

def test_validate_explicit_types(self):
with self.assertRaisesRegex(TypeError, r'.*violates schema.*'):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(a=2, b='abc', c=.25),
beam.Row(a=3, b='xy', c=.125),
])
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
language: python
fields:
bad:
expression: "a + c"
type: string # This is a lie.
''')
self.assertEqual(result.element_type._fields[0][1], str)


YamlMappingDocTest = createTestSuite(
'YamlMappingDocTest',
Expand Down
Loading