Skip to content

Commit

Permalink
Fix custom coder not being used in Reshuffle (global window) (apache#…
Browse files Browse the repository at this point in the history
…33339)

* Fix typehint in ReshufflePerKey on global window setting.

* Only update the type hint on global window setting. Need more work in non-global windows.

* Apply yapf

* Fix some failed tests.

* Revert change to setup.py
  • Loading branch information
shunping authored Dec 12, 2024
1 parent 224900c commit 272feb8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
13 changes: 11 additions & 2 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
Expand Down Expand Up @@ -78,6 +79,7 @@
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
from apache_beam.utils.sharded_key import ShardedKey
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.runners.pipeline_context import PipelineContext
Expand Down Expand Up @@ -953,6 +955,10 @@ def restore_timestamps(element):
window.GlobalWindows.windowed_value((key, value), timestamp)
for (value, timestamp) in values
]

ungrouped = pcoll | Map(reify_timestamps).with_input_types(
Tuple[K, V]).with_output_types(
Tuple[K, Tuple[V, Optional[Timestamp]]])
else:

# typing: All conditional function variants must have identical signatures
Expand All @@ -966,7 +972,9 @@ def restore_timestamps(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]

ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
# TODO(https://github.com/apache/beam/issues/33356): Support reshuffling
# unpicklable objects with a non-global window setting.
ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)

# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
Expand Down Expand Up @@ -1018,7 +1026,8 @@ def expand(self, pcoll):
pcoll | 'AddRandomKeys' >>
Map(lambda t: (random.randrange(0, self.num_buckets), t)
).with_input_types(T).with_output_types(Tuple[int, T])
| ReshufflePerKey()
| ReshufflePerKey().with_input_types(Tuple[int, T]).with_output_types(
Tuple[int, T])
| 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
Tuple[int, T]).with_output_types(T))

Expand Down
39 changes: 39 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,45 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
equal_to(expected_data),
label="formatted_after_reshuffle")

def test_reshuffle_unpicklable_in_global_window(self):
global _Unpicklable

class _Unpicklable(object):
def __init__(self, value):
self.value = value

def __getstate__(self):
raise NotImplementedError()

def __setstate__(self, state):
raise NotImplementedError()

class _UnpicklableCoder(beam.coders.Coder):
def encode(self, value):
return str(value.value).encode()

def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))

def to_type_hint(self):
return _Unpicklable

def is_deterministic(self):
return True

beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

with TestPipeline() as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 10, 20, 30, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(GlobalWindows())
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))


class WithKeysTest(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 272feb8

Please sign in to comment.