diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 5cb5bd04dbe5..683f94eb71ee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -72,6 +72,9 @@ public final class GrpcDirectGetWorkStream .build()) .build(); + private final AtomicReference inFlightBudget; + private final AtomicReference nextBudgetAdjustment; + private final AtomicReference pendingResponseBudget; private final GetWorkRequest request; private final WorkItemProcessor workItemProcessorFn; private final ThrottleTimer getWorkThrottleTimer; @@ -84,10 +87,6 @@ public final class GrpcDirectGetWorkStream */ private final ConcurrentMap workItemBuffers; - private final AtomicReference inflightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private GrpcDirectGetWorkStream( Function< StreamObserver, @@ -112,8 +111,7 @@ private GrpcDirectGetWorkStream( // stream. this.getDataStream = Suppliers.memoize(getDataStream::get); this.commitWorkStream = Suppliers.memoize(commitWorkStream::get); - - this.inflightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); } @@ -178,10 +176,11 @@ private void sendRequestExtension() { @Override protected synchronized void onNewStream() { workItemBuffers.clear(); - // Add the current inflight budget to the next adjustment. Only positive values are allowed here + // Add the current in-flight budget to the next adjustment. Only positive values are allowed + // here // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inflightBudget.get()); - inflightBudget.set(budgetAdjustment); + GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); + inFlightBudget.set(budgetAdjustment); send( StreamingGetWorkRequest.newBuilder() .setRequest( @@ -205,7 +204,7 @@ public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemBuffers.size(), inflightBudget.get()); + workItemBuffers.size(), inFlightBudget.get()); } @Override @@ -224,7 +223,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { if (chunk.getRemainingBytesForWorkItem() == 0) { workItemBuffer.runAndReset(); // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inflightBudget.updateAndGet(budget -> budget.subtract(1, workItemBuffer.bufferedSize())); + inFlightBudget.updateAndGet(budget -> budget.subtract(1, workItemBuffer.bufferedSize())); } } @@ -244,18 +243,11 @@ public GetWorkBudget remainingBudget() { // Snapshot the current budgets. GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inflightBudget.get(); + GetWorkBudget currentInflightBudget = inFlightBudget.get(); - return GetWorkBudget.builder() - .setItems( - currentNextBudgetAdjustment.items() - + currentPendingResponseBudget.items() - + currentInflightBudget.items()) - .setBytes( - currentNextBudgetAdjustment.bytes() - + currentPendingResponseBudget.bytes() - + currentInflightBudget.bytes()) - .build(); + return currentPendingResponseBudget + .apply(currentNextBudgetAdjustment) + .apply(currentInflightBudget); } private synchronized void updatePendingResponseBudget(long itemsDelta, long bytesDelta) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java index 80f0bef6306b..8a2c643a5b76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -196,7 +196,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - Thread streamingEngineClientThread = new Thread(this::waitForFirstWorkerMetadata); WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -205,7 +204,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { .putAllGlobalDataEndpoints(DEFAULT) .build(); - streamingEngineClientThread.start(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(1); @@ -240,9 +238,6 @@ public void testScheduledBudgetRefresh() throws InterruptedException { newStreamingEngineClient( GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), noOpProcessWorkItemFn()); - Thread streamingEngineClientThread = new Thread(this::waitForFirstWorkerMetadata); - - streamingEngineClientThread.start(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() @@ -266,7 +261,6 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() String workerToken2 = "workerToken2"; String workerToken3 = "workerToken3"; - Thread streamingEngineClientThread = new Thread(this::waitForFirstWorkerMetadata); WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -284,7 +278,6 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() .putAllGlobalDataEndpoints(DEFAULT) .build(); - streamingEngineClientThread.start(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); @@ -314,7 +307,6 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce String workerToken2 = "workerToken2"; String workerToken3 = "workerToken3"; - Thread streamingEngineClientThread = new Thread(this::waitForFirstWorkerMetadata); WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -337,7 +329,6 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .putAllGlobalDataEndpoints(DEFAULT) .build(); - streamingEngineClientThread.start(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); Thread.sleep(50); @@ -366,10 +357,6 @@ private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed( return connections.get(); } - private void waitForFirstWorkerMetadata() { - while (!Preconditions.checkNotNull(streamingEngineClient).isWorkerMetadataReady()) {} - } - private static class GetWorkerMetadataTestStub extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { private static final WorkerMetadataResponse CLOSE_ALL_STREAMS =