Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use other fallback coders for protobuf message base class #33432

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,11 +1039,16 @@ def __hash__(self):

@classmethod
def from_type_hint(cls, typehint, unused_registry):
if issubclass(typehint, proto_utils.message_types):
# The typehint must be a subclass of google.protobuf.message.Message.
shunping marked this conversation as resolved.
Show resolved Hide resolved
# ProtoCoder cannot work with message.Message itself, as required APIs are
# not implemented in the base class. If this occurs, an error is raised
shunping marked this conversation as resolved.
Show resolved Hide resolved
# and the system defaults to other fallback coders.
if (issubclass(typehint, proto_utils.message_types) and
typehint != message.Message):
return cls(typehint)
else:
raise ValueError((
'Expected a subclass of google.protobuf.message.Message'
'Expected a strict subclass of google.protobuf.message.Message'
', but got a %s' % typehint))

def to_type_hint(self):
Expand Down
18 changes: 18 additions & 0 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import unittest

from google.protobuf import message
import proto
import pytest

Expand Down Expand Up @@ -86,6 +87,23 @@ def test_proto_coder(self):
self.assertEqual(ma, real_coder.decode(real_coder.encode(ma)))
self.assertEqual(ma.__class__, real_coder.to_type_hint())

def test_proto_coder_on_protobuf_message_subclasses(self):
# This replicates a scenario where users provide message.Message as the
# output typehint for a Map function, even though the actual output messages
# are subclasses of message.Message.
ma = test_message.MessageA()
mb = ma.field2.add()
mb.field1 = True
ma.field1 = 'hello world'

coder = coders_registry.get_coder(message.Message)
# For messages of google.protobuf.message.Message, the fallback coder will
# be FastPrimitiveCoder other than ProtoCoder.
shunping marked this conversation as resolved.
Show resolved Hide resolved
# See the comment on ProtoCoder.from_type_hint() for further details.
self.assertEqual(coder, coders.FastPrimitivesCoder())

self.assertEqual(ma, coder.decode(coder.encode(ma)))


class DeterministicProtoCoderTest(unittest.TestCase):
def test_deterministic_proto_coder(self):
Expand Down
Loading