From 0a71499f3f5e20a8214e164f3f505c874cfdfbda Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:51:27 -0400 Subject: [PATCH] Fix counter metrics for ParDo#with_exception_handling(timeout). (#32571) Co-authored-by: Claude --- sdks/python/apache_beam/transforms/core.py | 18 ++++++++++--- .../apache_beam/transforms/ptransform_test.py | 26 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index e7180bc093b0..91ca4c8e33c3 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2611,11 +2611,23 @@ def __getattribute__(self, name): def process(self, *args, **kwargs): if self._pool is None: self._pool = concurrent.futures.ThreadPoolExecutor(10) + + # Import here to avoid circular dependency + from apache_beam.runners.worker.statesampler import get_current_tracker, set_current_tracker + + # State sampler/tracker is stored as a thread local variable, and is used + # when incrementing counter metrics. + dispatching_thread_state_sampler = get_current_tracker() + + def wrapped_process(): + """Makes the dispatching thread local state sampler available to child + thread""" + set_current_tracker(dispatching_thread_state_sampler) + return list(self._fn.process(*args, **kwargs)) + # Ensure we iterate over the entire output list in the given amount of time. try: - return self._pool.submit( - lambda: list(self._fn.process(*args, **kwargs))).result( - self._timeout) + return self._pool.submit(wrapped_process).result(self._timeout) except TimeoutError: self._pool.shutdown(wait=False) self._pool = None diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 2fdec14651f1..d760ef74fb14 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2780,6 +2780,32 @@ def test_timeout(self): ('slow', 'TimeoutError()')]), label='CheckBad') + def test_increment_counter(self): + # Counters are not currently supported for + # ParDo#with_exception_handling(use_subprocess=True). + if (self.use_subprocess): + return + + class CounterDoFn(beam.DoFn): + def __init__(self): + self.records_counter = Metrics.counter(self.__class__, 'recordsCounter') + + def process(self, element): + self.records_counter.inc() + + with TestPipeline() as p: + _, _ = ( + (p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn()) + .with_exception_handling( + use_subprocess=self.use_subprocess, timeout=1)) + results = p.result + metric_results = results.metrics().query( + MetricsFilter().with_name("recordsCounter")) + records_counter = metric_results['counters'][0] + + self.assertEqual(records_counter.key.metric.name, 'recordsCounter') + self.assertEqual(records_counter.result, 3) + def test_lifecycle(self): die = type(self).die