From bfc64d5c14adea209364d48b0fe9b4e8ba6eaab5 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Mon, 5 Aug 2024 11:34:45 +0200 Subject: [PATCH] Fix error when ActiveWorkRefresher processed empty heartbeat map. (#32078) --- .../work/refresh/ActiveWorkRefresher.java | 3 ++ .../work/refresh/ActiveWorkRefresherTest.java | 38 +++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java index 499d2e5b6943..781285def020 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java @@ -130,6 +130,9 @@ private void refreshActiveWork() { Instant refreshDeadline = clock.get().minus(Duration.millis(activeWorkRefreshPeriodMillis)); Map heartbeatsBySender = aggregateHeartbeatsBySender(refreshDeadline); + if (heartbeatsBySender.isEmpty()) { + return; + } List> fanOutRefreshActiveWork = new ArrayList<>(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java index 9dce3392c60c..5efb2421fe60 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -194,10 +195,13 @@ public void testActiveWorkRefresh() throws InterruptedException { assertThat(heartbeatRequests) .comparingElementsUsing( Correspondence.from( - (Windmill.HeartbeatRequest h, Work w) -> - h.getWorkToken() == w.getWorkItem().getWorkToken() - && h.getCacheToken() == w.getWorkItem().getWorkToken() - && h.getShardingKey() == w.getWorkItem().getShardingKey(), + (Windmill.HeartbeatRequest h, Work w) -> { + assert h != null; + assert w != null; + return h.getWorkToken() == w.getWorkItem().getWorkToken() + && h.getCacheToken() == w.getWorkItem().getWorkToken() + && h.getShardingKey() == w.getWorkItem().getShardingKey(); + }, "heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys should be equal.")) .containsExactlyElementsIn(work); } @@ -207,6 +211,32 @@ public void testActiveWorkRefresh() throws InterruptedException { workIsProcessed.countDown(); } + @Test + public void testEmptyActiveWorkRefresh() throws InterruptedException { + int activeWorkRefreshPeriodMillis = 100; + + List computations = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + ComputationState computationState = createComputationState(i); + computations.add(computationState); + } + + CountDownLatch heartbeatsSent = new CountDownLatch(1); + TestClock fakeClock = new TestClock(Instant.now()); + ActiveWorkRefresher activeWorkRefresher = + createActiveWorkRefresher( + fakeClock::now, + activeWorkRefreshPeriodMillis, + 0, + () -> computations, + heartbeats -> heartbeatsSent::countDown); + + activeWorkRefresher.start(); + fakeClock.advance(Duration.millis(activeWorkRefreshPeriodMillis * 2)); + assertFalse(heartbeatsSent.await(500, TimeUnit.MILLISECONDS)); + activeWorkRefresher.stop(); + } + @Test public void testInvalidateStuckCommits() throws InterruptedException { int stuckCommitDurationMillis = 100;