diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index d5aa4038ef7a..1b74a765e54b 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -18,9 +18,13 @@ import logging import unittest +import numpy as np + import apache_beam as beam +from apache_beam import schema_pb2 from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.typehints import schemas from apache_beam.yaml.yaml_transform import YamlTransform DATA = [ @@ -390,6 +394,32 @@ def test_partition_bad_runtime_type(self): language: python ''') + def test_append_type_inference(self): + p = beam.Pipeline( + options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) + elements = p | beam.Create(DATA) + elements.element_type = schemas.named_tuple_from_schema( + schema_pb2.Schema( + fields=[ + schemas.schema_field('label', str), + schemas.schema_field('conductor', int), + schemas.schema_field('rank', int) + ])) + result = elements | YamlTransform( + ''' + type: MapToFields + config: + language: python + append: true + fields: + new_label: label + ''') + self.assertSequenceEqual( + result.element_type._fields, + (('label', str), ('conductor', np.int64), ('rank', np.int64), + ('new_label', str))) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)