Skip to content

Commit

Permalink
Merge pull request #28855 [YAML] Schemify avroio and add a yaml provi…
Browse files Browse the repository at this point in the history
…der.
  • Loading branch information
robertwb authored Oct 6, 2023
2 parents d26a782 + 07e26fd commit ce217d3
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 28 deletions.
271 changes: 252 additions & 19 deletions sdks/python/apache_beam/io/avroio.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
# pytype: skip-file
import os
from functools import partial
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Union

import fastavro
from fastavro.read import block_reader
from fastavro.write import Writer

Expand All @@ -54,8 +60,11 @@
from apache_beam.io import filebasedsource
from apache_beam.io import iobase
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.iobase import Read
from apache_beam.portability.api import schema_pb2
from apache_beam.transforms import PTransform
from apache_beam.typehints import schemas

__all__ = [
'ReadFromAvro',
Expand All @@ -73,7 +82,8 @@ def __init__(
file_pattern=None,
min_bundle_size=0,
validate=True,
use_fastavro=True):
use_fastavro=True,
as_rows=False):
"""Initializes :class:`ReadFromAvro`.
Uses source :class:`~apache_beam.io._AvroSource` to read a set of Avro
Expand Down Expand Up @@ -140,13 +150,26 @@ def __init__(
creation time.
use_fastavro (bool): This flag is left for API backwards compatibility
and no longer has an effect. Do not use.
as_rows (bool): Whether to return a schema'd PCollection of Beam rows.
"""
super().__init__()
self._source = _create_avro_source(
self._source = _FastAvroSource(
file_pattern, min_bundle_size, validate=validate)
if as_rows:
path = FileSystems.match([file_pattern], [1])[0].metadata_list[0].path
with FileSystems.open(path) as fin:
avro_schema = fastavro.reader(fin).writer_schema
beam_schema = avro_schema_to_beam_schema(avro_schema)
self._post_process = avro_dict_to_beam_row(avro_schema, beam_schema)
else:
self._post_process = None

def expand(self, pvalue):
return pvalue.pipeline | Read(self._source)
records = pvalue.pipeline | Read(self._source)
if self._post_process:
return records | beam.Map(self._post_process)
else:
return records

def display_data(self):
return {'source_dd': self._source}
Expand Down Expand Up @@ -184,8 +207,7 @@ def __init__(
name and the value being the actual data. If False, it only returns
the data.
"""
source_from_file = partial(
_create_avro_source, min_bundle_size=min_bundle_size)
source_from_file = partial(_FastAvroSource, min_bundle_size=min_bundle_size)
self._read_all_files = filebasedsource.ReadAllFiles(
True,
CompressionTypes.AUTO,
Expand Down Expand Up @@ -280,15 +302,6 @@ def advance_file_past_next_sync_marker(f, sync_marker):
data = f.read(buf_size)


def _create_avro_source(file_pattern=None, min_bundle_size=0, validate=False):
return \
_FastAvroSource(
file_pattern=file_pattern,
min_bundle_size=min_bundle_size,
validate=validate
)


class _FastAvroSource(filebasedsource.FileBasedSource):
"""A source for reading Avro files using the `fastavro` library.
Expand Down Expand Up @@ -338,12 +351,15 @@ def split_points_unclaimed(stop_position):
yield record


_create_avro_source = _FastAvroSource


class WriteToAvro(beam.transforms.PTransform):
"""A ``PTransform`` for writing avro files."""
def __init__(
self,
file_path_prefix,
schema,
schema=None,
codec='deflate',
file_name_suffix='',
num_shards=0,
Expand Down Expand Up @@ -382,17 +398,32 @@ def __init__(
Returns:
A WriteToAvro transform usable for writing.
"""
self._sink = _create_avro_sink(
self._schema = schema
self._sink_provider = lambda avro_schema: _create_avro_sink(
file_path_prefix,
schema,
avro_schema,
codec,
file_name_suffix,
num_shards,
shard_name_template,
mime_type)

def expand(self, pcoll):
return pcoll | beam.io.iobase.Write(self._sink)
if self._schema:
avro_schema = self._schema
records = pcoll
else:
try:
beam_schema = schemas.schema_from_element_type(pcoll.element_type)
except TypeError as exn:
raise ValueError(
"An explicit schema is required to write non-schema'd PCollections."
) from exn
avro_schema = beam_schema_to_avro_schema(beam_schema)
records = pcoll | beam.Map(
beam_row_to_avro_dict(avro_schema, beam_schema))
self._sink = self._sink_provider(avro_schema)
return records | beam.io.iobase.Write(self._sink)

def display_data(self):
return {'sink_dd': self._sink}
Expand All @@ -406,7 +437,7 @@ def _create_avro_sink(
num_shards,
shard_name_template,
mime_type):
if "class \'avro.schema" in str(type(schema)):
if "class 'avro.schema" in str(type(schema)):
raise ValueError(
'You are using Avro IO with fastavro (default with Beam on '
'Python 3), but supplying a schema parsed by avro-python3. '
Expand Down Expand Up @@ -483,3 +514,205 @@ def write_record(self, writer, value):
def close(self, writer):
writer.flush()
self.file_handle.close()


AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES = {
'boolean': schema_pb2.BOOLEAN,
'int': schema_pb2.INT32,
'long': schema_pb2.INT64,
'float': schema_pb2.FLOAT,
'double': schema_pb2.DOUBLE,
'bytes': schema_pb2.BYTES,
'string': schema_pb2.STRING,
}

BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES = {
v: k
for k, v in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES.items()
}

_AvroSchemaType = Union[str, List, Dict]


def avro_type_to_beam_type(avro_type: _AvroSchemaType) -> schema_pb2.FieldType:
if isinstance(avro_type, str):
return avro_type_to_beam_type({'type': avro_type})
elif isinstance(avro_type, list):
# Union type
return schemas.typing_to_runner_api(Any)
type_name = avro_type['type']
if type_name in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES:
return schema_pb2.FieldType(
atomic_type=AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES[type_name])
elif type_name in ('fixed', 'enum'):
return schema_pb2.FieldType(atomic_type=schema_pb2.STRING)
elif type_name == 'array':
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(
element_type=avro_type_to_beam_type(avro_type['items'])))
elif type_name == 'map':
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(
key_type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
value_type=avro_type_to_beam_type(avro_type['values'])))
elif type_name == 'record':
return schema_pb2.FieldType(
row_type=schema_pb2.RowType(
schema=schema_pb2.Schema(
fields=[
schemas.schema_field(
f['name'], avro_type_to_beam_type(f['type']))
for f in avro_type['fields']
])))
else:
raise ValueError(f'Unable to convert {avro_type} to a Beam schema.')


def avro_schema_to_beam_schema(
avro_schema: _AvroSchemaType) -> schema_pb2.Schema:
beam_type = avro_type_to_beam_type(avro_schema)
if isinstance(avro_schema, dict) and avro_schema['type'] == 'record':
return beam_type.row_type.schema
else:
return schema_pb2.Schema(fields=[schemas.schema_field('record', beam_type)])


def avro_dict_to_beam_row(
avro_schema: _AvroSchemaType,
beam_schema: schema_pb2.Schema) -> Callable[[Any], Any]:
if isinstance(avro_schema, str):
return avro_dict_to_beam_row({'type': avro_schema})
if avro_schema['type'] == 'record':
to_row = avro_value_to_beam_value(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))
else:

def to_row(record):
return beam.Row(record=record)

return beam.typehints.with_output_types(
schemas.named_tuple_from_schema(beam_schema))(
to_row)


def avro_value_to_beam_value(
beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
return lambda value: value
elif type_info == "array_type":
element_converter = avro_value_to_beam_value(
beam_type.array_type.element_type)
return lambda value: [element_converter(e) for e in value]
elif type_info == "iterable_type":
element_converter = avro_value_to_beam_value(
beam_type.iterable_type.element_type)
return lambda value: [element_converter(e) for e in value]
elif type_info == "map_type":
if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING:
raise TypeError(
f'Only strings allowd as map keys when converting from AVRO, '
f'found {beam_type}')
value_converter = avro_value_to_beam_value(beam_type.map_type.value_type)
return lambda value: {k: value_converter(v) for (k, v) in value.items()}
elif type_info == "row_type":
converters = {
field.name: avro_value_to_beam_value(field.type)
for field in beam_type.row_type.schema.fields
}
return lambda value: beam.Row(
**
{name: convert(value[name])
for (name, convert) in converters.items()})
elif type_info == "logical_type":
return lambda value: value
else:
raise ValueError(f"Unrecognized type_info: {type_info!r}")


def beam_schema_to_avro_schema(
beam_schema: schema_pb2.Schema) -> _AvroSchemaType:
return beam_type_to_avro_type(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))


def beam_type_to_avro_type(beam_type: schema_pb2.FieldType) -> _AvroSchemaType:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
return {'type': BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type]}
elif type_info == "array_type":
return {
'type': 'array',
'items': beam_type_to_avro_type(beam_type.array_type.element_type)
}
elif type_info == "iterable_type":
return {
'type': 'array',
'items': beam_type_to_avro_type(beam_type.iterable_type.element_type)
}
elif type_info == "map_type":
if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING:
raise TypeError(
f'Only strings allowd as map keys when converting to AVRO, '
f'found {beam_type}')
return {
'type': 'map',
'values': beam_type_to_avro_type(beam_type.map_type.element_type)
}
elif type_info == "row_type":
return {
'type': 'record',
'name': beam_type.row_type.schema.id,
'fields': [{
'name': field.name, 'type': beam_type_to_avro_type(field.type)
} for field in beam_type.row_type.schema.fields],
}
else:
raise ValueError(f"Unconvertale type: {beam_type}")


def beam_row_to_avro_dict(
avro_schema: _AvroSchemaType, beam_schema: schema_pb2.Schema):
if isinstance(avro_schema, str):
return beam_row_to_avro_dict({'type': avro_schema}, beam_schema)
if avro_schema['type'] == 'record':
return beam_value_to_avro_value(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))
else:
convert = beam_value_to_avro_value(beam_schema)
return lambda row: convert(row[0])


def beam_value_to_avro_value(
beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
return lambda value: value
elif type_info == "array_type":
element_converter = avro_value_to_beam_value(
beam_type.array_type.element_type)
return lambda value: [element_converter(e) for e in value]
elif type_info == "iterable_type":
element_converter = avro_value_to_beam_value(
beam_type.iterable_type.element_type)
return lambda value: [element_converter(e) for e in value]
elif type_info == "map_type":
if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING:
raise TypeError(
f'Only strings allowd as map keys when converting from AVRO, '
f'found {beam_type}')
value_converter = avro_value_to_beam_value(beam_type.map_type.value_type)
return lambda value: {k: value_converter(v) for (k, v) in value.items()}
elif type_info == "row_type":
converters = {
field.name: avro_value_to_beam_value(field.type)
for field in beam_type.row_type.schema.fields
}
return lambda value: {
name: convert(getattr(value, name))
for (name, convert) in converters.items()
}
elif type_info == "logical_type":
return lambda value: value
else:
raise ValueError(f"Unrecognized type_info: {type_info!r}")
Loading

0 comments on commit ce217d3

Please sign in to comment.