Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Fix simple YAML mappings type hinting #31427

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/yaml/yaml_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def extract_return_type(expr):

for output, agg in self._combine.items():
expr = yaml_mapping._as_callable(
all_fields, agg['value'], 'Combine', self._language)
all_fields, agg['value'], 'Combine', self._language, input_types)
fn = create_combine_fn(agg['fn'])
transform = transform.aggregate_field(expr, fn, output)

Expand Down
23 changes: 19 additions & 4 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
Expand Down Expand Up @@ -276,12 +277,16 @@ def _as_callable_for_pcoll(
if isinstance(fn_spec, str) and fn_spec in input_schema:
return lambda row: getattr(row, fn_spec)
else:
return _as_callable(list(input_schema.keys()), fn_spec, msg, language)
return _as_callable(
list(input_schema.keys()), fn_spec, msg, language, input_schema)


def _as_callable(original_fields, expr, transform_name, language):
def _as_callable(original_fields, expr, transform_name, language, input_schema):

# Extract original type from upstream pcoll when doing simple mappings
original_type = input_schema.get(str(expr), None)
if expr in original_fields:
return expr
language = "python"

# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
Expand Down Expand Up @@ -317,6 +322,15 @@ def checking_func(row):

return checking_func

elif original_type:

@beam.typehints.with_output_types(convert_to_beam_type(original_type))
def checking_func(row):
result = func(row)
return result

return checking_func

else:
return func

Expand Down Expand Up @@ -554,7 +568,8 @@ def _PyJsMapToFields(pcoll, language='generic', **mapping_args):

return pcoll | beam.Select(
**{
name: _as_callable(original_fields, expr, name, language)
name: _as_callable(
original_fields, expr, name, language, input_schema)
for (name, expr) in fields.items()
})

Expand Down
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
import logging
import unittest

import numpy as np

import apache_beam as beam
from apache_beam import schema_pb2
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import schemas
from apache_beam.yaml.yaml_transform import YamlTransform

DATA = [
Expand Down Expand Up @@ -390,6 +394,32 @@ def test_partition_bad_runtime_type(self):
language: python
''')

def test_append_type_inference(self):
p = beam.Pipeline(
options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle'))
elements = p | beam.Create(DATA)
elements.element_type = schemas.named_tuple_from_schema(
schema_pb2.Schema(
fields=[
schemas.schema_field('label', str),
schemas.schema_field('conductor', int),
schemas.schema_field('rank', int)
]))
result = elements | YamlTransform(
'''
type: MapToFields
config:
language: python
append: true
fields:
new_label: label
''')
self.assertSequenceEqual(
result.element_type._fields,
(('label', str), ('conductor', np.int64), ('rank', np.int64),
('new_label', str)))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
Loading