Skip to content

Commit

Permalink
[YAML] Avro format for PubSub. (#28899)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Oct 12, 2023
1 parent e8e3814 commit 9c75db4
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
implementations of the same transforms, the configs must be kept in sync.
"""

import io
import os
from typing import Any
from typing import Callable
Expand All @@ -32,12 +33,14 @@
from typing import Optional
from typing import Tuple

import fastavro
import yaml

import apache_beam as beam
import apache_beam.io as beam_io
from apache_beam.io import ReadFromBigQuery
from apache_beam.io import WriteToBigQuery
from apache_beam.io import avroio
from apache_beam.io.gcp.bigquery import BigQueryDisposition
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import schemas
Expand Down Expand Up @@ -146,6 +149,13 @@ def _create_parser(
elif format == 'json':
beam_schema = json_utils.json_schema_to_beam_schema(schema)
return beam_schema, json_utils.json_parser(beam_schema)
elif format == 'avro':
beam_schema = avroio.avro_schema_to_beam_schema(schema)
covert_to_row = avroio.avro_dict_to_beam_row(schema, beam_schema)
return (
beam_schema,
lambda record: covert_to_row(
fastavro.schemaless_reader(io.BytesIO(record), schema)))
else:
raise ValueError(f'Unknown format: {format}')

Expand All @@ -162,6 +172,17 @@ def _create_formatter(
return lambda row: getattr(row, field_names[0])
elif format == 'json':
return json_utils.json_formater(beam_schema)
elif format == 'avro':
avro_schema = schema or avroio.beam_schema_to_avro_schema(beam_schema)
from_row = avroio.beam_row_to_avro_dict(avro_schema, beam_schema)

def formatter(row):
buffer = io.BytesIO()
fastavro.schemaless_writer(buffer, avro_schema, from_row(row))
buffer.seek(0)
return buffer.read()

return formatter
else:
raise ValueError(f'Unknown format: {format}')

Expand Down
68 changes: 68 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# limitations under the License.
#

import io
import json
import logging
import unittest

import fastavro
import mock

import apache_beam as beam
Expand Down Expand Up @@ -167,6 +170,48 @@ def test_read_with_id_attribute(self):
result,
equal_to([beam.Row(payload=b'msg1'), beam.Row(payload=b'msg2')]))

_avro_schema = {
'type': 'record',
'name': 'ec',
'fields': [{
'name': 'label', 'type': 'string'
}, {
'name': 'rank', 'type': 'int'
}]
}

def _encode_avro(self, data):
buffer = io.BytesIO()
fastavro.schemaless_writer(buffer, self._avro_schema, data)
buffer.seek(0)
return buffer.read()

def test_read_avro(self):

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
with mock.patch(
'apache_beam.io.ReadFromPubSub',
FakeReadFromPubSub(
topic='my_topic',
messages=[PubsubMessage(self._encode_avro({'label': '37a',
'rank': 1}), {}),
PubsubMessage(self._encode_avro({'label': '389a',
'rank': 2}), {})])):
result = p | YamlTransform(
'''
type: ReadFromPubSub
config:
topic: my_topic
format: avro
schema: %s
''' % json.dumps(self._avro_schema))
assert_that(
result,
equal_to(
[beam.Row(label='37a', rank=1), # linebreak
beam.Row(label='389a', rank=2)]))

def test_read_json(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
Expand Down Expand Up @@ -346,6 +391,29 @@ def test_write_with_id_attribute(self):
id_attribute: some_attr
'''))

def test_write_avro(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
with mock.patch(
'apache_beam.io.WriteToPubSub',
FakeWriteToPubSub(
topic='my_topic',
messages=[PubsubMessage(self._encode_avro({'label': '37a',
'rank': 1}), {}),
PubsubMessage(self._encode_avro({'label': '389a',
'rank': 2}), {})])):
_ = (
p | beam.Create(
[beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)])
| YamlTransform(
'''
type: WriteToPubSub
input: input
config:
topic: my_topic
format: avro
'''))

def test_write_json(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
Expand Down

0 comments on commit 9c75db4

Please sign in to comment.