Skip to content

Commit

Permalink
Deepcopy combine_fn in PrecombineFn and PostCombineFn. (#32598)
Browse files Browse the repository at this point in the history
Co-authored-by: Claude <[email protected]>
  • Loading branch information
claudevdm and Claude authored Oct 1, 2024
1 parent 987c4c9 commit eaf53e5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
11 changes: 7 additions & 4 deletions sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@ def test_combining_value_state(self):


@parameterized_class([
{'runner': direct_runner.BundleBasedDirectRunner},
{'runner': fn_api_runner.FnApiRunner},
]) # yapf: disable
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'},
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'},
]) # yapf: disable
class LocalCombineFnLifecycleTest(unittest.TestCase):
def tearDown(self):
CallSequenceEnforcingCombineFn.instances.clear()

def test_combine(self):
run_combine(TestPipeline(runner=self.runner()))
test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"])
run_combine(TestPipeline(runner=self.runner(), options=test_options))
self._assert_teardown_called()

def test_non_liftable_combine(self):
Expand Down
43 changes: 25 additions & 18 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3158,33 +3158,40 @@ def process(self, element):
yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value))

class PreCombineFn(CombineFn):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
self._combine_fn_copy = copy.deepcopy(combine_fn)
self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.add_input = self._combine_fn_copy.add_input
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.teardown = self._combine_fn_copy.teardown

@staticmethod
def extract_output(accumulator):
# Boolean indicates this is an accumulator.
return (True, accumulator)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
add_input = combine_fn.add_input
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
teardown = combine_fn.teardown

class PostCombineFn(CombineFn):
@staticmethod
def add_input(accumulator, element):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
self._combine_fn_copy = copy.deepcopy(combine_fn)
self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.extract_output = self._combine_fn_copy.extract_output
self.teardown = self._combine_fn_copy.teardown

def add_input(self, accumulator, element):
is_accumulator, value = element
if is_accumulator:
return combine_fn.merge_accumulators([accumulator, value])
return self._combine_fn_copy.merge_accumulators([accumulator, value])
else:
return combine_fn.add_input(accumulator, value)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
extract_output = combine_fn.extract_output
teardown = combine_fn.teardown
return self._combine_fn_copy.add_input(accumulator, value)

def StripNonce(nonce_key_value):
(_, key), value = nonce_key_value
Expand Down

0 comments on commit eaf53e5

Please sign in to comment.