diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 8501a41410f6..511fe5617dff 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -30,6 +30,7 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -198,7 +199,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); @@ -234,9 +234,29 @@ private GetDataStream getGlobalDataStream(String globalDataKey) { @Override public synchronized void shutdown() { Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); - Preconditions.checkNotNull(getWorkerMetadataStream).halfClose(); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); workerMetadataConsumer.shutdownNow(); + closeStreamsNotIn(WindmillEndpoints.none()); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } } private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { @@ -265,7 +285,7 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi newWindmillEndpoints, activeMetadataVersion, newWindmillEndpoints.version()); - closeStaleStreams(newWindmillEndpoints); + closeStreamsNotIn(newWindmillEndpoints); ImmutableMap newStreams = createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); StreamingEngineBackends newBackends = @@ -280,29 +300,30 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi } /** Close the streams that are no longer valid asynchronously. */ - @SuppressWarnings("FutureReturnValueIgnored") - private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) { + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - ImmutableMap currentWindmillStreams = - currentBackends.windmillStreams(); - currentWindmillStreams.entrySet().stream() + currentBackends.windmillStreams().entrySet().stream() .filter( connectionAndStream -> !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) .forEach( - entry -> - CompletableFuture.runAsync( - () -> closeStreamSender(entry.getKey(), entry.getValue()), - windmillStreamManager)); + entry -> { + CompletableFuture ignored = + CompletableFuture.runAsync( + () -> closeStreamSender(entry.getKey(), entry.getValue()), + windmillStreamManager); + }); Set newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); currentBackends.globalDataStreams().values().stream() .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) .forEach( - sender -> - CompletableFuture.runAsync( - () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)); + sender -> { + CompletableFuture ignored = + CompletableFuture.runAsync( + () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager); + }); } private void closeStreamSender(Endpoint endpoint, Closeable sender) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index 8c1ceada741a..eb269eef848f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + } + public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index dc8d4e633982..f26c56b14ec2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -56,10 +56,10 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void setBudget(long newItems, long newBytes); + void setBudget(GetWorkBudget newBudget); - default void setBudget(GetWorkBudget newBudget) { - setBudget(newBudget.items(), newBudget.bytes()); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index fc64da8b2e41..749e689d93c9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -70,7 +71,6 @@ final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference maxGetWorkBudget; private final GetWorkBudgetTracker budgetTracker; private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; @@ -120,14 +120,13 @@ private GrpcDirectGetWorkStream( this.heartbeatSender = heartbeatSender; this.workCommitter = workCommitter; this.getDataClient = getDataClient; - this.maxGetWorkBudget = - new AtomicReference<>( + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + GetWorkBudgetTracker.create( GetWorkBudget.builder() .setItems(requestHeader.getMaxItems()) .setBytes(requestHeader.getMaxBytes()) .build()); - this.lastRequest = new AtomicReference<>(); - this.budgetTracker = GetWorkBudgetTracker.create(); } static GrpcDirectGetWorkStream create( @@ -146,19 +145,22 @@ static GrpcDirectGetWorkStream create( GetDataClient getDataClient, WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { - return new GrpcDirectGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - heartbeatSender, - getDataClient, - workCommitter, - workItemScheduler); + GrpcDirectGetWorkStream getWorkStream = + new GrpcDirectGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + heartbeatSender, + getDataClient, + workCommitter, + workItemScheduler); + getWorkStream.startStream(); + return getWorkStream; } private static Watermarks createWatermarks( @@ -188,7 +190,11 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { .build(); lastRequest.set(request); budgetTracker.recordBudgetRequested(extension); - send(request); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was closed. + } }); } } @@ -198,8 +204,7 @@ protected synchronized void onNewStream() { workItemAssemblers.clear(); if (!isShutdown()) { budgetTracker.reset(); - GetWorkBudget initialGetWorkBudget = - budgetTracker.computeBudgetExtension(maxGetWorkBudget.get()); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( @@ -231,7 +236,7 @@ public void appendSpecificHtml(PrintWriter writer) { + "total budget received: %s," + "last sent request: %s. ", workItemAssemblers.size(), - maxGetWorkBudget.get(), + budgetTracker.maxGetWorkBudget().get(), budgetTracker.inFlightBudget(), budgetTracker.totalRequestedBudget(), budgetTracker.totalReceivedBudget(), @@ -262,7 +267,7 @@ private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { createProcessingContext(metadata.computationId()), assembledWorkItem.latencyAttributions()); budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); - GetWorkBudget extension = budgetTracker.computeBudgetExtension(maxGetWorkBudget.get()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); maybeSendRequestExtension(extension); } @@ -277,26 +282,38 @@ protected void startThrottleTimer() { } @Override - public void setBudget(long newItems, long newBytes) { - GetWorkBudget currentMaxGetWorkBudget = - maxGetWorkBudget.updateAndGet( - ignored -> GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); - GetWorkBudget extension = budgetTracker.computeBudgetExtension(currentMaxGetWorkBudget); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); maybeSendRequestExtension(extension); } + private void executeSafely(Runnable runnable) { + try { + executor().execute(runnable); + } catch (RejectedExecutionException e) { + LOG.debug("{} has been shutdown.", getClass()); + } + } + /** - * Tracks sent and received GetWorkBudget and uses this information to generate request + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request * extensions. */ + @ThreadSafe @AutoValue abstract static class GetWorkBudgetTracker { - private static GetWorkBudgetTracker create() { + private static GetWorkBudgetTracker create(GetWorkBudget initialMaxGetWorkBudget) { return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker( - new AtomicLong(), new AtomicLong(), new AtomicLong(), new AtomicLong()); + new AtomicReference<>(initialMaxGetWorkBudget), + new AtomicLong(), + new AtomicLong(), + new AtomicLong(), + new AtomicLong()); } + abstract AtomicReference maxGetWorkBudget(); + abstract AtomicLong itemsRequested(); abstract AtomicLong bytesRequested(); @@ -305,19 +322,25 @@ private static GetWorkBudgetTracker create() { abstract AtomicLong bytesReceived(); - private void reset() { + private synchronized void reset() { itemsRequested().set(0); bytesRequested().set(0); itemsReceived().set(0); bytesReceived().set(0); } - private void recordBudgetRequested(GetWorkBudget budgetRequested) { + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget().set(newBudget); + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { itemsRequested().addAndGet(budgetRequested.items()); bytesRequested().addAndGet(budgetRequested.bytes()); } - private void recordBudgetReceived(long bytesReceived) { + private synchronized void recordBudgetReceived(long bytesReceived) { itemsReceived().incrementAndGet(); bytesReceived().addAndGet(bytesReceived); } @@ -327,7 +350,8 @@ private void recordBudgetReceived(long bytesReceived) { * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values * without sending too many extension requests. */ - private GetWorkBudget computeBudgetExtension(GetWorkBudget maxGetWorkBudget) { + private synchronized GetWorkBudget computeBudgetExtension() { + GetWorkBudget maxGetWorkBudget = maxGetWorkBudget().get(); // Expected items and bytes can go negative here, since WorkItems returned might be larger // than the initially requested budget. long inFlightItems = itemsRequested().get() - itemsReceived().get(); @@ -363,14 +387,4 @@ private GetWorkBudget totalReceivedBudget() { .build(); } } - - private void executeSafely(Runnable runnable) { - try { - executor().execute(runnable); - } catch (RejectedExecutionException e) { - LOG.debug("{} has been shutdown.", getClass()); - } catch (IllegalStateException e) { - // Stream was closed. - } - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..c99e05a77074 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index dcfadbbde136..a368f3fec235 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; @@ -193,7 +194,7 @@ protected void startThrottleTimer() { } @Override - public void setBudget(long newItems, long newBytes) { + public void setBudget(GetWorkBudget newBudget) { // no-op } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 25587811f9be..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -64,7 +64,7 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + " WindmillServiceAddresses."); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 38419e0fb034..e78c2c685ffe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -49,8 +49,6 @@ public void distributeBudget( private GetWorkBudget computeDesiredBudgets( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. return GetWorkBudget.builder() .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index e50083a66818..90ffb3d3fbcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -66,6 +66,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; @@ -244,7 +245,7 @@ public void halfClose() { } @Override - public void setBudget(long newItems, long newBytes) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..9c53763e7694 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) { + return createGetWorkStream(testStub, initialGetWorkBudget, new ThrottleTimer()); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + WorkItemScheduler workItemScheduler) { + return createGetWorkStream( + testStub, initialGetWorkBudget, new ThrottleTimer(), workItemScheduler); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer) { + return createGetWorkStream( + testStub, + initialGetWorkBudget, + throttleTimer, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = createGetWorkStream(testStub, initialBudget); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget()); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + waitForRequests.await(5, TimeUnit.SECONDS); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + waitForRequests.await(5, TimeUnit.SECONDS); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertThat(Iterables.getOnlyElement(requests).getRequest()) + .isInstanceOf(Windmill.GetWorkRequest.class); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(), throttleTimer); + stream.startThrottleTimer(); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = new ArrayList<>(); + private final CountDownLatch waitForRequests; + private @Nullable StreamObserver responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +}