diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 57d8197a3a00..0f2a42686854 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -55,6 +55,7 @@ import google.protobuf.wrappers_pb2 import proto +from google.protobuf import message from apache_beam.coders import coder_impl from apache_beam.coders.avro_record import AvroRecord @@ -65,7 +66,6 @@ from apache_beam.utils import proto_utils if TYPE_CHECKING: - from google.protobuf import message # pylint: disable=ungrouped-imports from apache_beam.coders.typecoders import CoderRegistry from apache_beam.runners.pipeline_context import PipelineContext @@ -1039,11 +1039,18 @@ def __hash__(self): @classmethod def from_type_hint(cls, typehint, unused_registry): - if issubclass(typehint, proto_utils.message_types): + # The typehint must be a strict subclass of google.protobuf.message.Message. + # ProtoCoder cannot work with message.Message itself, as deserialization of + # a serialized proto requires knowledge of the desired concrete proto + # subclass which is not stored in the encoded bytes themselves. If this + # occurs, an error is raised and the system defaults to other fallback + # coders. + if (issubclass(typehint, proto_utils.message_types) and + typehint != message.Message): return cls(typehint) else: raise ValueError(( - 'Expected a subclass of google.protobuf.message.Message' + 'Expected a strict subclass of google.protobuf.message.Message' ', but got a %s' % typehint)) def to_type_hint(self): diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py index dc9780e36be3..5e5debca36e6 100644 --- a/sdks/python/apache_beam/coders/coders_test.py +++ b/sdks/python/apache_beam/coders/coders_test.py @@ -22,6 +22,7 @@ import proto import pytest +from google.protobuf import message import apache_beam as beam from apache_beam import typehints @@ -86,6 +87,23 @@ def test_proto_coder(self): self.assertEqual(ma, real_coder.decode(real_coder.encode(ma))) self.assertEqual(ma.__class__, real_coder.to_type_hint()) + def test_proto_coder_on_protobuf_message_subclasses(self): + # This replicates a scenario where users provide message.Message as the + # output typehint for a Map function, even though the actual output messages + # are subclasses of message.Message. + ma = test_message.MessageA() + mb = ma.field2.add() + mb.field1 = True + ma.field1 = 'hello world' + + coder = coders_registry.get_coder(message.Message) + # For messages of google.protobuf.message.Message, the fallback coder will + # be FastPrimitivesCoder rather than ProtoCoder. + # See the comment on ProtoCoder.from_type_hint() for further details. + self.assertEqual(coder, coders.FastPrimitivesCoder()) + + self.assertEqual(ma, coder.decode(coder.encode(ma))) + class DeterministicProtoCoderTest(unittest.TestCase): def test_deterministic_proto_coder(self):