diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index 3321644ded57..b2bf150fa558 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -23,6 +23,7 @@ implementations of the same transforms, the configs must be kept in sync. """ +import io import os from typing import Any from typing import Callable @@ -32,12 +33,14 @@ from typing import Optional from typing import Tuple +import fastavro import yaml import apache_beam as beam import apache_beam.io as beam_io from apache_beam.io import ReadFromBigQuery from apache_beam.io import WriteToBigQuery +from apache_beam.io import avroio from apache_beam.io.gcp.bigquery import BigQueryDisposition from apache_beam.portability.api import schema_pb2 from apache_beam.typehints import schemas @@ -146,6 +149,13 @@ def _create_parser( elif format == 'json': beam_schema = json_utils.json_schema_to_beam_schema(schema) return beam_schema, json_utils.json_parser(beam_schema) + elif format == 'avro': + beam_schema = avroio.avro_schema_to_beam_schema(schema) + covert_to_row = avroio.avro_dict_to_beam_row(schema, beam_schema) + return ( + beam_schema, + lambda record: covert_to_row( + fastavro.schemaless_reader(io.BytesIO(record), schema))) else: raise ValueError(f'Unknown format: {format}') @@ -162,6 +172,17 @@ def _create_formatter( return lambda row: getattr(row, field_names[0]) elif format == 'json': return json_utils.json_formater(beam_schema) + elif format == 'avro': + avro_schema = schema or avroio.beam_schema_to_avro_schema(beam_schema) + from_row = avroio.beam_row_to_avro_dict(avro_schema, beam_schema) + + def formatter(row): + buffer = io.BytesIO() + fastavro.schemaless_writer(buffer, avro_schema, from_row(row)) + buffer.seek(0) + return buffer.read() + + return formatter else: raise ValueError(f'Unknown format: {format}') diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py b/sdks/python/apache_beam/yaml/yaml_io_test.py index 72675da278b0..7071860a7bf1 100644 --- a/sdks/python/apache_beam/yaml/yaml_io_test.py +++ b/sdks/python/apache_beam/yaml/yaml_io_test.py @@ -15,9 +15,12 @@ # limitations under the License. # +import io +import json import logging import unittest +import fastavro import mock import apache_beam as beam @@ -167,6 +170,48 @@ def test_read_with_id_attribute(self): result, equal_to([beam.Row(payload=b'msg1'), beam.Row(payload=b'msg2')])) + _avro_schema = { + 'type': 'record', + 'name': 'ec', + 'fields': [{ + 'name': 'label', 'type': 'string' + }, { + 'name': 'rank', 'type': 'int' + }] + } + + def _encode_avro(self, data): + buffer = io.BytesIO() + fastavro.schemaless_writer(buffer, self._avro_schema, data) + buffer.seek(0) + return buffer.read() + + def test_read_avro(self): + + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + with mock.patch( + 'apache_beam.io.ReadFromPubSub', + FakeReadFromPubSub( + topic='my_topic', + messages=[PubsubMessage(self._encode_avro({'label': '37a', + 'rank': 1}), {}), + PubsubMessage(self._encode_avro({'label': '389a', + 'rank': 2}), {})])): + result = p | YamlTransform( + ''' + type: ReadFromPubSub + config: + topic: my_topic + format: avro + schema: %s + ''' % json.dumps(self._avro_schema)) + assert_that( + result, + equal_to( + [beam.Row(label='37a', rank=1), # linebreak + beam.Row(label='389a', rank=2)])) + def test_read_json(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: @@ -346,6 +391,29 @@ def test_write_with_id_attribute(self): id_attribute: some_attr ''')) + def test_write_avro(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + with mock.patch( + 'apache_beam.io.WriteToPubSub', + FakeWriteToPubSub( + topic='my_topic', + messages=[PubsubMessage(self._encode_avro({'label': '37a', + 'rank': 1}), {}), + PubsubMessage(self._encode_avro({'label': '389a', + 'rank': 2}), {})])): + _ = ( + p | beam.Create( + [beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)]) + | YamlTransform( + ''' + type: WriteToPubSub + input: input + config: + topic: my_topic + format: avro + ''')) + def test_write_json(self): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle')) as p: