Skip to content

Commit

Permalink
[YAML] Add optional output type parameter to mappings. (#29077)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Oct 25, 2023
1 parent 208c028 commit ffb4332
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 8 deletions.
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,9 +3372,16 @@ def expand(self, pcoll):
for name, expr in self._fields}))).as_result()

def infer_output_type(self, input_type):
def extract_return_type(expr):
expr_hints = get_type_hints(expr)
if (expr_hints and expr_hints.has_simple_output_type() and
expr_hints.simple_output_type(None) != typehints.Any):
return expr_hints.simple_output_type(None)
else:
return trivial_inference.infer_return_type(expr, [input_type])

return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._fields
(name, extract_return_type(expr)) for (name, expr) in self._fields
])


Expand Down
44 changes: 44 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,47 @@ criteria. This can be accomplished with a `Filter` transform, e.g.
language: sql
keep: "col2 > 0"
```

## Types

Beam will try to infer the types involved in the mappings, but sometimes this
is not possible. In these cases one can explicitly denote the expected output
type, e.g.

```
- type: MapToFields
config:
language: python
fields:
new_col:
expression: "col1.upper()"
output_type: string
```

The expected type is given in json schema notation, with the addition that
a top-level basic types may be given as a literal string rather than requiring
a `{type: 'basic_type_name'}` nesting.

```
- type: MapToFields
config:
language: python
fields:
new_col:
expression: "col1.upper()"
output_type: string
another_col:
expression: "beam.Row(a=col1, b=[col2])"
output_type:
type: 'object'
properties:
a:
type: 'string'
b:
type: 'array'
items:
type: 'number'
```

This can be especially useful to resolve errors involving the inability to
handle the `beam:logical:pythonsdk_any:v1` type.
68 changes: 62 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import row_type
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
from apache_beam.yaml import json_utils
from apache_beam.yaml import options
from apache_beam.yaml import yaml_provider

Expand Down Expand Up @@ -121,32 +124,85 @@ def _expand_python_mapping_func(
return python_callable.PythonCallableWithSource(source)


def _validator(beam_type: schema_pb2.FieldType) -> Callable[[Any], bool]:
"""Returns a callable converting rows of the given type to Json objects."""
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
if beam_type.atomic_type == schema_pb2.BOOLEAN:
return lambda x: isinstance(x, bool)
elif beam_type.atomic_type == schema_pb2.INT64:
return lambda x: isinstance(x, int)
elif beam_type.atomic_type == schema_pb2.DOUBLE:
return lambda x: isinstance(x, (int, float))
elif beam_type.atomic_type == schema_pb2.STRING:
return lambda x: isinstance(x, str)
else:
raise ValueError(
f'Unknown or unsupported atomic type: {beam_type.atomic_type}')
elif type_info == "array_type":
element_validator = _validator(beam_type.array_type.element_type)
return lambda value: all(element_validator(e) for e in value)
elif type_info == "iterable_type":
element_validator = _validator(beam_type.iterable_type.element_type)
return lambda value: all(element_validator(e) for e in value)
elif type_info == "map_type":
key_validator = _validator(beam_type.map_type.key_type)
value_validator = _validator(beam_type.map_type.value_type)
return lambda value: all(
key_validator(k) and value_validator(v) for (k, v) in value.items())
elif type_info == "row_type":
validators = {
field.name: _validator(field.type)
for field in beam_type.row_type.schema.fields
}
return lambda row: all(
validator(getattr(row, name))
for (name, validator) in validators.items())
else:
raise ValueError(f"Unrecognized type_info: {type_info!r}")


def _as_callable(original_fields, expr, transform_name, language):
if expr in original_fields:
return expr

# TODO(yaml): support a type parameter
# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
if isinstance(expr, str):
expr = {'expression': expr}
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
elif len(expr) != 1 and ('path' not in expr or 'name' not in expr):
raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")

explicit_type = expr.pop('output_type', None)
_check_mapping_arguments(transform_name, **expr)

if language == "javascript":
return _expand_javascript_mapping_func(original_fields, **expr)
func = _expand_javascript_mapping_func(original_fields, **expr)
elif language == "python":
return _expand_python_mapping_func(original_fields, **expr)
func = _expand_python_mapping_func(original_fields, **expr)
else:
raise ValueError(
f'Unknown language for mapping transform: {language}. '
'Supported languages are "javascript" and "python."')

if explicit_type:
if isinstance(explicit_type, str):
explicit_type = {'type': explicit_type}
beam_type = json_utils.json_type_to_beam_type(explicit_type)
validator = _validator(beam_type)

@beam.typehints.with_output_types(schemas.typing_from_runner_api(beam_type))
def checking_func(row):
result = func(row)
if not validator(result):
raise TypeError(f'{result} violates schema {explicit_type}')
return result

return checking_func

else:
return func


def exception_handling_args(error_handling_spec):
if error_handling_spec:
Expand Down
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def test_explode(self):
beam.Row(a=3, b='y', c=.125, range=2),
]))

def test_validate_explicit_types(self):
with self.assertRaisesRegex(TypeError, r'.*violates schema.*'):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(a=2, b='abc', c=.25),
beam.Row(a=3, b='xy', c=.125),
])
result = elements | YamlTransform(
'''
type: MapToFields
input: input
config:
language: python
fields:
bad:
expression: "a + c"
output_type: string # This is a lie.
''')
self.assertEqual(result.element_type._fields[0][1], str)


YamlMappingDocTest = createTestSuite(
'YamlMappingDocTest',
Expand Down

0 comments on commit ffb4332

Please sign in to comment.