diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index d86f59e3a411..9225acf346e4 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -45,7 +45,13 @@ # pytype: skip-file import os from functools import partial +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Union +import fastavro from fastavro.read import block_reader from fastavro.write import Writer @@ -54,8 +60,11 @@ from apache_beam.io import filebasedsource from apache_beam.io import iobase from apache_beam.io.filesystem import CompressionTypes +from apache_beam.io.filesystems import FileSystems from apache_beam.io.iobase import Read +from apache_beam.portability.api import schema_pb2 from apache_beam.transforms import PTransform +from apache_beam.typehints import schemas __all__ = [ 'ReadFromAvro', @@ -73,7 +82,8 @@ def __init__( file_pattern=None, min_bundle_size=0, validate=True, - use_fastavro=True): + use_fastavro=True, + as_rows=False): """Initializes :class:`ReadFromAvro`. Uses source :class:`~apache_beam.io._AvroSource` to read a set of Avro @@ -140,13 +150,26 @@ def __init__( creation time. use_fastavro (bool): This flag is left for API backwards compatibility and no longer has an effect. Do not use. + as_rows (bool): Whether to return a schema'd PCollection of Beam rows. """ super().__init__() - self._source = _create_avro_source( + self._source = _FastAvroSource( file_pattern, min_bundle_size, validate=validate) + if as_rows: + path = FileSystems.match([file_pattern], [1])[0].metadata_list[0].path + with FileSystems.open(path) as fin: + avro_schema = fastavro.reader(fin).writer_schema + beam_schema = avro_schema_to_beam_schema(avro_schema) + self._post_process = avro_dict_to_beam_row(avro_schema, beam_schema) + else: + self._post_process = None def expand(self, pvalue): - return pvalue.pipeline | Read(self._source) + records = pvalue.pipeline | Read(self._source) + if self._post_process: + return records | beam.Map(self._post_process) + else: + return records def display_data(self): return {'source_dd': self._source} @@ -184,8 +207,7 @@ def __init__( name and the value being the actual data. If False, it only returns the data. """ - source_from_file = partial( - _create_avro_source, min_bundle_size=min_bundle_size) + source_from_file = partial(_FastAvroSource, min_bundle_size=min_bundle_size) self._read_all_files = filebasedsource.ReadAllFiles( True, CompressionTypes.AUTO, @@ -280,15 +302,6 @@ def advance_file_past_next_sync_marker(f, sync_marker): data = f.read(buf_size) -def _create_avro_source(file_pattern=None, min_bundle_size=0, validate=False): - return \ - _FastAvroSource( - file_pattern=file_pattern, - min_bundle_size=min_bundle_size, - validate=validate - ) - - class _FastAvroSource(filebasedsource.FileBasedSource): """A source for reading Avro files using the `fastavro` library. @@ -338,12 +351,15 @@ def split_points_unclaimed(stop_position): yield record +_create_avro_source = _FastAvroSource + + class WriteToAvro(beam.transforms.PTransform): """A ``PTransform`` for writing avro files.""" def __init__( self, file_path_prefix, - schema, + schema=None, codec='deflate', file_name_suffix='', num_shards=0, @@ -382,9 +398,10 @@ def __init__( Returns: A WriteToAvro transform usable for writing. """ - self._sink = _create_avro_sink( + self._schema = schema + self._sink_provider = lambda avro_schema: _create_avro_sink( file_path_prefix, - schema, + avro_schema, codec, file_name_suffix, num_shards, @@ -392,7 +409,21 @@ def __init__( mime_type) def expand(self, pcoll): - return pcoll | beam.io.iobase.Write(self._sink) + if self._schema: + avro_schema = self._schema + records = pcoll + else: + try: + beam_schema = schemas.schema_from_element_type(pcoll.element_type) + except TypeError as exn: + raise ValueError( + "An explicit schema is required to write non-schema'd PCollections." + ) from exn + avro_schema = beam_schema_to_avro_schema(beam_schema) + records = pcoll | beam.Map( + beam_row_to_avro_dict(avro_schema, beam_schema)) + self._sink = self._sink_provider(avro_schema) + return records | beam.io.iobase.Write(self._sink) def display_data(self): return {'sink_dd': self._sink} @@ -406,7 +437,7 @@ def _create_avro_sink( num_shards, shard_name_template, mime_type): - if "class \'avro.schema" in str(type(schema)): + if "class 'avro.schema" in str(type(schema)): raise ValueError( 'You are using Avro IO with fastavro (default with Beam on ' 'Python 3), but supplying a schema parsed by avro-python3. ' @@ -483,3 +514,205 @@ def write_record(self, writer, value): def close(self, writer): writer.flush() self.file_handle.close() + + +AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES = { + 'boolean': schema_pb2.BOOLEAN, + 'int': schema_pb2.INT32, + 'long': schema_pb2.INT64, + 'float': schema_pb2.FLOAT, + 'double': schema_pb2.DOUBLE, + 'bytes': schema_pb2.BYTES, + 'string': schema_pb2.STRING, +} + +BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES = { + v: k + for k, v in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES.items() +} + +_AvroSchemaType = Union[str, List, Dict] + + +def avro_type_to_beam_type(avro_type: _AvroSchemaType) -> schema_pb2.FieldType: + if isinstance(avro_type, str): + return avro_type_to_beam_type({'type': avro_type}) + elif isinstance(avro_type, list): + # Union type + return schemas.typing_to_runner_api(Any) + type_name = avro_type['type'] + if type_name in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES: + return schema_pb2.FieldType( + atomic_type=AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES[type_name]) + elif type_name in ('fixed', 'enum'): + return schema_pb2.FieldType(atomic_type=schema_pb2.STRING) + elif type_name == 'array': + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=avro_type_to_beam_type(avro_type['items']))) + elif type_name == 'map': + return schema_pb2.FieldType( + map_type=schema_pb2.MapType( + key_type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), + value_type=avro_type_to_beam_type(avro_type['values']))) + elif type_name == 'record': + return schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schemas.schema_field( + f['name'], avro_type_to_beam_type(f['type'])) + for f in avro_type['fields'] + ]))) + else: + raise ValueError(f'Unable to convert {avro_type} to a Beam schema.') + + +def avro_schema_to_beam_schema( + avro_schema: _AvroSchemaType) -> schema_pb2.Schema: + beam_type = avro_type_to_beam_type(avro_schema) + if isinstance(avro_schema, dict) and avro_schema['type'] == 'record': + return beam_type.row_type.schema + else: + return schema_pb2.Schema(fields=[schemas.schema_field('record', beam_type)]) + + +def avro_dict_to_beam_row( + avro_schema: _AvroSchemaType, + beam_schema: schema_pb2.Schema) -> Callable[[Any], Any]: + if isinstance(avro_schema, str): + return avro_dict_to_beam_row({'type': avro_schema}) + if avro_schema['type'] == 'record': + to_row = avro_value_to_beam_value( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + else: + + def to_row(record): + return beam.Row(record=record) + + return beam.typehints.with_output_types( + schemas.named_tuple_from_schema(beam_schema))( + to_row) + + +def avro_value_to_beam_value( + beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]: + type_info = beam_type.WhichOneof("type_info") + if type_info == "atomic_type": + return lambda value: value + elif type_info == "array_type": + element_converter = avro_value_to_beam_value( + beam_type.array_type.element_type) + return lambda value: [element_converter(e) for e in value] + elif type_info == "iterable_type": + element_converter = avro_value_to_beam_value( + beam_type.iterable_type.element_type) + return lambda value: [element_converter(e) for e in value] + elif type_info == "map_type": + if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: + raise TypeError( + f'Only strings allowd as map keys when converting from AVRO, ' + f'found {beam_type}') + value_converter = avro_value_to_beam_value(beam_type.map_type.value_type) + return lambda value: {k: value_converter(v) for (k, v) in value.items()} + elif type_info == "row_type": + converters = { + field.name: avro_value_to_beam_value(field.type) + for field in beam_type.row_type.schema.fields + } + return lambda value: beam.Row( + ** + {name: convert(value[name]) + for (name, convert) in converters.items()}) + elif type_info == "logical_type": + return lambda value: value + else: + raise ValueError(f"Unrecognized type_info: {type_info!r}") + + +def beam_schema_to_avro_schema( + beam_schema: schema_pb2.Schema) -> _AvroSchemaType: + return beam_type_to_avro_type( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + + +def beam_type_to_avro_type(beam_type: schema_pb2.FieldType) -> _AvroSchemaType: + type_info = beam_type.WhichOneof("type_info") + if type_info == "atomic_type": + return {'type': BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type]} + elif type_info == "array_type": + return { + 'type': 'array', + 'items': beam_type_to_avro_type(beam_type.array_type.element_type) + } + elif type_info == "iterable_type": + return { + 'type': 'array', + 'items': beam_type_to_avro_type(beam_type.iterable_type.element_type) + } + elif type_info == "map_type": + if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: + raise TypeError( + f'Only strings allowd as map keys when converting to AVRO, ' + f'found {beam_type}') + return { + 'type': 'map', + 'values': beam_type_to_avro_type(beam_type.map_type.element_type) + } + elif type_info == "row_type": + return { + 'type': 'record', + 'name': beam_type.row_type.schema.id, + 'fields': [{ + 'name': field.name, 'type': beam_type_to_avro_type(field.type) + } for field in beam_type.row_type.schema.fields], + } + else: + raise ValueError(f"Unconvertale type: {beam_type}") + + +def beam_row_to_avro_dict( + avro_schema: _AvroSchemaType, beam_schema: schema_pb2.Schema): + if isinstance(avro_schema, str): + return beam_row_to_avro_dict({'type': avro_schema}, beam_schema) + if avro_schema['type'] == 'record': + return beam_value_to_avro_value( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + else: + convert = beam_value_to_avro_value(beam_schema) + return lambda row: convert(row[0]) + + +def beam_value_to_avro_value( + beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]: + type_info = beam_type.WhichOneof("type_info") + if type_info == "atomic_type": + return lambda value: value + elif type_info == "array_type": + element_converter = avro_value_to_beam_value( + beam_type.array_type.element_type) + return lambda value: [element_converter(e) for e in value] + elif type_info == "iterable_type": + element_converter = avro_value_to_beam_value( + beam_type.iterable_type.element_type) + return lambda value: [element_converter(e) for e in value] + elif type_info == "map_type": + if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: + raise TypeError( + f'Only strings allowd as map keys when converting from AVRO, ' + f'found {beam_type}') + value_converter = avro_value_to_beam_value(beam_type.map_type.value_type) + return lambda value: {k: value_converter(v) for (k, v) in value.items()} + elif type_info == "row_type": + converters = { + field.name: avro_value_to_beam_value(field.type) + for field in beam_type.row_type.schema.fields + } + return lambda value: { + name: convert(getattr(value, name)) + for (name, convert) in converters.items() + } + elif type_info == "logical_type": + return lambda value: value + else: + raise ValueError(f"Unrecognized type_info: {type_info!r}") diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index ba35cf5846c0..c54ac40711b1 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -35,8 +35,8 @@ from apache_beam.io import filebasedsource from apache_beam.io import iobase from apache_beam.io import source_test_utils +from apache_beam.io.avroio import _FastAvroSource # For testing from apache_beam.io.avroio import _create_avro_sink # For testing -from apache_beam.io.avroio import _create_avro_source # For testing from apache_beam.io.filesystems import FileSystems from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -125,7 +125,7 @@ def _write_pattern(self, num_files, return_filenames=False): def _run_avro_test( self, pattern, desired_bundle_size, perform_splitting, expected_result): - source = _create_avro_source(pattern) + source = _FastAvroSource(pattern) if perform_splitting: assert desired_bundle_size @@ -146,6 +146,20 @@ def _run_avro_test( read_records = source_test_utils.read_from_source(source, None, None) self.assertCountEqual(expected_result, read_records) + def test_schema_read_write(self): + with tempfile.TemporaryDirectory() as tmp_dirname: + path = os.path.join(tmp_dirname, 'tmp_filename') + rows = [beam.Row(a=1, b=['x', 'y']), beam.Row(a=2, b=['t', 'u'])] + stable_repr = lambda row: json.dumps(row._asdict()) + with TestPipeline() as p: + _ = p | Create(rows) | avroio.WriteToAvro(path) | beam.Map(print) + with TestPipeline() as p: + readback = ( + p + | avroio.ReadFromAvro(path + '*', as_rows=True) + | beam.Map(stable_repr)) + assert_that(readback, equal_to([stable_repr(r) for r in rows])) + def test_read_without_splitting(self): file_name = self._write_data() expected_result = self.RECORDS @@ -159,7 +173,7 @@ def test_read_with_splitting(self): def test_source_display_data(self): file_name = 'some_avro_source' source = \ - _create_avro_source( + _FastAvroSource( file_name, validate=False, ) @@ -207,6 +221,7 @@ def test_sink_display_data(self): def test_write_display_data(self): file_name = 'some_avro_sink' write = avroio.WriteToAvro(file_name, self.SCHEMA) + write.expand(beam.PCollection(beam.Pipeline())) dd = DisplayData.create_from(write) expected_items = [ DisplayDataItemMatcher('schema', str(self.SCHEMA)), @@ -220,12 +235,12 @@ def test_write_display_data(self): def test_read_reentrant_without_splitting(self): file_name = self._write_data() - source = _create_avro_source(file_name) + source = _FastAvroSource(file_name) source_test_utils.assert_reentrant_reads_succeed((source, None, None)) def test_read_reantrant_with_splitting(self): file_name = self._write_data() - source = _create_avro_source(file_name) + source = _FastAvroSource(file_name) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 source_test_utils.assert_reentrant_reads_succeed( @@ -246,7 +261,7 @@ def test_split_points(self): sync_interval = 16000 file_name = self._write_data(count=num_records, sync_interval=sync_interval) - source = _create_avro_source(file_name) + source = _FastAvroSource(file_name) splits = [split for split in source.split(desired_bundle_size=float('inf'))] assert len(splits) == 1 @@ -306,7 +321,7 @@ def test_read_with_splitting_pattern(self): def test_dynamic_work_rebalancing_exhaustive(self): def compare_split_points(file_name): - source = _create_avro_source(file_name) + source = _FastAvroSource(file_name) splits = [ split for split in source.split(desired_bundle_size=float('inf')) ] @@ -334,7 +349,7 @@ def test_corrupted_file(self): f.write(corrupted_data) corrupted_file_name = f.name - source = _create_avro_source(corrupted_file_name) + source = _FastAvroSource(corrupted_file_name) with self.assertRaisesRegex(ValueError, r'expected sync marker'): source_test_utils.read_from_source(source, None, None) diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 2e86c9eb51c7..90882651d0b2 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -673,6 +673,9 @@ def __init__(self, **kwargs): def as_dict(self): return dict(self.__dict__) + # For compatibility with named tuples. + _asdict = as_dict + def __iter__(self): for _, value in self.__dict__.items(): yield value diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 229a8af20bb6..ea836430e8e2 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -93,6 +93,7 @@ from apache_beam.typehints.native_type_compatibility import _match_is_exactly_mapping from apache_beam.typehints.native_type_compatibility import _match_is_optional from apache_beam.typehints.native_type_compatibility import _safe_issubclass +from apache_beam.typehints.native_type_compatibility import convert_to_typing_type from apache_beam.typehints.native_type_compatibility import extract_optional_type from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY @@ -284,6 +285,9 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: if row_type_constraint is not None: return self.typing_to_runner_api(row_type_constraint) + if isinstance(type_, typehints.TypeConstraint): + type_ = convert_to_typing_type(type_) + # All concrete types (other than NamedTuple sub-classes) should map to # a supported primitive type. if type_ in PRIMITIVE_TO_ATOMIC_TYPE: diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 9ad4f53ba1f6..1d6aa5548a82 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -67,6 +67,8 @@ 'WriteToCsv': 'WriteToCsv' 'ReadFromJson': 'ReadFromJson' 'WriteToJson': 'WriteToJson' + 'ReadFromAvro': 'ReadFromAvro' + 'WriteToAvro': 'WriteToAvro' config: mappings: 'ReadFromCsv': @@ -77,6 +79,13 @@ path: 'path' 'WriteToJson': path: 'path' + 'ReadFromAvro': + path: 'file_pattern' + 'WriteToAvro': + path: 'file_path_prefix' + defaults: + 'ReadFromAvro': + as_rows: True underlying_provider: type: python transforms: @@ -84,3 +93,5 @@ 'WriteToCsv': 'apache_beam.io.WriteToCsv' 'ReadFromJson': 'apache_beam.io.ReadFromJson' 'WriteToJson': 'apache_beam.io.WriteToJson' + 'ReadFromAvro': 'apache_beam.io.ReadFromAvro' + 'WriteToAvro': 'apache_beam.io.WriteToAvro' diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 630e63c31d8a..84399cd597b2 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -684,7 +684,7 @@ def __exit__(self, *args): @ExternalProvider.register_provider_type('renaming') class RenamingProvider(Provider): - def __init__(self, transforms, mappings, underlying_provider): + def __init__(self, transforms, mappings, underlying_provider, defaults=None): if isinstance(underlying_provider, dict): underlying_provider = ExternalProvider.provider_from_spec( underlying_provider) @@ -694,6 +694,7 @@ def __init__(self, transforms, mappings, underlying_provider): if transform not in mappings: raise ValueError(f'Missing transform {transform} in mappings.') self._mappings = mappings + self._defaults = defaults or {} def available(self) -> bool: return self._underlying_provider.available() @@ -731,6 +732,9 @@ def create_transform( mappings.get(key, key): value for key, value in args.items() } + for key, value in self._defaults.get(typ, {}).items(): + if key not in remapped_args: + remapped_args[key] = value return self._underlying_provider.create_transform( self._transforms[typ], remapped_args, yaml_create_transform)