diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 511fe5617dff..458cf57ca8e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -307,23 +307,17 @@ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { connectionAndStream -> !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) .forEach( - entry -> { - CompletableFuture ignored = - CompletableFuture.runAsync( - () -> closeStreamSender(entry.getKey(), entry.getValue()), - windmillStreamManager); - }); + entry -> + windmillStreamManager.execute( + () -> closeStreamSender(entry.getKey(), entry.getValue()))); Set newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); currentBackends.globalDataStreams().values().stream() .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) .forEach( - sender -> { - CompletableFuture ignored = - CompletableFuture.runAsync( - () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager); - }); + sender -> + windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); } private void closeStreamSender(Endpoint endpoint, Closeable sender) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 749e689d93c9..b27ebc8e9eee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -17,15 +17,14 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; -import com.google.auto.value.AutoValue; import java.io.PrintWriter; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import javax.annotation.concurrent.GuardedBy; import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; @@ -122,7 +121,7 @@ private GrpcDirectGetWorkStream( this.getDataClient = getDataClient; this.lastRequest = new AtomicReference<>(); this.budgetTracker = - GetWorkBudgetTracker.create( + new GetWorkBudgetTracker( GetWorkBudget.builder() .setItems(requestHeader.getMaxItems()) .setBytes(requestHeader.getMaxBytes()) @@ -229,18 +228,9 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, " - + "max budget: %s, " - + "in-flight budget: %s, " - + "total budget requested: %s, " - + "total budget received: %s," - + "last sent request: %s. ", - workItemAssemblers.size(), - budgetTracker.maxGetWorkBudget().get(), - budgetTracker.inFlightBudget(), - budgetTracker.totalRequestedBudget(), - budgetTracker.totalReceivedBudget(), - lastRequest.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override @@ -300,49 +290,57 @@ private void executeSafely(Runnable runnable) { * extensions. */ @ThreadSafe - @AutoValue - abstract static class GetWorkBudgetTracker { - - private static GetWorkBudgetTracker create(GetWorkBudget initialMaxGetWorkBudget) { - return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker( - new AtomicReference<>(initialMaxGetWorkBudget), - new AtomicLong(), - new AtomicLong(), - new AtomicLong(), - new AtomicLong()); - } + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; - abstract AtomicReference maxGetWorkBudget(); + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; - abstract AtomicLong itemsRequested(); + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; - abstract AtomicLong bytesRequested(); + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; - abstract AtomicLong itemsReceived(); + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; - abstract AtomicLong bytesReceived(); + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } private synchronized void reset() { - itemsRequested().set(0); - bytesRequested().set(0); - itemsReceived().set(0); - bytesReceived().set(0); + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); } /** Consumes the new budget and computes an extension based on the new budget. */ private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { - maxGetWorkBudget().set(newBudget); + maxGetWorkBudget = newBudget; return computeBudgetExtension(); } private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { - itemsRequested().addAndGet(budgetRequested.items()); - bytesRequested().addAndGet(budgetRequested.bytes()); + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); } - private synchronized void recordBudgetReceived(long bytesReceived) { - itemsReceived().incrementAndGet(); - bytesReceived().addAndGet(bytesReceived); + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; } /** @@ -351,11 +349,10 @@ private synchronized void recordBudgetReceived(long bytesReceived) { * without sending too many extension requests. */ private synchronized GetWorkBudget computeBudgetExtension() { - GetWorkBudget maxGetWorkBudget = maxGetWorkBudget().get(); // Expected items and bytes can go negative here, since WorkItems returned might be larger // than the initially requested budget. - long inFlightItems = itemsRequested().get() - itemsReceived().get(); - long inFlightBytes = bytesRequested().get() - bytesReceived().get(); + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; // Don't send negative budget extensions. long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); @@ -366,25 +363,19 @@ private synchronized GetWorkBudget computeBudgetExtension() { : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); } - private GetWorkBudget inFlightBudget() { + private synchronized GetWorkBudget inFlightBudget() { return GetWorkBudget.builder() - .setItems(itemsRequested().get() - itemsReceived().get()) - .setBytes(bytesRequested().get() - bytesReceived().get()) + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) .build(); } - private GetWorkBudget totalRequestedBudget() { - return GetWorkBudget.builder() - .setItems(itemsRequested().get()) - .setBytes(bytesRequested().get()) - .build(); + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); } - private GetWorkBudget totalReceivedBudget() { - return GetWorkBudget.builder() - .setItems(itemsReceived().get()) - .setBytes(bytesReceived().get()) - .build(); + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index e78c2c685ffe..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -43,11 +43,11 @@ public void distributeBudget( return; } - GetWorkBudget budgetPerStream = computeDesiredBudgets(budgetSpenders, getWorkBudget); + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private GetWorkBudget computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { return GetWorkBudget.builder() .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index fad8c75ad838..0092fcc7bcd1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -214,7 +214,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); @@ -286,7 +286,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = @@ -334,19 +334,14 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); getWorkBudgetDistributor.expectNumDistributions(1); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); - } - private static class GetWorkerMetadataTestStub extends CloudWindmillMetadataServiceV1Alpha1Grpc .CloudWindmillMetadataServiceV1Alpha1ImplBase { @@ -395,9 +390,8 @@ private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + private boolean waitForBudgetDistribution() throws InterruptedException { + return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); } private void expectNumDistributions(int numBudgetDistributionsExpected) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java index 9c53763e7694..fd2b30238836 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -20,10 +20,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -58,8 +60,11 @@ @RunWith(JUnit4.class) public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() + .setClientId(1L) .setJobId("test_job") .setWorkerId("test_worker") .setProjectId("test_project") @@ -78,6 +83,22 @@ private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget .build(); } + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + @Before public void setUp() throws IOException { Server server = @@ -100,30 +121,6 @@ public void cleanUp() { checkNotNull(stream).shutdown(); } - private GrpcDirectGetWorkStream createGetWorkStream( - GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { - return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); - } - - private GrpcDirectGetWorkStream createGetWorkStream( - GetWorkStreamTestStub testStub, - GetWorkBudget initialGetWorkBudget, - WorkItemScheduler workItemScheduler) { - return createGetWorkStream( - testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); - } - - private GrpcDirectGetWorkStream createGetWorkStream( - GetWorkStreamTestStub testStub, - GetWorkBudget initialGetWorkBudget, - ThrottleTimer throttleTimer) { - return createGetWorkStream( - testStub, - initialGetWorkBudget, - throttleTimer, - (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); - } - private GrpcDirectGetWorkStream createGetWorkStream( GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget, @@ -173,14 +170,17 @@ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); - stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); stream.setBudget(newBudget); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); // Header and extension. assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) .isEqualTo(extension(newBudget)); } @@ -193,16 +193,19 @@ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); - stream = createGetWorkStream(testStub, initialBudget); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); stream.setBudget(newBudget); GetWorkBudget diff = newBudget.subtract(initialBudget); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); List requests = requestObserver.sent(); // Header and extension. assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); } @@ -213,19 +216,19 @@ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); stream = createGetWorkStream( - testStub, - GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); List requests = requestObserver.sent(); // Assert that the extension was never sent, only the header. assertThat(requests).hasSize(expectedRequests); - assertThat(Iterables.getOnlyElement(requests).getRequest()) - .isInstanceOf(Windmill.GetWorkRequest.class); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); } @Test @@ -234,18 +237,19 @@ public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedExcept CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); - stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); stream.shutdown(); stream.setBudget( GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); List requests = requestObserver.sent(); // Assert that the extension was never sent, only the header. assertThat(requests).hasSize(1); - assertThat(Iterables.getOnlyElement(requests).getRequest()) - .isInstanceOf(Windmill.GetWorkRequest.class); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); } @Test @@ -260,6 +264,7 @@ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws Inter createGetWorkStream( testStub, initialBudget, + new ThrottleTimer(), (work, watermarks, processingContext, getWorkStreamLatencies) -> { scheduledWorkItems.add(work); }); @@ -273,13 +278,14 @@ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws Inter testStub.injectResponse(createResponse(workItem)); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); assertThat(scheduledWorkItems).containsExactly(workItem); List requests = requestObserver.sent(); long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); assertThat(Iterables.getLast(requests).getRequestExtension()) .isEqualTo( extension( @@ -297,13 +303,15 @@ public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); stream = createGetWorkStream( testStub, - GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(), - (work, watermarks, processingContext, getWorkStreamLatencies) -> { - scheduledWorkItems.add(work); - }); + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); Windmill.WorkItem workItem = Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("somewhat_long_key")) @@ -314,15 +322,14 @@ public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() testStub.injectResponse(createResponse(workItem)); - waitForRequests.await(5, TimeUnit.SECONDS); + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); assertThat(scheduledWorkItems).containsExactly(workItem); List requests = requestObserver.sent(); // Assert that the extension was never sent, only the header. assertThat(requests).hasSize(expectedRequests); - assertThat(Iterables.getOnlyElement(requests).getRequest()) - .isInstanceOf(Windmill.GetWorkRequest.class); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); } @Test @@ -331,8 +338,11 @@ public void testOnResponse_stopsThrottling() { TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(new CountDownLatch(1)); GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); - stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), throttleTimer); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); assertFalse(throttleTimer.throttled()); } @@ -364,9 +374,11 @@ private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk private static class TestGetWorkRequestObserver implements StreamObserver { - private final List requests = new ArrayList<>(); + private final List requests = + Collections.synchronizedList(new ArrayList<>()); private final CountDownLatch waitForRequests; - private @Nullable StreamObserver responseObserver; + private @Nullable volatile StreamObserver + responseObserver; public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { this.waitForRequests = waitForRequests; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index b966240cfc4a..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -57,7 +57,7 @@ public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = spy(createGetWorkBudgetOwner()); + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); @@ -65,46 +65,41 @@ public void testDistributeBudget_doesNothingWithNoBudget() { @Test public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwner())); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); } GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test public void testDistributeBudget_distributesFairlyWhenNotEven() { - long totalItemsAndBytes = 10L; + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwner())); + streams.add(createGetWorkBudgetOwner()); } GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); streams.forEach( stream -> verify(stream, times(1)) - .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test @@ -112,7 +107,7 @@ public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwner())); + streams.add(createGetWorkBudgetOwner()); } GetWorkBudgetDistributors.distributeEvenly()