From 37f0a7674333c00f5b80699bea1921666f1701ad Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 11 Dec 2024 23:13:05 -0500 Subject: [PATCH] Fix custom coders not being used in reshuffle in non-global windows --- sdks/python/apache_beam/coders/coders.py | 5 ++ sdks/python/apache_beam/coders/typecoders.py | 2 + sdks/python/apache_beam/transforms/util.py | 7 +-- .../apache_beam/transforms/util_test.py | 52 ++++++++++++------- .../python/apache_beam/typehints/typehints.py | 9 ++++ sdks/python/setup.py | 3 +- 6 files changed, 55 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index a0c55da81800..d53be59f3ada 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -1402,6 +1402,11 @@ def __hash__(self): return hash( (self.wrapped_value_coder, self.timestamp_coder, self.window_coder)) + @classmethod + def from_type_hint(cls, typehint, registry): + # type: (Any, CoderRegistry) -> WindowedValueCoder + return cls(registry.get_coder(typehint.inner_type)) + Coder.register_structured_urn( common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder) diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 1667cb7a916a..abfb9e72a047 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder): self._register_coder_internal(str, coders.StrUtf8Coder) self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder) self._register_coder_internal(typehints.DictConstraint, coders.MapCoder) + self._register_coder_internal(typehints.WindowedTypeConstraint, + coders.WindowedValueCoder) # Default fallback coders applied in that order until the first matching # coder found. default_fallback_coders = [ diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 43d4a6c20e94..0787c90650c4 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -74,6 +74,7 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature +from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.typehints.sharded_key_type import ShardedKeyType from apache_beam.utils import shared from apache_beam.utils import windowed_value @@ -972,9 +973,9 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - # TODO(https://github.com/apache/beam/issues/33356): Support reshuffling - # unpicklable objects with a non-global window setting. - ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) + ungrouped = pcoll | Map(reify_timestamps).with_input_types( + Tuple[K, V]).with_output_types( + Tuple[K, typehints.WindowedValue[convert_to_beam_type(V)]]) # TODO(https://github.com/apache/beam/issues/19785) Using global window as # one of the standard window. This is to mitigate the Dataflow Java Runner diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7f166f78ef0a..55f4488beb3a 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1010,32 +1010,32 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): equal_to(expected_data), label="formatted_after_reshuffle") - def test_reshuffle_unpicklable_in_global_window(self): - global _Unpicklable - - class _Unpicklable(object): - def __init__(self, value): - self.value = value + global _Unpicklable + global _UnpicklableCoder + class _Unpicklable(object): + def __init__(self, value): + self.value = value - def __getstate__(self): - raise NotImplementedError() + def __getstate__(self): + raise NotImplementedError() - def __setstate__(self, state): - raise NotImplementedError() + def __setstate__(self, state): + raise NotImplementedError() - class _UnpicklableCoder(beam.coders.Coder): - def encode(self, value): - return str(value.value).encode() + class _UnpicklableCoder(beam.coders.Coder): + def encode(self, value): + return str(value.value).encode() - def decode(self, encoded): - return _Unpicklable(int(encoded.decode())) + def decode(self, encoded): + return _Unpicklable(int(encoded.decode())) - def to_type_hint(self): - return _Unpicklable + def to_type_hint(self): + return _Unpicklable - def is_deterministic(self): - return True + def is_deterministic(self): + return True + def test_reshuffle_unpicklable_in_global_window(self): beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) with TestPipeline() as pipeline: @@ -1049,6 +1049,20 @@ def is_deterministic(self): | beam.Map(lambda u: u.value * 10)) assert_that(result, equal_to(expected_data)) + def test_reshuffle_unpicklable_in_non_global_window(self): + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + + with TestPipeline() as pipeline: + data = [_Unpicklable(i) for i in range(5)] + expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40] + result = ( + pipeline + | beam.Create(data) + | beam.WindowInto(window.SlidingWindows(size=3, period=1)) + | beam.Reshuffle() + | beam.Map(lambda u: u.value * 10)) + assert_that(result, equal_to(expected_data)) + class WithKeysTest(unittest.TestCase): def setUp(self): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 0e18e887c2a0..a65a0f753826 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1213,6 +1213,15 @@ def type_check(self, instance): repr(self.inner_type), instance.value.__class__.__name__)) + def bind_type_variables(self, bindings): + bound_inner_type = bind_type_variables(self.inner_type, bindings) + if bound_inner_type == self.inner_type: + return self + return WindowedValue[bound_inner_type] + + def __repr__(self): + return 'WindowedValue[%s]' % repr(self.inner_type) + class GeneratorHint(IteratorHint): """A Generator type hint. diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 53c7a532e706..505da7f94c16 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -336,7 +336,8 @@ def get_portability_package_data(): *get_portability_package_data() ] }, - ext_modules=extensions, + # ext_modules=extensions, + ext_modules=[], install_requires=[ 'crcmod>=1.7,<2.0', 'orjson>=3.9.7,<4',