Skip to content

Commit

Permalink
numpy.int64 type is not serialized correctly in Python 3.11 and Pytho…
Browse files Browse the repository at this point in the history
…n 3.12 (#33137)

* Added the validation to __getstate__

* fix the comments

* formatting

* fixed the lint

* fixes isort
  • Loading branch information
liferoad authored Nov 18, 2024
1 parent 0d894a7 commit 22ea62c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ def encode_special_deterministic(self, value, stream):
stream.write_byte(NESTED_STATE_TYPE)
self.encode_type(type(value), stream)
state_value = value.__getstate__()
if value is not None and state_value is None:
# https://github.com/apache/beam/issues/33020
raise TypeError(self._deterministic_encoding_error_msg(value))
try:
self.encode_to_stream(state_value, stream, True)
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import proto
import pytest

import apache_beam as beam
from apache_beam import typehints
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import coders
from apache_beam.coders.avro_record import AvroRecord
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.testing.test_pipeline import TestPipeline


class PickleCoderTest(unittest.TestCase):
Expand Down Expand Up @@ -242,6 +244,20 @@ def test_to_type_hint(self):
assert coder.to_type_hint() is bytes


class NumpyIntAsKeyTest(unittest.TestCase):
def test_numpy_int(self):
# this type is not supported as the key
import numpy as np

with self.assertRaises(TypeError):
with TestPipeline() as p:
indata = p | "Create" >> beam.Create([(a, int(a))
for a in np.arange(3)])

# Apply CombinePerkey to sum values for each key.
_ = indata | "CombinePerKey" >> beam.CombinePerKey(sum)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit 22ea62c

Please sign in to comment.