diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index ec5122a8732a..a12a5075c5ee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -26,8 +26,12 @@ public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - public WorkItemCancelledException(Throwable e) { - super(e); + public WorkItemCancelledException(String message, Throwable cause) { + super(message, cause); + } + + public WorkItemCancelledException(Throwable cause) { + super(cause); } /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 458cf57ca8e7..115142f98b9c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,7 +20,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.io.Closeable; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; import java.util.Map.Entry; import java.util.NoSuchElementException; @@ -34,6 +34,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -65,6 +66,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; @@ -112,7 +114,7 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker private boolean started; @GuardedBy("this") - private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream = null; private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, @@ -143,7 +145,6 @@ private FanOutStreamingEngineWorkerHarness( this.totalGetWorkBudget = totalGetWorkBudget; this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - this.getWorkerMetadataStream = null; } /** @@ -204,9 +205,10 @@ public synchronized void start() { Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); getWorkerMetadataStream = streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), + dispatcherClient::getWindmillMetadataServiceStubBlocking, getWorkerMetadataThrottleTimer, this::consumeWorkerMetadata); + getWorkerMetadataStream.start(); started = true; } @@ -225,7 +227,7 @@ public ImmutableSet currentWindmillEndpoints() { */ private GetDataStream getGlobalDataStream(String globalDataKey) { return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) - .map(GlobalDataStreamSender::get) + .map(GlobalDataStreamSender::stream) .orElseThrow( () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @@ -236,7 +238,8 @@ public synchronized void shutdown() { Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); workerMetadataConsumer.shutdownNow(); - closeStreamsNotIn(WindmillEndpoints.none()); + // Close all the streams blocking until this completes to not leak resources. + closeStreamsNotIn(WindmillEndpoints.none()).join(); channelCachingStubFactory.shutdown(); try { @@ -300,27 +303,38 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi } /** Close the streams that are no longer valid asynchronously. */ - private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + @CanIgnoreReturnValue + private CompletableFuture closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - currentBackends.windmillStreams().entrySet().stream() - .filter( - connectionAndStream -> - !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) - .forEach( - entry -> - windmillStreamManager.execute( - () -> closeStreamSender(entry.getKey(), entry.getValue()))); + Stream> closeStreamFutures = + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints + .windmillEndpoints() + .contains(connectionAndStream.getKey())) + .map( + entry -> + CompletableFuture.runAsync( + () -> closeStreamSender(entry.getKey(), entry.getValue()), + windmillStreamManager)); Set newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); - currentBackends.globalDataStreams().values().stream() - .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) - .forEach( - sender -> - windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + Stream> closeGlobalDataStreamFutures = + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .map( + sender -> + CompletableFuture.runAsync( + () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)); + + return CompletableFuture.allOf( + Streams.concat(closeStreamFutures, closeGlobalDataStreamFutures) + .toArray(CompletableFuture[]::new)); } - private void closeStreamSender(Endpoint endpoint, Closeable sender) { + private void closeStreamSender(Endpoint endpoint, StreamSender sender) { LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); try { sender.close(); @@ -346,13 +360,14 @@ private void closeStreamSender(Endpoint endpoint, Closeable sender) { private CompletionStage> getOrCreateWindmillStreamSenderFuture( Endpoint endpoint, ImmutableMap currentStreams) { - return MoreFutures.supplyAsync( - () -> - Pair.of( - endpoint, - Optional.ofNullable(currentStreams.get(endpoint)) - .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), - windmillStreamManager); + return Optional.ofNullable(currentStreams.get(endpoint)) + .map(backend -> CompletableFuture.completedFuture(Pair.of(endpoint, backend))) + .orElseGet( + () -> + MoreFutures.supplyAsync( + () -> Pair.of(endpoint, createAndStartWindmillStreamSender(endpoint)), + windmillStreamManager) + .toCompletableFuture()); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ @@ -393,9 +408,8 @@ private GlobalDataStreamSender getOrCreateGlobalDataSteam( .orElseGet( () -> new GlobalDataStreamSender( - () -> - streamFactory.createGetDataStream( - createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), keyedEndpoint.getValue())); } @@ -416,7 +430,7 @@ private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoin StreamGetDataClient.create( getDataStream, this::getGlobalDataStream, getDataMetricTracker), workCommitterFactory); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); return windmillStreamSender; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index ce5f3a7b6bfc..d590e69c17d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -17,44 +17,45 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; -import java.io.Closeable; -import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; @Internal @ThreadSafe -// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is -// merged -final class GlobalDataStreamSender implements Closeable, Supplier { +final class GlobalDataStreamSender implements StreamSender { private final Endpoint endpoint; - private final Supplier delegate; + private final GetDataStream delegate; private volatile boolean started; - GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { - // Ensures that the Supplier is thread-safe - this.delegate = Suppliers.memoize(delegate::get); + GlobalDataStreamSender(GetDataStream delegate, Endpoint endpoint) { + this.delegate = delegate; this.started = false; this.endpoint = endpoint; } - @Override - public GetDataStream get() { + GetDataStream stream() { if (!started) { - started = true; + // Starting the stream possibly perform IO. Start the stream lazily since not all pipeline + // implementations need to fetch global/side input data. + startStream(); } - return delegate.get(); + return delegate; + } + + private synchronized void startStream() { + // Check started again after we acquire the lock. + if (!started) { + delegate.start(); + started = true; + } } @Override public void close() { - if (started) { - delegate.get().shutdown(); - } + delegate.shutdown(); } Endpoint endpoint() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index bc93e6d89c41..22fba91e170a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -33,7 +33,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.RpcException; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; @@ -199,7 +199,7 @@ private void applianceDispatchLoop(Supplier getWorkFn) if (workResponse.getWorkCount() > 0) { break; } - } catch (RpcException e) { + } catch (WindmillRpcException e) { LOG.warn("GetWork failed, retrying:", e); } sleepUninterruptibly(backoff, TimeUnit.MILLISECONDS); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java new file mode 100644 index 000000000000..40a63571620f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +interface StreamSender { + void close(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 744c3d74445f..2a2f49dff846 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,11 +17,14 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; -import java.io.Closeable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; @@ -37,20 +40,13 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetSpender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; /** * Owns and maintains a set of streams used to communicate with a specific Windmill worker. - * Underlying streams are "cached" in a threadsafe manner so that once {@link Supplier#get} is - * called, a stream that is already started is returned. - * - *

Holds references to {@link - * Supplier} because - * initializing the streams automatically start them, and we want to do so lazily here once the - * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #close()} ()}. + * #close()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -60,14 +56,16 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { +final class WindmillStreamSender implements GetWorkBudgetSpender, StreamSender { + private static final String STREAM_STARTER_THREAD_NAME = "StartWindmillStreamThread-%d"; private final AtomicBoolean started; private final AtomicReference getWorkBudget; - private final Supplier getWorkStream; - private final Supplier getDataStream; - private final Supplier commitWorkStream; - private final Supplier workCommitter; + private final GetWorkStream getWorkStream; + private final GetDataStream getDataStream; + private final CommitWorkStream commitWorkStream; + private final WorkCommitter workCommitter; private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; + private final ExecutorService streamStarter; private WindmillStreamSender( WindmillConnection connection, @@ -81,33 +79,28 @@ private WindmillStreamSender( this.getWorkBudget = getWorkBudget; this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create(); - // All streams are memoized/cached since they are expensive to create and some implementations - // perform side effects on construction (i.e. sending initial requests to the stream server to - // initiate the streaming RPC connection). Stream instances connect/reconnect internally, so we - // can reuse the same instance through the entire lifecycle of WindmillStreamSender. + // Stream instances connect/reconnect internally, so we can reuse the same instance through the + // entire lifecycle of WindmillStreamSender. this.getDataStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createGetDataStream( - connection.stub(), streamingEngineThrottleTimers.getDataThrottleTimer())); + streamingEngineStreamFactory.createDirectGetDataStream( + connection, streamingEngineThrottleTimers.getDataThrottleTimer()); this.commitWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createCommitWorkStream( - connection.stub(), streamingEngineThrottleTimers.commitWorkThrottleTimer())); - this.workCommitter = - Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get())); + streamingEngineStreamFactory.createDirectCommitWorkStream( + connection, streamingEngineThrottleTimers.commitWorkThrottleTimer()); + this.workCommitter = workCommitterFactory.apply(commitWorkStream); this.getWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createDirectGetWorkStream( - connection, - withRequestBudget(getWorkRequest, getWorkBudget.get()), - streamingEngineThrottleTimers.getWorkThrottleTimer(), - FixedStreamHeartbeatSender.create(getDataStream.get()), - getDataClientFactory.apply(getDataStream.get()), - workCommitter.get(), - workItemScheduler)); + streamingEngineStreamFactory.createDirectGetWorkStream( + connection, + withRequestBudget(getWorkRequest, getWorkBudget.get()), + streamingEngineThrottleTimers.getWorkThrottleTimer(), + FixedStreamHeartbeatSender.create(getDataStream), + getDataClientFactory.apply(getDataStream), + workCommitter, + workItemScheduler); + // 3 threads, 1 for each stream type (GetWork, GetData, CommitWork). + this.streamStarter = + Executors.newFixedThreadPool( + 3, new ThreadFactoryBuilder().setNameFormat(STREAM_STARTER_THREAD_NAME).build()); } static WindmillStreamSender create( @@ -132,34 +125,36 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB return request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build(); } - @SuppressWarnings("ReturnValueIgnored") - void startStreams() { - getWorkStream.get(); - getDataStream.get(); - commitWorkStream.get(); - workCommitter.get().start(); - // *stream.get() is all memoized in a threadsafe manner. - started.set(true); + synchronized void start() { + if (!started.get()) { + checkState(!streamStarter.isShutdown(), "WindmillStreamSender has already been shutdown."); + + // Start these 3 streams in parallel since they each may perform blocking IO. + CompletableFuture.allOf( + CompletableFuture.runAsync(getWorkStream::start, streamStarter), + CompletableFuture.runAsync(getDataStream::start, streamStarter), + CompletableFuture.runAsync(commitWorkStream::start, streamStarter)) + .join(); + workCommitter.start(); + started.set(true); + } } @Override - public void close() { - // Supplier.get() starts the stream which is an expensive operation as it initiates the - // streaming RPCs by possibly making calls over the network. Do not close the streams unless - // they have already been started. - if (started.get()) { - getWorkStream.get().shutdown(); - getDataStream.get().shutdown(); - workCommitter.get().stop(); - commitWorkStream.get().shutdown(); - } + public synchronized void close() { + streamStarter.shutdownNow(); + getWorkStream.shutdown(); + getDataStream.shutdown(); + workCommitter.stop(); + commitWorkStream.shutdown(); } @Override public void setBudget(long items, long bytes) { - getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); + GetWorkBudget budget = GetWorkBudget.builder().setItems(items).setBytes(bytes).build(); + getWorkBudget.set(budget); if (started.get()) { - getWorkStream.get().setBudget(items, bytes); + getWorkStream.setBudget(budget); } } @@ -168,6 +163,6 @@ long getAndResetThrottleTime() { } long getCurrentActiveCommitBytes() { - return started.get() ? workCommitter.get().currentActiveCommitBytes() : 0; + return workCommitter.currentActiveCommitBytes(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index eb269eef848f..13b3ea954198 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -40,13 +40,15 @@ @AutoValue public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + private static final WindmillEndpoints NO_ENDPOINTS = + WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); public static WindmillEndpoints none() { - return WindmillEndpoints.builder() - .setVersion(Long.MAX_VALUE) - .setWindmillEndpoints(ImmutableSet.of()) - .setGlobalDataEndpoints(ImmutableMap.of()) - .build(); + return NO_ENDPOINTS; } public static WindmillEndpoints from( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index cd753cb8ec91..2ae97087fec7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -30,10 +30,16 @@ public abstract class WindmillServerStub @Override public void appendSummaryHtml(PrintWriter writer) {} - /** Generic Exception type for implementors to use to represent errors while making RPCs. */ - public static final class RpcException extends RuntimeException { - public RpcException(Throwable cause) { + /** + * Generic Exception type for implementors to use to represent errors while making Windmill RPCs. + */ + public static final class WindmillRpcException extends RuntimeException { + public WindmillRpcException(Throwable cause) { super(cause); } + + public WindmillRpcException(String message, Throwable cause) { + super(message, cause); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 58aecfc71e00..8b48459eba94 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,30 +17,27 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.IOException; import java.io.PrintWriter; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.StatusRuntimeException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Base class for persistent streams connecting to Windmill. @@ -49,46 +46,55 @@ * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a * broken stream. * - *

Subclasses should override onResponse to handle responses from the server, and onNewStream to - * perform any work that must be done when a new stream is created, such as sending headers or - * retrying requests. + *

Subclasses should override {@link #onResponse(ResponseT)} to handle responses from the server, + * and {@link #onNewStream()} to perform any work that must be done when a new stream is created, + * such as sending headers or retrying requests. * - *

send and startStream should not be called from onResponse; use executor() instead. + *

{@link #trySend(RequestT)} and {@link #startStream()} should not be called from {@link + * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead. * *

Synchronization on this is used to synchronize the gRpc stream state and internal data * structures. Since grpc channel operations may block, synchronization on this stream may also * block. This is generally not a problem since streams are used in a single-threaded manner. * However, some accessors used for status page and other debugging need to take care not to require * synchronizing on this. + * + *

{@link #start()} and {@link #shutdown()} are called once in the lifetime of the stream. Once + * {@link #shutdown()}, a stream in considered invalid and cannot be restarted/reused. */ public abstract class AbstractWindmillStream implements WindmillStream { - public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; - private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); - protected final AtomicBoolean clientClosed; - private final AtomicBoolean isShutdown; - private final AtomicLong lastSendTimeMs; - private final Executor executor; + // Indicates that the logical stream has been half-closed and is waiting for clean server + // shutdown. + private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); + private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; + private static final String NOT_SHUTDOWN = "not shutdown"; + protected final Sleeper sleeper; + + private final Logger logger; + private final ExecutorService executor; private final BackOff backoff; - private final AtomicLong startTimeMs; - private final AtomicLong lastResponseTimeMs; - private final AtomicInteger errorCount; - private final AtomicReference lastError; - private final AtomicReference lastErrorTime; - private final AtomicLong sleepUntil; private final CountDownLatch finishLatch; private final Set> streamRegistry; private final int logEveryNStreamFailures; - private final Supplier> requestObserverSupplier; - // Indicates if the current stream in requestObserver is closed by calling close() method - private final AtomicBoolean streamClosed; private final String backendWorkerToken; - private @Nullable StreamObserver requestObserver; + private final ResettableThrowingStreamObserver requestObserver; + private final StreamDebugMetrics debugMetrics; + + @GuardedBy("this") + protected boolean clientClosed; + + @GuardedBy("this") + protected boolean isShutdown; + + @GuardedBy("this") + private boolean started; protected AbstractWindmillStream( + Logger logger, String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, @@ -106,21 +112,20 @@ protected AbstractWindmillStream( this.backoff = backoff; this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; - this.clientClosed = new AtomicBoolean(); - this.streamClosed = new AtomicBoolean(); - this.startTimeMs = new AtomicLong(); - this.lastSendTimeMs = new AtomicLong(); - this.lastResponseTimeMs = new AtomicLong(); - this.errorCount = new AtomicInteger(); - this.lastError = new AtomicReference<>(); - this.lastErrorTime = new AtomicReference<>(); - this.sleepUntil = new AtomicLong(); + this.clientClosed = false; + this.isShutdown = false; + this.started = false; this.finishLatch = new CountDownLatch(1); - this.isShutdown = new AtomicBoolean(false); - this.requestObserverSupplier = - () -> - streamObserverFactory.from( - clientFactory, new AbstractWindmillStream.ResponseObserver()); + this.logger = logger; + this.requestObserver = + new ResettableThrowingStreamObserver<>( + () -> + streamObserverFactory.from( + clientFactory, + new AbstractWindmillStream.ResponseObserver()), + logger); + this.sleeper = Sleeper.DEFAULT; + this.debugMetrics = StreamDebugMetrics.create(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -129,18 +134,11 @@ private static String createThreadName(String streamType, String backendWorkerTo : String.format("%s-WindmillStream-thread", streamType); } - private static long debugDuration(long nowMs, long startMs) { - if (startMs <= 0) { - return -1; - } - return Math.max(0, nowMs - startMs); - } - /** Called on each response from the server. */ protected abstract void onResponse(ResponseT response); /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream(); + protected abstract void onNewStream() throws WindmillStreamShutdownException; /** Returns whether there are any pending requests that should be retried on a stream break. */ protected abstract boolean hasPendingRequests(); @@ -152,114 +150,161 @@ private static long debugDuration(long nowMs, long startMs) { */ protected abstract void startThrottleTimer(); - /** Reflects that {@link #shutdown()} was explicitly called. */ - protected boolean isShutdown() { - return isShutdown.get(); - } - - private StreamObserver requestObserver() { - if (requestObserver == null) { - throw new NullPointerException( - "requestObserver cannot be null. Missing a call to startStream() to initialize."); + /** Try to send a request to the server. Returns true if the request was successfully sent. */ + @CanIgnoreReturnValue + protected final synchronized boolean trySend(RequestT request) + throws WindmillStreamShutdownException { + debugMetrics.recordSend(); + try { + requestObserver.onNext(request); + return true; + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { + // Stream was broken, requests may be retried when stream is reopened. } - return requestObserver; + return false; } - /** Send a request to the server. */ - protected final void send(RequestT request) { - lastSendTimeMs.set(Instant.now().getMillis()); + @Override + public final void start() { + boolean shouldStartStream = false; synchronized (this) { - if (streamClosed.get()) { - throw new IllegalStateException("Send called on a client closed stream."); + if (!isShutdown && !started) { + started = true; + shouldStartStream = true; } + } - requestObserver().onNext(request); + if (shouldStartStream) { + startStream(); } } /** Starts the underlying stream. */ - protected final void startStream() { + private void startStream() { // Add the stream to the registry after it has been fully constructed. streamRegistry.add(this); while (true) { try { synchronized (this) { - startTimeMs.set(Instant.now().getMillis()); - lastResponseTimeMs.set(0); - streamClosed.set(false); - // lazily initialize the requestObserver. Gets reset whenever the stream is reopened. - requestObserver = requestObserverSupplier.get(); + debugMetrics.recordStart(); + requestObserver.reset(); onNewStream(); - if (clientClosed.get()) { + if (clientClosed) { halfClose(); } return; } + } catch (WindmillStreamShutdownException e) { + // shutdown() is responsible for cleaning up pending requests. + logger.debug("Stream was shutdown while creating new stream.", e); + break; } catch (Exception e) { - LOG.error("Failed to create new stream, retrying: ", e); + logger.error("Failed to create new stream, retrying: ", e); try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException | IOException i) { + debugMetrics.recordSleep(sleep); + sleeper.sleep(sleep); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + logger.info( + "Interrupted during {} creation backoff. The stream will not be created.", + getClass()); + // Shutdown the stream to clean up any dangling resources and pending requests. + shutdown(); + break; + } catch (IOException ioe) { // Keep trying to create the stream. } } } + + // We were never able to start the stream, remove it from the stream registry. Otherwise, it is + // removed when closed. + streamRegistry.remove(this); } - protected final Executor executor() { - return executor; + /** + * Execute the runnable using the {@link #executor} handling the executor being in a shutdown + * state. + */ + protected final void executeSafely(Runnable runnable) { + try { + executor.execute(runnable); + } catch (RejectedExecutionException e) { + logger.debug("{}-{} has been shutdown.", getClass(), backendWorkerToken); + } } public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { - if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && !clientClosed.get()) { + if (!clientClosed && debugMetrics.getLastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); - } catch (RuntimeException e) { - LOG.debug("Received exception sending health check.", e); + } catch (Exception e) { + logger.debug("Received exception sending health check.", e); } } } - protected abstract void sendHealthCheck(); + protected abstract void sendHealthCheck() throws WindmillStreamShutdownException; - // Care is taken that synchronization on this is unnecessary for all status page information. - // Blocking sends are made beneath this stream object's lock which could block status page - // rendering. + /** + * @implNote Care is taken that synchronization on this is unnecessary for all status page + * information. Blocking sends are made beneath this stream object's lock which could block + * status page rendering. + */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - if (errorCount.get() > 0) { - writer.format( - ", %d errors, last error [ %s ] at [%s]", - errorCount.get(), lastError.get(), lastErrorTime.get()); - } - if (clientClosed.get()) { + StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); + summaryMetrics + .restartMetrics() + .ifPresent( + metrics -> + writer.format( + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + metrics.restartCount(), + metrics.lastRestartReason(), + metrics.lastRestartTime().orElse(null), + metrics.errorCount())); + + if (summaryMetrics.isClientClosed()) { writer.write(", client closed"); } - long nowMs = Instant.now().getMillis(); - long sleepLeft = sleepUntil.get() - nowMs; - if (sleepLeft > 0) { - writer.format(", %dms backoff remaining", sleepLeft); + + if (summaryMetrics.sleepLeft() > 0) { + writer.format(", %dms backoff remaining", summaryMetrics.sleepLeft()); } + writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s", - debugDuration(nowMs, startTimeMs.get()), - debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get()); + ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " + + "shutdown time: %s", + summaryMetrics.streamAge(), + summaryMetrics.timeSinceLastSend(), + summaryMetrics.timeSinceLastResponse(), + requestObserver.isClosed(), + summaryMetrics.shutdownTime().map(DateTime::toString).orElse(NOT_SHUTDOWN)); } - // Don't require synchronization on stream, see the appendSummaryHtml comment. + /** + * @implNote Don't require synchronization on stream, see the {@link + * #appendSummaryHtml(PrintWriter)} comment. + */ protected abstract void appendSpecificHtml(PrintWriter writer); @Override public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. - clientClosed.set(true); - requestObserver().onCompleted(); - streamClosed.set(true); + debugMetrics.recordHalfClose(); + clientClosed = true; + try { + requestObserver.onCompleted(); + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { + logger.warn("Stream was previously closed."); + } catch (WindmillStreamShutdownException e) { + logger.warn("Stream was previously shutdown."); + } catch (IllegalStateException e) { + logger.warn("Unexpected error when trying to close stream", e); + } } @Override @@ -269,7 +314,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(startTimeMs.get()); + return new Instant(debugMetrics.getStartTimeMs()); } @Override @@ -278,22 +323,31 @@ public String backendWorkerToken() { } @Override - public void shutdown() { - if (isShutdown.compareAndSet(false, true)) { - requestObserver() - .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + public final void shutdown() { + // Don't lock on "this" before poisoning the request observer since otherwise the observer may + // be blocking in send(). + requestObserver.poison(); + synchronized (this) { + if (!isShutdown) { + isShutdown = true; + debugMetrics.recordShutdown(); + shutdownInternal(); + } } } - private void setLastError(String error) { - lastError.set(error); - lastErrorTime.set(DateTime.now()); - } + protected abstract void shutdownInternal(); - public static class WindmillStreamShutdownException extends RuntimeException { - public WindmillStreamShutdownException(String message) { - super(message); + /** Returns true if the stream was torn down and should not be restarted internally. */ + private synchronized boolean maybeTearDownStream() { + if (isShutdown || (clientClosed && !hasPendingRequests())) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); + return true; } + + return false; } private class ResponseObserver implements StreamObserver { @@ -305,77 +359,83 @@ public void onNext(ResponseT response) { } catch (IOException e) { // Ignore. } - lastResponseTimeMs.set(Instant.now().getMillis()); + debugMetrics.recordResponse(); onResponse(response); } @Override public void onError(Throwable t) { - onStreamFinished(t); + if (maybeTearDownStream()) { + return; + } + + Status errorStatus = Status.fromThrowable(t); + recordStreamStatus(errorStatus); + + // If the stream was stopped due to a resource exhausted error then we are throttled. + if (errorStatus.getCode() == Status.Code.RESOURCE_EXHAUSTED) { + startThrottleTimer(); + } + + try { + long sleep = backoff.nextBackOffMillis(); + debugMetrics.recordSleep(sleep); + sleeper.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } catch (IOException e) { + // Ignore. + } + + executeSafely(AbstractWindmillStream.this::startStream); } @Override public void onCompleted() { - onStreamFinished(null); + if (maybeTearDownStream()) { + return; + } + recordStreamStatus(OK_STATUS); + executeSafely(AbstractWindmillStream.this::startStream); } - private void onStreamFinished(@Nullable Throwable t) { - synchronized (this) { - if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - return; - } - } - if (t != null) { - Status status = null; - if (t instanceof StatusRuntimeException) { - status = ((StatusRuntimeException) t).getStatus(); - } - String statusError = status == null ? "" : status.toString(); - setLastError(statusError); - if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) { + private void recordStreamStatus(Status status) { + int currentRestartCount = debugMetrics.incrementAndGetRestarts(); + if (status.isOk()) { + String restartReason = + "Stream completed successfully but did not complete requested operations, " + + "recreating"; + logger.warn(restartReason); + debugMetrics.recordRestartReason(restartReason); + } else { + int currentErrorCount = debugMetrics.incrementAndGetErrors(); + debugMetrics.recordRestartReason(status.toString()); + Throwable t = status.getCause(); + if (t instanceof StreamObserverCancelledException) { + logger.error( + "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}", + getClass(), + backendWorkerToken, + t.getStackTrace(), + t); + } else if (currentRestartCount % logEveryNStreamFailures == 0) { + // Don't log every restart since it will get noisy, and many errors transient. long nowMillis = Instant.now().getMillis(); - String responseDebug; - if (lastResponseTimeMs.get() == 0) { - responseDebug = "never received response"; - } else { - responseDebug = - "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - } - LOG.debug( - "{} streaming Windmill RPC errors for {}, last was: {} with status {}." - + " created {}ms ago, {}. This is normal with autoscaling.", + logger.debug( + "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", AbstractWindmillStream.this.getClass(), - errorCount.get(), + currentRestartCount, + currentErrorCount, t, - statusError, - nowMillis - startTimeMs.get(), - responseDebug); - } - // If the stream was stopped due to a resource exhausted error then we are throttled. - if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { - startThrottleTimer(); - } - - try { - long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (IOException e) { - // Ignore. + status, + nowMillis - debugMetrics.getStartTimeMs(), + debugMetrics + .responseDebugString(nowMillis) + .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); } - } else { - errorCount.incrementAndGet(); - String error = - "Stream completed successfully but did not complete requested operations, " - + "recreating"; - LOG.warn(error); - setLastError(error); } - executor.execute(AbstractWindmillStream.this::startStream); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java new file mode 100644 index 000000000000..1db6d8de791d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.slf4j.Logger; + +/** + * Request observer that allows resetting its internal delegate using the given {@link + * #streamObserverFactory}. + * + * @implNote {@link StreamObserver}s generated by {@link #streamObserverFactory} are expected to be + * {@link ThreadSafe}. Has same methods declared in {@link StreamObserver}, but they throw + * {@link StreamClosedException} and {@link WindmillStreamShutdownException}, which much be + * handled by callers. + */ +@ThreadSafe +@Internal +final class ResettableThrowingStreamObserver { + private final Supplier> streamObserverFactory; + private final Logger logger; + + @GuardedBy("this") + private @Nullable TerminatingStreamObserver delegateStreamObserver; + + @GuardedBy("this") + private boolean isPoisoned = false; + + /** + * Indicates that the current delegate is closed via {@link #poison() or {@link #onCompleted()}}. + * If not poisoned, a call to {@link #reset()} is required to perform future operations on the + * StreamObserver. + */ + @GuardedBy("this") + private boolean isCurrentStreamClosed = true; + + ResettableThrowingStreamObserver( + Supplier> streamObserverFactory, Logger logger) { + this.streamObserverFactory = streamObserverFactory; + this.logger = logger; + this.delegateStreamObserver = null; + } + + private synchronized StreamObserver delegate() + throws WindmillStreamShutdownException, StreamClosedException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + if (isCurrentStreamClosed) { + throw new StreamClosedException( + "Current stream is closed, requires reset() for future stream operations."); + } + + return Preconditions.checkNotNull(delegateStreamObserver, "requestObserver cannot be null."); + } + + /** Creates a new delegate to use for future {@link StreamObserver} methods. */ + synchronized void reset() throws WindmillStreamShutdownException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + delegateStreamObserver = streamObserverFactory.get(); + isCurrentStreamClosed = false; + } + + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ + synchronized void poison() { + if (!isPoisoned) { + isPoisoned = true; + if (delegateStreamObserver != null) { + delegateStreamObserver.terminate( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + delegateStreamObserver = null; + isCurrentStreamClosed = true; + } + } + } + + public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException { + // Make sure onNext and onError below to be called on the same StreamObserver instance. + StreamObserver delegate = delegate(); + try { + // Do NOT lock while sending message over the stream as this will block other StreamObserver + // operations. + delegate.onNext(t); + } catch (StreamObserverCancelledException cancellationException) { + synchronized (this) { + if (isPoisoned) { + logger.debug("Stream was shutdown during send.", cancellationException); + return; + } + } + + try { + delegate.onError(cancellationException); + } catch (IllegalStateException onErrorException) { + // If the delegate above was already terminated via onError or onComplete from another + // thread. + logger.warn( + "StreamObserver was already cancelled {} due to error.", + onErrorException, + cancellationException); + } catch (RuntimeException onErrorException) { + logger.warn( + "Encountered unexpected error {} when cancelling due to error.", + onErrorException, + cancellationException); + } + } + } + + public synchronized void onError(Throwable throwable) + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onError(throwable); + isCurrentStreamClosed = true; + } + + public synchronized void onCompleted() + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onCompleted(); + isCurrentStreamClosed = true; + } + + synchronized boolean isClosed() { + return isCurrentStreamClosed; + } + + /** + * Indicates that the current stream was closed and the {@link StreamObserver} has finished via + * {@link StreamObserver#onCompleted()}. The stream may perform + */ + static final class StreamClosedException extends Exception { + StreamClosedException(String s) { + super(s); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java new file mode 100644 index 000000000000..4cda12a85ea2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.joda.time.DateTime; +import org.joda.time.Instant; + +/** Records stream metrics for debugging. */ +@ThreadSafe +final class StreamDebugMetrics { + private final Supplier clock; + + @GuardedBy("this") + private int errorCount = 0; + + @GuardedBy("this") + private int restartCount = 0; + + @GuardedBy("this") + private long sleepUntil = 0; + + @GuardedBy("this") + private String lastRestartReason = ""; + + @GuardedBy("this") + private @Nullable DateTime lastRestartTime = null; + + @GuardedBy("this") + private long lastResponseTimeMs = 0; + + @GuardedBy("this") + private long lastSendTimeMs = 0; + + @GuardedBy("this") + private long startTimeMs = 0; + + @GuardedBy("this") + private @Nullable DateTime shutdownTime = null; + + @GuardedBy("this") + private boolean clientClosed = false; + + private StreamDebugMetrics(Supplier clock) { + this.clock = clock; + } + + static StreamDebugMetrics create() { + return new StreamDebugMetrics(Instant::now); + } + + @VisibleForTesting + static StreamDebugMetrics forTesting(Supplier fakeClock) { + return new StreamDebugMetrics(fakeClock); + } + + private static long debugDuration(long nowMs, long startMs) { + return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); + } + + private long nowMs() { + return clock.get().getMillis(); + } + + synchronized void recordSend() { + lastSendTimeMs = nowMs(); + } + + synchronized void recordStart() { + startTimeMs = nowMs(); + lastResponseTimeMs = 0; + } + + synchronized void recordResponse() { + lastResponseTimeMs = nowMs(); + } + + synchronized void recordRestartReason(String error) { + lastRestartReason = error; + lastRestartTime = clock.get().toDateTime(); + } + + synchronized long getStartTimeMs() { + return startTimeMs; + } + + synchronized long getLastSendTimeMs() { + return lastSendTimeMs; + } + + synchronized void recordSleep(long sleepMs) { + sleepUntil = nowMs() + sleepMs; + } + + synchronized int incrementAndGetRestarts() { + return restartCount++; + } + + synchronized int incrementAndGetErrors() { + return errorCount++; + } + + synchronized void recordShutdown() { + shutdownTime = clock.get().toDateTime(); + } + + synchronized void recordHalfClose() { + clientClosed = true; + } + + synchronized Optional responseDebugString(long nowMillis) { + return lastResponseTimeMs == 0 + ? Optional.empty() + : Optional.of("received response " + (nowMillis - lastResponseTimeMs) + "ms ago"); + } + + private synchronized Optional getRestartMetrics() { + if (restartCount > 0) { + return Optional.of( + RestartMetrics.create(restartCount, lastRestartReason, lastRestartTime, errorCount)); + } + + return Optional.empty(); + } + + synchronized Snapshot getSummaryMetrics() { + long nowMs = clock.get().getMillis(); + return Snapshot.create( + debugDuration(nowMs, startTimeMs), + debugDuration(nowMs, lastSendTimeMs), + debugDuration(nowMs, lastResponseTimeMs), + getRestartMetrics(), + sleepUntil - nowMs(), + shutdownTime, + clientClosed); + } + + @AutoValue + abstract static class Snapshot { + private static Snapshot create( + long streamAge, + long timeSinceLastSend, + long timeSinceLastResponse, + Optional restartMetrics, + long sleepLeft, + @Nullable DateTime shutdownTime, + boolean isClientClosed) { + return new AutoValue_StreamDebugMetrics_Snapshot( + streamAge, + timeSinceLastSend, + timeSinceLastResponse, + restartMetrics, + sleepLeft, + Optional.ofNullable(shutdownTime), + isClientClosed); + } + + abstract long streamAge(); + + abstract long timeSinceLastSend(); + + abstract long timeSinceLastResponse(); + + abstract Optional restartMetrics(); + + abstract long sleepLeft(); + + abstract Optional shutdownTime(); + + abstract boolean isClientClosed(); + } + + @AutoValue + abstract static class RestartMetrics { + private static RestartMetrics create( + int restartCount, + String restartReason, + @Nullable DateTime lastRestartTime, + int errorCount) { + return new AutoValue_StreamDebugMetrics_RestartMetrics( + restartCount, restartReason, Optional.ofNullable(lastRestartTime), errorCount); + } + + abstract int restartCount(); + + abstract String lastRestartReason(); + + abstract Optional lastRestartTime(); + + abstract int errorCount(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index f26c56b14ec2..51bc03e8e0e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -34,6 +34,12 @@ @ThreadSafe public interface WindmillStream { + /** + * Start the stream, opening a connection to the backend server. A call to start() is required for + * any further interactions on the stream. + */ + void start(); + /** An identifier for the backend worker where the stream is sending/receiving RPCs. */ String backendWorkerToken(); @@ -47,8 +53,9 @@ public interface WindmillStream { Instant startTime(); /** - * Shutdown the stream. There should be no further interactions with the stream once this has been - * called. + * Shuts down the stream. No further interactions should be made with the stream, and the stream + * will no longer try to connect internally. Any pending retries or in-flight requests will be + * cancelled and all responses dropped and considered invalid. */ void shutdown(); @@ -68,13 +75,16 @@ default void setBudget(long newItems, long newBytes) { interface GetDataStream extends WindmillStream { /** Issues a keyed GetData fetch, blocking until the result is ready. */ Windmill.KeyedGetDataResponse requestKeyedData( - String computation, Windmill.KeyedGetDataRequest request); + String computation, Windmill.KeyedGetDataRequest request) + throws WindmillStreamShutdownException; /** Issues a global GetData fetch, blocking until the result is ready. */ - Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request); + Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) + throws WindmillStreamShutdownException; /** Tells windmill processing is ongoing for the given keys. */ - void refreshActiveWork(Map> heartbeats); + void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException; void onHeartbeatResponse(List responses); } @@ -86,7 +96,7 @@ interface CommitWorkStream extends WindmillStream { * Returns a builder that can be used for sending requests. Each builder is not thread-safe but * different builders for the same stream may be used simultaneously. */ - CommitWorkStream.RequestBatcher batcher(); + RequestBatcher batcher(); @NotThreadSafe interface RequestBatcher extends Closeable { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java new file mode 100644 index 000000000000..566c15c58036 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +/** + * Thrown when operations are requested on a {@link WindmillStream} has been shutdown. Future + * operations on the stream are not allowed and will throw an {@link + * WindmillStreamShutdownException}. + */ +public final class WindmillStreamShutdownException extends Exception { + public WindmillStreamShutdownException(String message) { + super(message); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index c8e058e7e230..ab12946ad18b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -21,8 +21,8 @@ import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; /** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */ @@ -61,7 +61,7 @@ public Windmill.KeyedGetDataResponse getStateData( String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( @@ -86,7 +86,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) sideInputGetDataStreamFactory.apply(request.getDataId().getTag()); try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(e); } catch (Exception e) { throw new GetDataException( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java index 98545a429461..b15f73645dee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java @@ -134,6 +134,12 @@ public void close() throws IOException { stream.close(); } + static class InvalidInputStreamStateException extends IllegalStateException { + public InvalidInputStreamStateException() { + super("Got poison pill or timeout but stream is not done."); + } + } + @SuppressWarnings("NullableProblems") private class InputStreamEnumeration implements Enumeration { // The first stream is eagerly read on SequenceInputStream creation. For this reason @@ -159,7 +165,7 @@ public boolean hasMoreElements() { if (complete.get()) { return false; } - throw new IllegalStateException("Got poison pill or timeout but stream is not done."); + throw new InvalidInputStreamStateException(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new CancellationException(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 053843a8af25..2dd069b9c443 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -19,14 +19,19 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import com.google.auto.value.AutoValue; import java.io.PrintWriter; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -35,22 +40,25 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcCommitWorkStream +final class GrpcCommitWorkStream extends AbstractWindmillStream implements CommitWorkStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStream.class); private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; - private final Map pending; + private final ConcurrentMap pending; private final AtomicLong idGenerator; private final JobHeader jobHeader; private final ThrottleTimer commitWorkThrottleTimer; @@ -69,6 +77,7 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( + LOG, "CommitWorkStream", startCommitWorkRpcFn, backoff, @@ -83,7 +92,7 @@ private GrpcCommitWorkStream( this.streamingRpcBatchLimit = streamingRpcBatchLimit; } - public static GrpcCommitWorkStream create( + static GrpcCommitWorkStream create( String backendWorkerToken, Function, StreamObserver> startCommitWorkRpcFn, @@ -95,20 +104,17 @@ public static GrpcCommitWorkStream create( JobHeader jobHeader, AtomicLong idGenerator, int streamingRpcBatchLimit) { - GrpcCommitWorkStream commitWorkStream = - new GrpcCommitWorkStream( - backendWorkerToken, - startCommitWorkRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - commitWorkThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit); - commitWorkStream.startStream(); - return commitWorkStream; + return new GrpcCommitWorkStream( + backendWorkerToken, + startCommitWorkRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit); } @Override @@ -117,8 +123,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - protected synchronized void onNewStream() { - send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); try (Batcher resendBatcher = new Batcher()) { for (Map.Entry entry : pending.entrySet()) { if (!resendBatcher.canAccept(entry.getValue().getBytes())) { @@ -144,11 +150,11 @@ protected boolean hasPendingRequests() { } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); - send(builder.build()); + trySend(builder.build()); } } @@ -156,29 +162,49 @@ public void sendHealthCheck() { protected void onResponse(StreamingCommitResponse response) { commitWorkThrottleTimer.stop(); - RuntimeException finalException = null; + CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); for (int i = 0; i < response.getRequestIdCount(); ++i) { long requestId = response.getRequestId(i); if (requestId == HEARTBEAT_REQUEST_ID) { continue; } - PendingRequest done = pending.remove(requestId); - if (done == null) { - LOG.error("Got unknown commit request ID: {}", requestId); + + // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may + // be omitted. + CommitStatus commitStatus = + i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; + + @Nullable PendingRequest pendingRequest = pending.remove(requestId); + if (pendingRequest == null) { + synchronized (this) { + if (!isShutdown) { + // Missing responses are expected after shutdown() because it removes them. + LOG.error("Got unknown commit request ID: {}", requestId); + } + } } else { try { - done.onDone.accept( - (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); + pendingRequest.completeWithStatus(commitStatus); } catch (RuntimeException e) { // Catch possible exceptions to ensure that an exception for one commit does not prevent - // other commits from being processed. + // other commits from being processed. Aggregate all the failures to throw after + // processing the response if they exist. LOG.warn("Exception while processing commit response.", e); - finalException = e; + failureHandler.addError(commitStatus, e); } } } - if (finalException != null) { - throw finalException; + + failureHandler.throwIfNonEmpty(); + } + + @Override + protected void shutdownInternal() { + Iterator pendingRequests = pending.values().iterator(); + while (pendingRequests.hasNext()) { + PendingRequest pendingRequest = pendingRequests.next(); + pendingRequest.abort(); + pendingRequests.remove(); } } @@ -187,13 +213,15 @@ protected void startThrottleTimer() { commitWorkThrottleTimer.start(); } - private void flushInternal(Map requests) { + private void flushInternal(Map requests) + throws WindmillStreamShutdownException { if (requests.isEmpty()) { return; } + if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request.getSerializedSize() + if (elem.getValue().request().getSerializedSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -204,100 +232,171 @@ private void flushInternal(Map requests) { } } - private void issueSingleRequest(final long id, PendingRequest pendingRequest) { + private void issueSingleRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() - .setComputationId(pendingRequest.computation) + .setComputationId(pendingRequest.computationId()) .setRequestId(id) - .setShardingKey(pendingRequest.request.getShardingKey()) - .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { - pending.put(id, pendingRequest); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; } + trySend(chunk); } } - private void issueBatchedRequest(Map requests) { + private void issueBatchedRequest(Map requests) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); String lastComputation = null; for (Map.Entry entry : requests.entrySet()) { PendingRequest request = entry.getValue(); StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computation)) { - chunkBuilder.setComputationId(request.computation); - lastComputation = request.computation; + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); } - chunkBuilder.setRequestId(entry.getKey()); - chunkBuilder.setShardingKey(request.request.getShardingKey()); - chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { - pending.putAll(requests); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + if (!prepareForSend(requests)) { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; } + trySend(request); } } - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - checkNotNull(pendingRequest.computation); - final ByteString serializedCommit = pendingRequest.request.toByteString(); - + private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { - pending.put(id, pendingRequest); + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - StreamingCommitRequestChunk.Builder chunkBuilder = StreamingCommitRequestChunk.newBuilder() .setRequestId(id) .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); } - StreamingCommitWorkRequest requestChunk = StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + + if (!trySend(requestChunk)) { + // The stream broke, don't try to send the rest of the chunks here. break; } } } } - private static class PendingRequest { + /** Returns true if prepare for send succeeded. */ + private synchronized boolean prepareForSend(long id, PendingRequest request) { + if (!isShutdown) { + pending.put(id, request); + return true; + } + return false; + } + + /** Returns true if prepare for send succeeded. */ + private synchronized boolean prepareForSend(Map requests) { + if (!isShutdown) { + pending.putAll(requests); + return true; + } + return false; + } + + @AutoValue + abstract static class PendingRequest { + + private static PendingRequest create( + String computationId, WorkItemCommitRequest request, Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + } + + abstract String computationId(); + + abstract WorkItemCommitRequest request(); + + abstract Consumer onDone(); + + private long getBytes() { + return (long) request().getSerializedSize() + computationId().length(); + } + + private ByteString serializedCommit() { + return request().toByteString(); + } + + private void completeWithStatus(CommitStatus commitStatus) { + onDone().accept(commitStatus); + } + + private long shardingKey() { + return request().getShardingKey(); + } + + private void abort() { + completeWithStatus(CommitStatus.ABORTED); + } + } + + private static class CommitCompletionException extends RuntimeException { + private CommitCompletionException(String message) { + super(message); + } + } + + private static class CommitCompletionFailureHandler { + private static final int MAX_PRINTABLE_ERRORS = 10; + private final Map>, Integer> errorCounter; + private final EvictingQueue detailedErrors; - private final String computation; - private final WorkItemCommitRequest request; - private final Consumer onDone; + private CommitCompletionFailureHandler() { + this.errorCounter = new HashMap<>(); + this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); + } - PendingRequest( - String computation, WorkItemCommitRequest request, Consumer onDone) { - this.computation = computation; - this.request = request; - this.onDone = onDone; + private void addError(CommitStatus commitStatus, Throwable error) { + errorCounter.compute( + Pair.of(commitStatus, error.getClass()), + (ignored, current) -> current == null ? 1 : current + 1); + detailedErrors.add(error); } - long getBytes() { - return (long) request.getSerializedSize() + computation.length(); + private void throwIfNonEmpty() { + if (!errorCounter.isEmpty()) { + String errorMessage = + String.format( + "Exception while processing commit response. ErrorCounter: %s; Details: %s", + errorCounter, detailedErrors); + throw new CommitCompletionException(errorMessage); + } } } @@ -317,7 +416,8 @@ public boolean commitWorkItem( if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { return false; } - PendingRequest request = new PendingRequest(computation, commitRequest, onDone); + + PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); add(idGenerator.incrementAndGet(), request); return true; } @@ -325,13 +425,18 @@ public boolean commitWorkItem( /** Flushes any pending work items to the wire. */ @Override public void flush() { - flushInternal(queue); - queuedBytes = 0; - queue.clear(); + try { + flushInternal(queue); + } catch (WindmillStreamShutdownException e) { + queue.forEach((ignored, request) -> request.abort()); + } finally { + queuedBytes = 0; + queue.clear(); + } } void add(long id, PendingRequest request) { - assert (canAccept(request.getBytes())); + Preconditions.checkState(canAccept(request.getBytes())); queuedBytes += request.getBytes(); queue.put(id, request); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index b27ebc8e9eee..27f457900e6c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,7 +21,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import javax.annotation.concurrent.GuardedBy; @@ -35,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; @@ -105,6 +105,7 @@ private GrpcDirectGetWorkStream( WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -144,22 +145,19 @@ static GrpcDirectGetWorkStream create( GetDataClient getDataClient, WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { - GrpcDirectGetWorkStream getWorkStream = - new GrpcDirectGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - heartbeatSender, - getDataClient, - workCommitter, - workItemScheduler); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcDirectGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + heartbeatSender, + getDataClient, + workCommitter, + workItemScheduler); } private static Watermarks createWatermarks( @@ -174,7 +172,7 @@ private static Watermarks createWatermarks( /** * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message * which can deadlock since we send on the stream beneath the synchronization. {@link - * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + * AbstractWindmillStream#trySend(Object)} is synchronized so the sends are already guarded. */ private void maybeSendRequestExtension(GetWorkBudget extension) { if (extension.items() > 0 || extension.bytes() > 0) { @@ -190,8 +188,8 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { lastRequest.set(request); budgetTracker.recordBudgetRequested(extension); try { - send(request); - } catch (IllegalStateException e) { + trySend(request); + } catch (WindmillStreamShutdownException e) { // Stream was closed. } }); @@ -199,24 +197,22 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); - if (!isShutdown()) { - budgetTracker.reset(); - GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); - StreamingGetWorkRequest request = - StreamingGetWorkRequest.newBuilder() - .setRequest( - requestHeader - .toBuilder() - .setMaxItems(initialGetWorkBudget.items()) - .setMaxBytes(initialGetWorkBudget.bytes()) - .build()) - .build(); - lastRequest.set(request); - budgetTracker.recordBudgetRequested(initialGetWorkBudget); - send(request); - } + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + trySend(request); } @Override @@ -234,10 +230,13 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() { - send(HEALTH_CHECK_REQUEST); + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } + @Override + protected void shutdownInternal() {} + @Override protected void onResponse(StreamingGetWorkResponseChunk chunk) { getWorkThrottleTimer.stop(); @@ -277,14 +276,6 @@ public void setBudget(GetWorkBudget newBudget) { maybeSendRequestExtension(extension); } - private void executeSafely(Runnable runnable) { - try { - executor().execute(runnable); - } catch (RejectedExecutionException e) { - LOG.debug("{} has been shutdown.", getClass()); - } - } - /** * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request * extensions. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index c99e05a77074..b5b49c8ee976 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -31,10 +31,11 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; @@ -49,6 +50,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -59,12 +61,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +@ThreadSafe final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); + private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = + StreamingGetDataRequest.newBuilder().build(); + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ private final Deque batches; + private final Map pending; private final AtomicLong idGenerator; private final ThrottleTimer getDataThrottleTimer; @@ -90,6 +97,7 @@ private GrpcGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { super( + LOG, "GetDataStream", startGetDataRpcFn, backoff, @@ -107,7 +115,7 @@ private GrpcGetDataStream( this.processHeartbeatResponses = processHeartbeatResponses; } - public static GrpcGetDataStream create( + static GrpcGetDataStream create( String backendWorkerToken, Function, StreamObserver> startGetDataRpcFn, @@ -121,32 +129,44 @@ public static GrpcGetDataStream create( int streamingRpcBatchLimit, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { - GrpcGetDataStream getDataStream = - new GrpcGetDataStream( - backendWorkerToken, - startGetDataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getDataThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit, - sendKeyedGetDataRequests, - processHeartbeatResponses); - getDataStream.startStream(); - return getDataStream; + return new GrpcGetDataStream( + backendWorkerToken, + startGetDataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedBatch batch) { + return new WindmillStreamShutdownException( + "Stream was closed when attempting to send " + batch.requestsCount() + " requests."); + } + + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedRequest request) { + return new WindmillStreamShutdownException( + "Cannot send request=[" + request + "] on closed stream."); + } + + private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest) + throws WindmillStreamShutdownException { + trySend(getDataRequest); } @Override - protected synchronized void onNewStream() { - send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed.get()) { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + if (clientClosed) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. - verify(!hasPendingRequests()); + verify(!hasPendingRequests(), "Pending requests not expected if we've half-closed."); } else { for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); @@ -160,7 +180,6 @@ protected boolean hasPendingRequests() { } @Override - @SuppressWarnings("dereference.of.nullable") protected void onResponse(StreamingGetDataResponse chunk) { checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); @@ -168,8 +187,15 @@ protected void onResponse(StreamingGetDataResponse chunk) { onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - verify(responseStream != null, "No pending response stream"); + @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); + if (responseStream == null) { + synchronized (this) { + // shutdown()/shutdownInternal() cleans up pending, else we expect a pending + // responseStream for every response. + verify(isShutdown, "No pending response stream"); + } + continue; + } responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { responseStream.complete(); @@ -187,23 +213,22 @@ private long uniqueId() { } @Override - public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { + public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) + throws WindmillStreamShutdownException { return issueRequest( QueuedRequest.forComputation(uniqueId(), computation, request), KeyedGetDataResponse::parseFrom); } @Override - public GlobalData requestGlobalData(GlobalDataRequest request) { + public GlobalData requestGlobalData(GlobalDataRequest request) + throws WindmillStreamShutdownException { return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); } @Override - public void refreshActiveWork(Map> heartbeats) { - if (isShutdown()) { - throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); - } - + public void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException { StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); if (sendKeyedGetDataRequests) { long builderBytes = 0; @@ -214,7 +239,7 @@ public void refreshActiveWork(Map> heartbea if (builderBytes > 0 && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); } @@ -233,7 +258,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } else { // No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`. @@ -248,7 +273,7 @@ public void refreshActiveWork(Map> heartbea if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) { builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build()); } - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); computationHeartbeatBuilder.clear().setComputationId(entry.getKey()); @@ -260,7 +285,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } } @@ -271,12 +296,26 @@ public void onHeartbeatResponse(List resp } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { - send(StreamingGetDataRequest.newBuilder().build()); + trySend(HEALTH_CHECK_REQUEST); } } + @Override + protected synchronized void shutdownInternal() { + // Stream has been explicitly closed. Drain pending input streams and request batches. + // Future calls to send RPCs will fail. + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); + } + @Override public void appendSpecificHtml(PrintWriter writer) { writer.format( @@ -301,20 +340,23 @@ public void appendSpecificHtml(PrintWriter writer) { writer.append("]"); } - private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { + private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) + throws WindmillStreamShutdownException { while (true) { request.resetResponseStream(); try { queueRequestAndWait(request); return parseFn.parse(request.getResponseStream()); - } catch (CancellationException e) { - // Retry issuing the request since the response stream was cancelled. - continue; + } catch (AppendableInputStream.InvalidInputStreamStateException | CancellationException e) { + throwIfShutdown(request, e); + if (!(e instanceof CancellationException)) { + throw e; + } } catch (IOException e) { LOG.error("Parsing GetData response failed: ", e); - continue; } catch (InterruptedException e) { Thread.currentThread().interrupt(); + throwIfShutdown(request, e); throw new RuntimeException(e); } finally { pending.remove(request.id()); @@ -322,18 +364,32 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn= streamingRpcBatchLimit + || batch.requestsCount() >= streamingRpcBatchLimit || batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { if (batch != null) { - waitForSendLatch = batch.getLatch(); + prevBatch = batch; } batch = new QueuedBatch(); batches.addLast(batch); @@ -342,64 +398,80 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept batch.addRequest(request); } if (responsibleForSend) { - if (waitForSendLatch == null) { + if (prevBatch == null) { // If there was not a previous batch wait a little while to improve // batching. - Thread.sleep(1); + sleeper.sleep(1); } else { - waitForSendLatch.await(); + prevBatch.waitForSendOrFailNotification(); } // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for its completion. - synchronized (batches) { - verify(batch == batches.peekFirst()); + synchronized (this) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + verify(batch == batches.peekFirst(), "GetDataStream request batch removed before send()."); batch.markFinalized(); } - sendBatch(batch.requests()); - synchronized (batches) { - verify(batch == batches.pollFirst()); + trySendBatch(batch); + } else { + // Wait for this batch to be sent before parsing the response. + batch.waitForSendOrFailNotification(); + } + } + + void trySendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { + try { + sendBatch(batch); + synchronized (this) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + verify( + batch == batches.pollFirst(), + "Sent GetDataStream request batch removed before send() was complete."); } // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). - batch.countDown(); - } else { - // Wait for this batch to be sent before parsing the response. - batch.await(); + batch.notifySent(); + } catch (Exception e) { + // Free waiters if the send() failed. + batch.notifyFailed(); + // Propagate the exception to the calling thread. + throw e; } } - @SuppressWarnings("NullableProblems") - private void sendBatch(List requests) { - StreamingGetDataRequest batchedRequest = flushToBatch(requests); + private void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { + if (batch.isEmpty()) { + return; + } + + // Synchronization of pending inserts is necessary with send to ensure duplicates are not + // sent on stream reconnect. synchronized (this) { - // Synchronization of pending inserts is necessary with send to ensure duplicates are not - // sent on stream reconnect. - for (QueuedRequest request : requests) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + for (QueuedRequest request : batch.requestsReadOnly()) { // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. - verify(pending.put(request.id(), request.getResponseStream()) == null); + verify( + pending.put(request.id(), request.getResponseStream()) == null, + "Request already sent."); } - try { - send(batchedRequest); - } catch (IllegalStateException e) { + + if (!trySend(batch.asGetDataRequest())) { // The stream broke before this call went through; onNewStream will retry the fetch. - LOG.warn("GetData stream broke before call started.", e); + LOG.warn("GetData stream broke before call started."); } } } - @SuppressWarnings("argument") - private StreamingGetDataRequest flushToBatch(List requests) { - // Put all global data requests first because there is only a single repeated field for - // request ids and the initial ids correspond to global data requests if they are present. - requests.sort(QueuedRequest.globalRequestsFirst()); - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (QueuedRequest request : requests) { - request.addToStreamingGetDataRequest(builder); - } - return builder.build(); - } - @FunctionalInterface private interface ParseFn { ResponseT parse(InputStream input) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index cda9537127d9..ef7f5b20bb07 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -17,20 +17,35 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; + import com.google.auto.value.AutoOneOf; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.stream.Stream; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Utility data classes for {@link GrpcGetDataStream}. */ final class GrpcGetDataStreamRequests { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStreamRequests.class); + private static final int STREAM_CANCELLED_ERROR_LOG_LIMIT = 3; + private GrpcGetDataStreamRequests() {} + private static String debugFormat(long value) { + return String.format("%016x", value); + } + static class QueuedRequest { private final long id; private final ComputationOrGlobalDataRequest dataRequest; @@ -81,6 +96,10 @@ void resetResponseStream() { this.responseStream = new AppendableInputStream(); } + public ComputationOrGlobalDataRequest getDataRequest() { + return dataRequest; + } + void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder builder) { builder.addRequestId(id); if (dataRequest.isForComputation()) { @@ -89,20 +108,51 @@ void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder build builder.addGlobalDataRequest(dataRequest.global()); } } + + @Override + public final String toString() { + return "QueuedRequest{" + "dataRequest=" + dataRequest + ", id=" + id + '}'; + } } + /** + * Represents a batch of queued requests. Methods are not thread-safe unless commented otherwise. + */ static class QueuedBatch { private final List requests = new ArrayList<>(); private final CountDownLatch sent = new CountDownLatch(1); private long byteSize = 0; - private boolean finalized = false; + private volatile boolean finalized = false; + private volatile boolean failed = false; - CountDownLatch getLatch() { - return sent; + /** Returns a read-only view of requests. */ + List requestsReadOnly() { + return Collections.unmodifiableList(requests); } - List requests() { - return requests; + /** + * Converts the batch to a {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest}. + */ + Windmill.StreamingGetDataRequest asGetDataRequest() { + Windmill.StreamingGetDataRequest.Builder builder = + Windmill.StreamingGetDataRequest.newBuilder(); + + requests.stream() + // Put all global data requests first because there is only a single repeated field for + // request ids and the initial ids correspond to global data requests if they are present. + .sorted(QueuedRequest.globalRequestsFirst()) + .forEach(request -> request.addToStreamingGetDataRequest(builder)); + + return builder.build(); + } + + boolean isEmpty() { + return requests.isEmpty(); + } + + int requestsCount() { + return requests.size(); } long byteSize() { @@ -117,17 +167,83 @@ void markFinalized() { finalized = true; } + /** Adds a request to the batch. */ void addRequest(QueuedRequest request) { requests.add(request); byteSize += request.byteSize(); } - void countDown() { + /** + * Let waiting for threads know that the request has been successfully sent. + * + * @implNote Thread safe. + */ + void notifySent() { + sent.countDown(); + } + + /** + * Let waiting for threads know that a failure occurred. + * + * @implNote Thread safe. + */ + void notifyFailed() { + failed = true; sent.countDown(); } - void await() throws InterruptedException { + /** + * Block until notified of a successful send via {@link #notifySent()} or a non-retryable + * failure via {@link #notifyFailed()}. On failure, throw an exception for waiters. + * + * @implNote Thread safe. + */ + void waitForSendOrFailNotification() + throws InterruptedException, WindmillStreamShutdownException { sent.await(); + if (failed) { + ImmutableList cancelledRequests = createStreamCancelledErrorMessages(); + if (!cancelledRequests.isEmpty()) { + LOG.error("Requests failed for the following batches: {}", cancelledRequests); + throw new WindmillStreamShutdownException( + "Requests failed for batch containing " + + String.join(", ", cancelledRequests) + + " ... requests. This is most likely due to the stream being explicitly closed" + + " which happens when the work is marked as invalid on the streaming" + + " backend when key ranges shuffle around. This is transient and corresponding" + + " work will eventually be retried."); + } + + throw new WindmillStreamShutdownException("Stream was shutdown while waiting for send."); + } + } + + private ImmutableList createStreamCancelledErrorMessages() { + return requests.stream() + .flatMap( + request -> { + switch (request.getDataRequest().getKind()) { + case GLOBAL: + return Stream.of("GetSideInput=" + request.getDataRequest().global()); + case COMPUTATION: + return request.getDataRequest().computation().getRequestsList().stream() + .map( + keyedRequest -> + "KeyedGetState=[" + + "shardingKey=" + + debugFormat(keyedRequest.getShardingKey()) + + "cacheToken=" + + debugFormat(keyedRequest.getCacheToken()) + + "workToken" + + debugFormat(keyedRequest.getWorkToken()) + + "]"); + default: + // Will never happen switch is exhaustive. + throw new IllegalStateException(); + } + }) + .limit(STREAM_CANCELLED_ERROR_LOG_LIMIT) + .collect(toImmutableList()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index a368f3fec235..fcfefab71c8c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; @@ -36,11 +37,15 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; final class GrpcGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class); + private final GetWorkRequest request; private final WorkItemReceiver receiver; private final ThrottleTimer getWorkThrottleTimer; @@ -62,6 +67,7 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -90,19 +96,16 @@ public static GrpcGetWorkStream create( int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { - GrpcGetWorkStream getWorkStream = - new GrpcGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - receiver); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + receiver); } private void sendRequestExtension(long moreItems, long moreBytes) { @@ -114,25 +117,27 @@ private void sendRequestExtension(long moreItems, long moreBytes) { .setMaxBytes(moreBytes)) .build(); - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + executeSafely( + () -> { + try { + trySend(extension); + } catch (WindmillStreamShutdownException e) { + // Stream was closed. + } + }); } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); - send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); + trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -147,8 +152,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() { - send( + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend( StreamingGetWorkRequest.newBuilder() .setRequestExtension( StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 44e21a9b18ed..4ce2f651f0b7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; @@ -47,9 +48,6 @@ public final class GrpcGetWorkerMetadataStream private final Consumer serverMappingConsumer; private final Object metadataLock; - @GuardedBy("metadataLock") - private long metadataVersion; - @GuardedBy("metadataLock") private WorkerMetadataResponse latestResponse; @@ -61,10 +59,10 @@ private GrpcGetWorkerMetadataStream( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - long metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingConsumer) { super( + LOG, "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, @@ -73,7 +71,6 @@ private GrpcGetWorkerMetadataStream( logEveryNStreamFailures, ""); this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); - this.metadataVersion = metadataVersion; this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer; this.serverMappingConsumer = serverMappingConsumer; this.latestResponse = WorkerMetadataResponse.getDefaultInstance(); @@ -88,23 +85,17 @@ public static GrpcGetWorkerMetadataStream create( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - int metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingUpdater) { - GrpcGetWorkerMetadataStream getWorkerMetadataStream = - new GrpcGetWorkerMetadataStream( - startGetWorkerMetadataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - jobHeader, - metadataVersion, - getWorkerMetadataThrottleTimer, - serverMappingUpdater); - LOG.info("Started GetWorkerMetadataStream. {}", getWorkerMetadataStream); - getWorkerMetadataStream.startStream(); - return getWorkerMetadataStream; + return new GrpcGetWorkerMetadataStream( + startGetWorkerMetadataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + jobHeader, + getWorkerMetadataThrottleTimer, + serverMappingUpdater); } /** @@ -118,25 +109,23 @@ protected void onResponse(WorkerMetadataResponse response) { /** * Acquires the {@link #metadataLock} Returns {@link Optional} if the - * metadataVersion in the response is not stale (older or equal to {@link #metadataVersion}), else - * returns empty {@link Optional}. + * metadataVersion in the response is not stale (older or equal to current {@link + * WorkerMetadataResponse#getMetadataVersion()}), else returns empty {@link Optional}. */ private Optional extractWindmillEndpointsFrom( WorkerMetadataResponse response) { synchronized (metadataLock) { - if (response.getMetadataVersion() > this.metadataVersion) { - this.metadataVersion = response.getMetadataVersion(); + if (response.getMetadataVersion() > latestResponse.getMetadataVersion()) { this.latestResponse = response; return Optional.of(WindmillEndpoints.from(response)); } else { // If the currentMetadataVersion is greater than or equal to one in the response, the // response data is stale, and we do not want to do anything. - LOG.info( - "Received WorkerMetadataResponse={}; Received metadata version={}; Current metadata version={}. " + LOG.debug( + "Received metadata version={}; Current metadata version={}. " + "Skipping update because received stale metadata", - response, response.getMetadataVersion(), - this.metadataVersion); + latestResponse.getMetadataVersion()); } } @@ -144,10 +133,13 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected synchronized void onNewStream() { - send(workerMetadataRequest); + protected void onNewStream() throws WindmillStreamShutdownException { + trySend(workerMetadataRequest); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -159,16 +151,16 @@ protected void startThrottleTimer() { } @Override - protected void sendHealthCheck() { - send(HEALTH_CHECK_REQUEST); + protected void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } @Override protected void appendSpecificHtml(PrintWriter writer) { synchronized (metadataLock) { writer.format( - "GetWorkerMetadataStream: version=[%d] , job_header=[%s], latest_response=[%s]", - this.metadataVersion, workerMetadataRequest.getHeader(), this.latestResponse); + "GetWorkerMetadataStream: job_header=[%s], current_metadata=[%s]", + workerMetadataRequest.getHeader(), latestResponse); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 310495982679..f35b9b23d091 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -290,13 +290,13 @@ private ResponseT callWithBackoff(Supplier function) { e.getStatus()); } if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { - throw new RpcException(e); + throw new WindmillRpcException(e); } } catch (IOException | InterruptedException i) { if (i instanceof InterruptedException) { Thread.currentThread().interrupt(); } - RpcException rpcException = new RpcException(e); + WindmillRpcException rpcException = new WindmillRpcException(e); rpcException.addSuppressed(i); throw rpcException; } @@ -310,7 +310,7 @@ public GetWorkResponse getWork(GetWorkRequest request) { return callWithBackoff(() -> syncApplianceStub.getWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); } @Override @@ -319,7 +319,7 @@ public GetDataResponse getData(GetDataRequest request) { return callWithBackoff(() -> syncApplianceStub.getData(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); } @Override @@ -327,32 +327,53 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { if (syncApplianceStub != null) { return callWithBackoff(() -> syncApplianceStub.commitWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); } + /** + * @implNote Returns a {@link GetWorkStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { - return windmillStreamFactory.createGetWorkStream( - dispatcherClient.getWindmillServiceStub(), - GetWorkRequest.newBuilder(request) - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build(), - throttleTimers.getWorkThrottleTimer(), - receiver); + GetWorkStream getWorkStream = + windmillStreamFactory.createGetWorkStream( + dispatcherClient.getWindmillServiceStub(), + GetWorkRequest.newBuilder(request) + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build(), + throttleTimers.getWorkThrottleTimer(), + receiver); + getWorkStream.start(); + return getWorkStream; } + /** + * @implNote Returns a {@link GetDataStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetDataStream getDataStream() { - return windmillStreamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + GetDataStream getDataStream = + windmillStreamFactory.createGetDataStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + getDataStream.start(); + return getDataStream; } + /** + * @implNote Returns a {@link CommitWorkStream} in the started state (w/ the initial header + * already sent). + */ @Override public CommitWorkStream commitWorkStream() { - return windmillStreamFactory.createCommitWorkStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + CommitWorkStream commitWorkStream = + windmillStreamFactory.createCommitWorkStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; } @Override @@ -361,7 +382,7 @@ public GetConfigResponse getConfig(GetConfigRequest request) { return callWithBackoff(() -> syncApplianceStub.getConfig(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("GetConfig not supported in Streaming Engine.")); } @@ -371,7 +392,7 @@ public ReportStatsResponse reportStats(ReportStatsRequest request) { return callWithBackoff(() -> syncApplianceStub.reportStats(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("ReportStats not supported in Streaming Engine.")); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 9e6a02d135e2..df69af207899 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -17,10 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import com.google.auto.value.AutoBuilder; import java.io.PrintWriter; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.Timer; @@ -29,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; @@ -55,7 +57,9 @@ import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.AbstractStub; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Duration; import org.joda.time.Instant; @@ -66,6 +70,8 @@ @ThreadSafe @Internal public class GrpcWindmillStreamFactory implements StatusDataProvider { + + private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -73,6 +79,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1; private static final int NO_HEALTH_CHECKS = -1; private static final String NO_BACKEND_WORKER_TOKEN = ""; + private static final String DISPATCHER_DEBUG_NAME = "Dispatcher"; private final JobHeader jobHeader; private final int logEveryNStreamFailures; @@ -173,8 +180,20 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { private static > T withDefaultDeadline(T stub) { // Deadlines are absolute points in time, so generate a new one everytime this function is // called. - return stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + return stub.withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + } + + private static void printSummaryHtmlForWorker( + String workerToken, Collection> streams, PrintWriter writer) { + writer.write( + "" + (workerToken.isEmpty() ? DISPATCHER_DEBUG_NAME : workerToken) + ""); + writer.write("
"); + streams.forEach( + stream -> { + stream.appendSummaryHtml(writer); + writer.write("
"); + }); + writer.write("
"); } public GetWorkStream createGetWorkStream( @@ -204,7 +223,7 @@ public GetWorkStream createDirectGetWorkStream( WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), - responseObserver -> withDefaultDeadline(connection.stub()).getWorkStream(responseObserver), + responseObserver -> connection.stub().getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -234,6 +253,23 @@ public GetDataStream createGetDataStream( processHeartbeatResponses); } + public GetDataStream createDirectGetDataStream( + WindmillConnection connection, ThrottleTimer getDataThrottleTimer) { + return GrpcGetDataStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().getDataStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + public CommitWorkStream createCommitWorkStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) { return GrpcCommitWorkStream.create( @@ -249,18 +285,32 @@ public CommitWorkStream createCommitWorkStream( streamingRpcBatchLimit); } + public CommitWorkStream createDirectCommitWorkStream( + WindmillConnection connection, ThrottleTimer commitWorkThrottleTimer) { + return GrpcCommitWorkStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().commitWorkStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit); + } + public GetWorkerMetadataStream createGetWorkerMetadataStream( - CloudWindmillMetadataServiceV1Alpha1Stub stub, + Supplier stub, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer onNewWindmillEndpoints) { return GrpcGetWorkerMetadataStream.create( - responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver), + responseObserver -> withDefaultDeadline(stub.get()).getWorkerMetadata(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, jobHeader, - 0, getWorkerMetadataThrottleTimer, onNewWindmillEndpoints); } @@ -273,10 +323,17 @@ private StreamObserverFactory newStreamObserverFactory() { @Override public void appendSummaryHtml(PrintWriter writer) { writer.write("Active Streams:
"); - for (AbstractWindmillStream stream : streamRegistry) { - stream.appendSummaryHtml(writer); - writer.write("
"); - } + streamRegistry.stream() + .collect( + toImmutableListMultimap( + AbstractWindmillStream::backendWorkerToken, Function.identity())) + .asMap() + .forEach((workerToken, streams) -> printSummaryHtmlForWorker(workerToken, streams, writer)); + } + + @VisibleForTesting + final ImmutableSet> streamRegistry() { + return ImmutableSet.copyOf(streamRegistry); } @Internal diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 9d57df1af317..8710d66d2c80 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -22,8 +22,10 @@ import java.util.concurrent.TimeoutException; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,27 +39,33 @@ * becomes ready. */ @ThreadSafe -public final class DirectStreamObserver implements StreamObserver { +final class DirectStreamObserver implements TerminatingStreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); - private final Phaser phaser; + private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; + private final Phaser isReadyNotifier; + private final long deadlineSeconds; + private final int messagesBetweenIsReadyChecks; private final Object lock = new Object(); @GuardedBy("lock") private final CallStreamObserver outboundObserver; - private final long deadlineSeconds; - private final int messagesBetweenIsReadyChecks; + @GuardedBy("lock") + private boolean isOutboundObserverClosed = false; + + @GuardedBy("lock") + private boolean isUserClosed = false; @GuardedBy("lock") private int messagesSinceReady = 0; - public DirectStreamObserver( - Phaser phaser, + DirectStreamObserver( + Phaser isReadyNotifier, CallStreamObserver outboundObserver, long deadlineSeconds, int messagesBetweenIsReadyChecks) { - this.phaser = phaser; + this.isReadyNotifier = isReadyNotifier; this.outboundObserver = outboundObserver; this.deadlineSeconds = deadlineSeconds; // We always let the first message pass through without blocking because it is performed under @@ -66,6 +74,12 @@ public DirectStreamObserver( this.messagesBetweenIsReadyChecks = Math.max(1, messagesBetweenIsReadyChecks); } + /** + * @throws StreamObserverCancelledException if the StreamObserver was closed via {@link + * #onError(Throwable)}, {@link #onCompleted()}, or {@link #terminate(Throwable)} while + * waiting for {@code outboundObserver#isReady}. + * @throws WindmillRpcException if we time out for waiting for {@code outboundObserver#isReady}. + */ @Override public void onNext(T value) { int awaitPhase = -1; @@ -74,6 +88,24 @@ public void onNext(T value) { while (true) { try { synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; + + // If we awaited previously and timed out, wait for the same phase. Otherwise we're + // careful to observe the phase before observing isReady. + if (awaitPhase < 0) { + awaitPhase = currentPhase; + } + // We only check isReady periodically to effectively allow for increasing the outbound // buffer periodically. This reduces the overhead of blocking while still restricting // memory because there is a limited # of streams, and we have a max messages size of 2MB. @@ -81,25 +113,40 @@ public void onNext(T value) { outboundObserver.onNext(value); return; } - // If we awaited previously and timed out, wait for the same phase. Otherwise we're - // careful to observe the phase before observing isReady. - if (awaitPhase < 0) { - awaitPhase = phaser.getPhase(); - } + if (outboundObserver.isReady()) { messagesSinceReady = 0; outboundObserver.onNext(value); return; } } + // A callback has been registered to advance the phaser whenever the observer // transitions to is ready. Since we are waiting for a phase observed before the // outboundObserver.isReady() returned false, we expect it to advance after the // channel has become ready. This doesn't always seem to be the case (despite // documentation stating otherwise) so we poll periodically and enforce an overall // timeout related to the stream deadline. - phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + int nextPhase = + isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + // If nextPhase is a value less than 0, the phaser has been terminated. + if (nextPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; + messagesSinceReady = 0; outboundObserver.onNext(value); return; @@ -107,36 +154,78 @@ public void onNext(T value) { } catch (TimeoutException e) { totalSecondsWaited += waitSeconds; if (totalSecondsWaited > deadlineSeconds) { - LOG.error( - "Exceeded timeout waiting for the outboundObserver to become ready meaning " - + "that the stream deadline was not respected."); - throw new RuntimeException(e); + String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); + LOG.error(errorMessage); + throw new WindmillRpcException(errorMessage, e); } - if (totalSecondsWaited > 30) { + + if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) { LOG.info( "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, Thread.currentThread().getName()); } + waitSeconds = waitSeconds * 2; } catch (InterruptedException e) { Thread.currentThread().interrupt(); - throw new RuntimeException(e); + throw new StreamObserverCancelledException(e); } } } + /** @throws IllegalStateException if called multiple times or after {@link #onCompleted()}. */ @Override public void onError(Throwable t) { + isReadyNotifier.forceTermination(); synchronized (lock) { - outboundObserver.onError(t); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { + outboundObserver.onError(t); + isOutboundObserverClosed = true; + } } } + /** + * @throws IllegalStateException if called multiple times or after {@link #onError(Throwable)}. + */ @Override public void onCompleted() { + isReadyNotifier.forceTermination(); synchronized (lock) { - outboundObserver.onCompleted(); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { + outboundObserver.onCompleted(); + isOutboundObserverClosed = true; + } } } + + @Override + public void terminate(Throwable terminationException) { + // Free the blocked threads in onNext(). + isReadyNotifier.forceTermination(); + synchronized (lock) { + if (!isOutboundObserverClosed) { + outboundObserver.onError(terminationException); + isOutboundObserverClosed = true; + } + } + } + + private String constructStreamCancelledErrorMessage(long totalSecondsWaited) { + return deadlineSeconds > 0 + ? "Waited " + + totalSecondsWaited + + "s which exceeds given deadline of " + + deadlineSeconds + + "s for the outboundObserver to become ready meaning " + + "that the stream deadline was not respected." + : "Output channel has been blocked for " + + totalSecondsWaited + + "s. Restarting stream internally."; + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java index 4ea209f31b1d..70fd3497a37f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java @@ -21,11 +21,15 @@ @Internal public final class StreamObserverCancelledException extends RuntimeException { - public StreamObserverCancelledException(Throwable cause) { + StreamObserverCancelledException(Throwable cause) { super(cause); } - public StreamObserverCancelledException(String message, Throwable cause) { + StreamObserverCancelledException(String message, Throwable cause) { super(message, cause); } + + StreamObserverCancelledException(String message) { + super(message); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java index cb4415bdab18..01e854492bf9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java @@ -33,7 +33,7 @@ public static StreamObserverFactory direct( return new Direct(deadlineSeconds, messagesBetweenIsReadyChecks); } - public abstract StreamObserver from( + public abstract TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver responseObserver); @@ -47,7 +47,7 @@ private static class Direct extends StreamObserverFactory { } @Override - public StreamObserver from( + public TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver inboundObserver) { AdvancingPhaser phaser = new AdvancingPhaser(1); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java new file mode 100644 index 000000000000..5fb4f95e3e1e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; + +@Internal +public interface TerminatingStreamObserver extends StreamObserver { + + /** + * Terminates the StreamObserver. + * + * @implSpec Different then {@link #onError(Throwable)} and {@link #onCompleted()} which can only + * be called once during the lifetime of each {@link StreamObserver}, terminate() + * implementations are meant to be idempotent and can be called multiple times as well as + * being interleaved with other stream operations. + */ + void terminate(Throwable terminationException); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java index 33a55d1927f8..ed5f2db7f480 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java @@ -20,8 +20,8 @@ import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public void sendHeartbeats(Heartbeats heartbeats) { Thread.currentThread().setName(originalThreadName + "-" + backendWorkerToken); } getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { LOG.warn( "Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker." + " heartbeats", diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java index 071bf7fa3d43..6d768e8a972c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java @@ -21,12 +21,14 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; /** Heartbeat requests and the work that was used to generate the heartbeat requests. */ +@Internal @AutoValue -abstract class Heartbeats { +public abstract class Heartbeats { static Heartbeats.Builder builder() { return new AutoValue_Heartbeats.Builder(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 90ffb3d3fbcf..1da48bd2b7ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -236,6 +236,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -299,6 +302,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -380,6 +386,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 0092fcc7bcd1..bba6cad5529a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -40,6 +40,8 @@ import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; @@ -79,6 +81,7 @@ @RunWith(JUnit4.class) public class FanOutStreamingEngineWorkerHarnessTest { + private static final String CHANNEL_NAME = "FanOutStreamingEngineWorkerHarnessTest"; private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS = WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST, 443)); private static final ImmutableMap DEFAULT = @@ -105,9 +108,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); private final ChannelCachingStubFactory stubFactory = new FakeWindmillStubFactory( - () -> - grpcCleanup.register( - WindmillChannelFactory.inProcessChannel("StreamingEngineClientTest"))); + () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(CHANNEL_NAME))); private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class), @@ -148,7 +149,7 @@ public void setUp() throws IOException { stubFactory.shutdown(); fakeStreamingEngineServer = grpcCleanup.register( - InProcessServerBuilder.forName("StreamingEngineClientTest") + InProcessServerBuilder.forName(CHANNEL_NAME) .fallbackHandlerRegistry(serviceRegistry) .executor(Executors.newFixedThreadPool(1)) .build()); @@ -156,18 +157,18 @@ public void setUp() throws IOException { fakeStreamingEngineServer.start(); dispatcherClient.consumeWindmillDispatcherEndpoints( ImmutableSet.of( - HostAndPort.fromString( - new InProcessSocketAddress("StreamingEngineClientTest").toString()))); + HostAndPort.fromString(new InProcessSocketAddress(CHANNEL_NAME).toString()))); getWorkerMetadataReady = new CountDownLatch(1); fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); serviceRegistry.addService(fakeGetWorkerMetadataStub); + serviceRegistry.addService(new WindmillServiceFakeStub()); } @After public void cleanUp() { Preconditions.checkNotNull(fanOutStreamingEngineWorkProvider).shutdown(); - fakeStreamingEngineServer.shutdownNow(); stubFactory.shutdown(); + fakeStreamingEngineServer.shutdownNow(); } private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( @@ -241,16 +242,15 @@ public void testStreamsStartCorrectly() throws InterruptedException { any(), eq(noOpProcessWorkItemFn())); - verify(streamFactory, times(2)).createGetDataStream(any(), any()); - verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); + verify(streamFactory, times(2)).createDirectGetDataStream(any(), any()); + verify(streamFactory, times(2)).createDirectCommitWorkStream(any(), any()); } @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { - int metadataCount = 2; TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(metadataCount)); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), @@ -285,6 +285,8 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); @@ -342,6 +344,54 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } + private static class WindmillServiceFakeStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetDataRequest getDataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetWorkRequest getWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + } + private static class GetWorkerMetadataTestStub extends CloudWindmillMetadataServiceV1Alpha1Grpc .CloudWindmillMetadataServiceV1Alpha1ImplBase { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index 32d1f5738086..aa2767d5472d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -17,13 +17,13 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; @@ -96,7 +96,7 @@ public void testStartStream_startsAllStreams() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); verify(streamFactory) .createDirectGetWorkStream( @@ -113,8 +113,8 @@ public void testStartStream_startsAllStreams() { any(), eq(workItemScheduler)); - verify(streamFactory).createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); - verify(streamFactory).createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory).createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); + verify(streamFactory).createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -126,9 +126,9 @@ public void testStartStream_onlyStartsStreamsOnce() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); + windmillStreamSender.start(); + windmillStreamSender.start(); verify(streamFactory, times(1)) .createDirectGetWorkStream( @@ -146,9 +146,9 @@ public void testStartStream_onlyStartsStreamsOnce() { eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -160,10 +160,10 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - Thread startStreamThread = new Thread(windmillStreamSender::startStreams); + Thread startStreamThread = new Thread(windmillStreamSender::start); startStreamThread.start(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); startStreamThread.join(); @@ -183,23 +183,52 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test - public void testCloseAllStreams_doesNotCloseUnstartedStreams() { + public void testCloseAllStreams_closesAllStreams() { + long itemBudget = 1L; + long byteBudget = 1L; + GetWorkRequest getWorkRequestWithBudget = + GET_WORK_REQUEST.toBuilder().setMaxItems(itemBudget).setMaxBytes(byteBudget).build(); + GrpcWindmillStreamFactory mockStreamFactory = mock(GrpcWindmillStreamFactory.class); + GetWorkStream mockGetWorkStream = mock(GetWorkStream.class); + GetDataStream mockGetDataStream = mock(GetDataStream.class); + CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class); + + when(mockStreamFactory.createDirectGetWorkStream( + eq(connection), + eq(getWorkRequestWithBudget), + any(ThrottleTimer.class), + any(), + any(), + any(), + eq(workItemScheduler))) + .thenReturn(mockGetWorkStream); + + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockGetDataStream); + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockCommitWorkStream); + WindmillStreamSender windmillStreamSender = - newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), + mockStreamFactory); + windmillStreamSender.start(); windmillStreamSender.close(); - verifyNoInteractions(streamFactory); + verify(mockGetWorkStream).shutdown(); + verify(mockGetDataStream).shutdown(); + verify(mockCommitWorkStream).shutdown(); } @Test - public void testCloseAllStreams_closesAllStreams() { + public void testCloseAllStreams_doesNotStartStreamsAfterClose() { long itemBudget = 1L; long byteBudget = 1L; GetWorkRequest getWorkRequestWithBudget = @@ -219,9 +248,9 @@ public void testCloseAllStreams_closesAllStreams() { eq(workItemScheduler))) .thenReturn(mockGetWorkStream); - when(mockStreamFactory.createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockGetDataStream); - when(mockStreamFactory.createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockCommitWorkStream); WindmillStreamSender windmillStreamSender = @@ -229,14 +258,30 @@ public void testCloseAllStreams_closesAllStreams() { GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), mockStreamFactory); - windmillStreamSender.startStreams(); windmillStreamSender.close(); + verify(mockGetWorkStream, times(0)).start(); + verify(mockGetDataStream, times(0)).start(); + verify(mockCommitWorkStream, times(0)).start(); + verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); verify(mockCommitWorkStream).shutdown(); } + @Test + public void testStartStream_afterCloseThrows() { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + windmillStreamSender.close(); + assertThrows(IllegalStateException.class, windmillStreamSender::start); + } + private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { return newWindmillStreamSender(budget, streamFactory); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java new file mode 100644 index 000000000000..05fbc6f969df --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import java.io.PrintWriter; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class AbstractWindmillStreamTest { + private static final long DEADLINE_SECONDS = 10; + private final Set> streamRegistry = ConcurrentHashMap.newKeySet(); + private final StreamObserverFactory streamObserverFactory = + StreamObserverFactory.direct(DEADLINE_SECONDS, 1); + + @Before + public void setUp() { + streamRegistry.clear(); + } + + private TestStream newStream( + Function, StreamObserver> clientFactory) { + return new TestStream(clientFactory, streamRegistry, streamObserverFactory); + } + + @Test + public void testShutdown_notBlockedBySend() throws InterruptedException, ExecutionException { + CountDownLatch sendBlocker = new CountDownLatch(1); + Function, StreamObserver> clientFactory = + ignored -> + new CallStreamObserver() { + @Override + public void onNext(Integer integer) { + try { + sendBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + }; + + TestStream testStream = newStream(clientFactory); + testStream.start(); + ExecutorService sendExecutor = Executors.newSingleThreadExecutor(); + Future sendFuture = + sendExecutor.submit( + () -> + assertThrows(WindmillStreamShutdownException.class, () -> testStream.testSend(1))); + testStream.shutdown(); + + // Sleep a bit to give sendExecutor time to execute the send(). + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + + sendBlocker.countDown(); + assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); + } + + private static class TestStream extends AbstractWindmillStream { + private final AtomicInteger numStarts = new AtomicInteger(); + + private TestStream( + Function, StreamObserver> clientFactory, + Set> streamRegistry, + StreamObserverFactory streamObserverFactory) { + super( + LoggerFactory.getLogger(AbstractWindmillStreamTest.class), + "Test", + clientFactory, + FluentBackoff.DEFAULT.backoff(), + streamObserverFactory, + streamRegistry, + 1, + "Test"); + } + + @Override + protected void onResponse(Integer response) {} + + @Override + protected void onNewStream() { + numStarts.incrementAndGet(); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + protected void startThrottleTimer() {} + + public void testSend(Integer i) + throws ResettableThrowingStreamObserver.StreamClosedException, + WindmillStreamShutdownException { + trySend(i); + } + + @Override + protected void sendHealthCheck() {} + + @Override + protected void appendSpecificHtml(PrintWriter writer) {} + + @Override + protected void shutdownInternal() {} + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java new file mode 100644 index 000000000000..790c155d94d6 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class ResettableThrowingStreamObserverTest { + private final TerminatingStreamObserver delegate = newDelegate(); + + private static TerminatingStreamObserver newDelegate() { + return spy( + new TerminatingStreamObserver() { + @Override + public void onNext(Integer integer) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public void terminate(Throwable terminationException) {} + }); + } + + @Test + public void testPoison_beforeDelegateSet() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + verifyNoInteractions(delegate); + } + + @Test + public void testPoison_afterDelegateSet() throws WindmillStreamShutdownException { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.reset(); + observer.poison(); + verify(delegate).terminate(isA(WindmillStreamShutdownException.class)); + } + + @Test + public void testReset_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::reset); + } + + @Test + public void testOnNext_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); + } + + @Test + public void testOnError_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows( + WindmillStreamShutdownException.class, + () -> observer.onError(new RuntimeException("something bad happened."))); + } + + @Test + public void testOnCompleted_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); + } + + @Test + public void testReset_usesNewDelegate() + throws WindmillStreamShutdownException, + ResettableThrowingStreamObserver.StreamClosedException { + List> delegates = new ArrayList<>(); + ResettableThrowingStreamObserver observer = + newStreamObserver( + () -> { + TerminatingStreamObserver delegate = newDelegate(); + delegates.add(delegate); + return delegate; + }); + observer.reset(); + observer.onNext(1); + observer.reset(); + observer.onNext(2); + + StreamObserver firstObserver = delegates.get(0); + StreamObserver secondObserver = delegates.get(1); + + verify(firstObserver).onNext(eq(1)); + verify(secondObserver).onNext(eq(2)); + } + + private ResettableThrowingStreamObserver newStreamObserver( + Supplier> delegate) { + return new ResettableThrowingStreamObserver<>(delegate, LoggerFactory.getLogger(getClass())); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java new file mode 100644 index 000000000000..564b2e664505 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.function.Supplier; +import org.joda.time.DateTime; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StreamDebugMetricsTest { + + @Test + public void testSummaryMetrics_noRestarts() throws InterruptedException { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.recordStart(); + streamDebugMetrics.recordSend(); + streamDebugMetrics.recordResponse(); + Thread.sleep(1000); + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertFalse(metricsSnapshot.shutdownTime().isPresent()); + assertTrue(metricsSnapshot.sleepLeft() <= 0); + assertThat(metricsSnapshot.streamAge()).isGreaterThan(0); + assertThat(metricsSnapshot.timeSinceLastSend()).isGreaterThan(0); + assertThat(metricsSnapshot.timeSinceLastResponse()).isGreaterThan(0); + assertFalse(metricsSnapshot.restartMetrics().isPresent()); + + streamDebugMetrics.recordShutdown(); + StreamDebugMetrics.Snapshot metricsSnapshotAfterShutdown = + streamDebugMetrics.getSummaryMetrics(); + assertTrue(metricsSnapshotAfterShutdown.shutdownTime().isPresent()); + } + + @Test + public void testSummaryMetrics_sleep() { + long sleepMs = 100; + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + streamDebugMetrics.recordSleep(sleepMs); + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertEquals(sleepMs, metricsSnapshot.sleepLeft()); + } + + @Test + public void testSummaryMetrics_withRestarts() { + String restartReason = "something bad happened"; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.incrementAndGetErrors(); + streamDebugMetrics.incrementAndGetRestarts(); + streamDebugMetrics.recordRestartReason(restartReason); + + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertTrue(metricsSnapshot.restartMetrics().isPresent()); + StreamDebugMetrics.RestartMetrics restartMetrics = metricsSnapshot.restartMetrics().get(); + assertThat(restartMetrics.lastRestartReason()).isEqualTo(restartReason); + assertThat(restartMetrics.restartCount()).isEqualTo(1); + assertThat(restartMetrics.errorCount()).isEqualTo(1); + assertTrue(restartMetrics.lastRestartTime().isPresent()); + assertThat(restartMetrics.lastRestartTime().get()).isLessThan(DateTime.now()); + assertThat(restartMetrics.lastRestartTime().get().toInstant()).isGreaterThan(Instant.EPOCH); + } + + @Test + public void testResponseDebugString_neverReceivedResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + assertFalse(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testResponseDebugString_withResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.recordResponse(); + assertTrue(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testGetStartTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getStartTimeMs()); + streamDebugMetrics.recordStart(); + assertThat(streamDebugMetrics.getStartTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } + + @Test + public void testGetLastSendTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getLastSendTimeMs()); + streamDebugMetrics.recordSend(); + assertThat(streamDebugMetrics.getLastSendTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index bdad382c9af2..fdd213223987 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -260,5 +260,8 @@ public String backendWorkerToken() { public void shutdown() { halfClose(); } + + @Override + public void start() {} } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..8bb057156d20 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -60,13 +61,15 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { - + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; @@ -261,6 +264,10 @@ public void testStop_drainsCommitQueue() { Supplier fakeCommitWorkStream = () -> new CommitWorkStream() { + + @Override + public void start() {} + @Override public RequestBatcher batcher() { return new RequestBatcher() { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java new file mode 100644 index 000000000000..7de824b86fd2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.spy; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; + +@RunWith(JUnit4.class) +public class GrpcCommitWorkStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String COMPUTATION_ID = "computationId"; + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(value) + .setWorkToken(value) + .setCacheToken(value) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcCommitWorkStream commitWorkStream = + (GrpcCommitWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createCommitWorkStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; + } + + @Test + public void testShutdown_abortsQueuedCommits() throws InterruptedException { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + Set onDone = new HashSet<>(); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + InOrder requestObserverVerifier = inOrder(requestObserver); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> { + onDone.add(commitStatus); + commitProcessed.countDown(); + }); + } + } catch (StreamObserverCancelledException ignored) { + } + + // Verify that we sent the commits above in a request + the initial header. + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> !request.getCommitChunkList().isEmpty())); + requestObserverVerifier.verifyNoMoreInteractions(); + + // We won't get responses so we will have some pending requests. + assertTrue(commitWorkStream.hasPendingRequests()); + commitWorkStream.shutdown(); + commitProcessed.await(); + + assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED); + } + + @Test + public void testCommitWorkItem_afterShutdown() { + int numCommits = 5; + + CommitWorkStreamTestStub testStub = + new CommitWorkStreamTestStub(new TestCommitWorkStreamRequestObserver()); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue(batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + } + } + commitWorkStream.shutdown(); + + AtomicReference commitStatus = new AtomicReference<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set)); + } + } + + assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED); + } + + @Test + public void testSend_notCalledAfterShutdown() { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + InOrder requestObserverVerifier = inOrder(requestObserver); + + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> commitProcessed.countDown())); + } + // Shutdown the stream before we exit the try-with-resources block which will try to send() + // the batched request. + commitWorkStream.shutdown(); + } + + // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends + // the header, which happens before we shutdown. + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier.verify(requestObserver).onError(any()); + requestObserverVerifier.verifyNoMoreInteractions(); + } + + private static class TestCommitWorkStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class CommitWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestCommitWorkStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private CommitWorkStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java index fd2b30238836..6584ed1c5ae6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -117,8 +117,8 @@ public void setUp() throws IOException { @After public void cleanUp() { - inProcessChannel.shutdownNow(); checkNotNull(stream).shutdown(); + inProcessChannel.shutdownNow(); } private GrpcDirectGetWorkStream createGetWorkStream( @@ -127,26 +127,29 @@ private GrpcDirectGetWorkStream createGetWorkStream( ThrottleTimer throttleTimer, WorkItemScheduler workItemScheduler) { serviceRegistry.addService(testStub); - return (GrpcDirectGetWorkStream) - GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) - .build() - .createDirectGetWorkStream( - WindmillConnection.builder() - .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) - .build(), - Windmill.GetWorkRequest.newBuilder() - .setClientId(TEST_JOB_HEADER.getClientId()) - .setJobId(TEST_JOB_HEADER.getJobId()) - .setProjectId(TEST_JOB_HEADER.getProjectId()) - .setWorkerId(TEST_JOB_HEADER.getWorkerId()) - .setMaxItems(initialGetWorkBudget.items()) - .setMaxBytes(initialGetWorkBudget.bytes()) - .build(), - throttleTimer, - mock(HeartbeatSender.class), - mock(GetDataClient.class), - mock(WorkCommitter.class), - workItemScheduler); + GrpcDirectGetWorkStream getWorkStream = + (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + getWorkStream.start(); + return getWorkStream; } private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java new file mode 100644 index 000000000000..dc2dce7807a9 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamRequestsTest { + + @Test + public void testQueuedRequest_globalRequestsFirstComparator() { + List requests = new ArrayList<>(); + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst()); + + // First one should be the global request. + assertTrue(requests.get(0).getDataRequest().isGlobal()); + } + + @Test + public void testQueuedBatch_asGetDataRequest() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + Windmill.StreamingGetDataRequest getDataRequest = queuedBatch.asGetDataRequest(); + + assertThat(getDataRequest.getRequestIdCount()).isEqualTo(3); + assertThat(getDataRequest.getGlobalDataRequestList()).containsExactly(globalDataRequest); + assertThat(getDataRequest.getStateRequestList()) + .containsExactly( + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation1") + .addRequests(keyedGetDataRequest1) + .build(), + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation2") + .addRequests(keyedGetDataRequest2) + .build()); + } + + @Test + public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOnWaiters() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + CompletableFuture waitFuture = + CompletableFuture.supplyAsync( + () -> + assertThrows( + WindmillStreamShutdownException.class, + queuedBatch::waitForSendOrFailNotification)); + // Wait a few seconds for the above future to get scheduled and run. + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + queuedBatch.notifyFailed(); + waitFuture.join(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java new file mode 100644 index 000000000000..3125def64b32 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetDataStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcGetDataStream getDataStream = + (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setSendKeyedGetDataRequests(false) + .build() + .createGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + getDataStream.start(); + return getDataStream; + } + + @Test + public void testRequestKeyedData() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + // These will block until they are successfully sent. + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + // Sleep a bit to allow future to run. + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + + Windmill.KeyedGetDataResponse response = + Windmill.KeyedGetDataResponse.newBuilder() + .setShardingKey(1) + .setKey(ByteString.EMPTY) + .build(); + + testStub.injectResponse( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(response.toByteString()) + .setRemainingBytesForResponse(0) + .build()); + + assertThat(sendFuture.join()).isEqualTo(response); + } + + @Test + public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + int numSendThreads = 5; + ExecutorService getDataStreamSenders = Executors.newFixedThreadPool(numSendThreads); + CountDownLatch waitForSendAttempt = new CountDownLatch(1); + // These will block until they are successfully sent. + List> sendFutures = + IntStream.range(0, 5) + .sequential() + .mapToObj( + i -> + (Runnable) + () -> { + // Prevent some threads from sending until we close the stream. + if (i % 2 == 0) { + try { + waitForSendAttempt.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + try { + getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(i) + .setCacheToken(i) + .setWorkToken(i) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }) + // Run the code above on multiple threads. + .map(runnable -> CompletableFuture.runAsync(runnable, getDataStreamSenders)) + .collect(Collectors.toList()); + + getDataStream.shutdown(); + + // Free up waiting threads so that they can try to send on a closed stream. + waitForSendAttempt.countDown(); + + for (int i = 0; i < numSendThreads; i++) { + CompletableFuture sendFuture = sendFutures.get(i); + try { + // Wait for future to complete. + sendFuture.join(); + } catch (Exception ignored) { + } + if (i % 2 == 0) { + assertTrue(sendFuture.isCompletedExceptionally()); + ExecutionException e = assertThrows(ExecutionException.class, sendFuture::get); + assertThat(e) + .hasCauseThat() + .hasCauseThat() + .isInstanceOf(WindmillStreamShutdownException.class); + } + } + } + + private static class TestGetDataStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingGetDataRequest streamingGetDataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class GetDataStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetDataStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetDataStreamTestStub(TestGetDataStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetDataResponse getDataResponse) { + checkNotNull(responseObserver).onNext(getDataResponse); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 4439c409b32f..d74735ee3052 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; @@ -38,10 +37,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; @@ -82,28 +79,24 @@ public class GrpcGetWorkerMetadataStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); + private final GrpcWindmillStreamFactory streamFactory = + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER).build(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( GetWorkerMetadataTestStub getWorkerMetadataTestStub, - int metadataVersion, Consumer endpointsConsumer) { serviceRegistry.addService(getWorkerMetadataTestStub); - return GrpcGetWorkerMetadataStream.create( - responseObserver -> - CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getWorkerMetadata(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, // logEveryNStreamFailures - TEST_JOB_HEADER, - metadataVersion, - new ThrottleTimer(), - endpointsConsumer); + GrpcGetWorkerMetadataStream getWorkerMetadataStream = + (GrpcGetWorkerMetadataStream) + streamFactory.createGetWorkerMetadataStream( + () -> CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer(), + endpointsConsumer); + getWorkerMetadataStream.start(); + return getWorkerMetadataStream; } @Before @@ -146,8 +139,7 @@ public void testGetWorkerMetadata() { new TestWindmillEndpointsConsumer(); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = -1; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(mockResponse); assertThat(testWindmillEndpointsConsumer.globalDataEndpoints.keySet()) @@ -175,8 +167,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(initialResponse); List newDirectPathEndpoints = @@ -222,8 +213,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { Mockito.spy(new TestWindmillEndpointsConsumer()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(freshEndpoints); List staleDirectPathEndpoints = @@ -252,7 +242,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); testStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -260,17 +250,17 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) .build()); - assertTrue(streamRegistry.contains(stream)); + assertTrue(streamFactory.streamRegistry().contains(stream)); stream.halfClose(); - assertFalse(streamRegistry.contains(stream)); + assertFalse(streamFactory.streamRegistry().contains(stream)); } @Test - public void testSendHealthCheck() { + public void testSendHealthCheck() throws WindmillStreamShutdownException { TestGetWorkMetadataRequestObserver requestObserver = Mockito.spy(new TestGetWorkMetadataRequestObserver()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); stream.sendHealthCheck(); verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 239e3979a3b7..a595524ca582 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -22,12 +22,14 @@ import java.io.InputStream; import java.io.SequenceInputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -71,6 +73,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactoryFactory; @@ -115,6 +118,7 @@ public class GrpcWindmillServerTest { private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); private static final int STREAM_CHUNK_SIZE = 2 << 20; private final long clientId = 10L; + private final Set openedChannels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -131,16 +135,18 @@ public void setUp() throws Exception { @After public void tearDown() throws Exception { server.shutdownNow(); + openedChannels.forEach(ManagedChannel::shutdownNow); } private void startServerAndClient(List experiments) throws Exception { String name = "Fake server for " + getClass(); this.server = - InProcessServerBuilder.forName(name) - .fallbackHandlerRegistry(serviceRegistry) - .executor(Executors.newFixedThreadPool(1)) - .build() - .start(); + grpcCleanup.register( + InProcessServerBuilder.forName(name) + .fallbackHandlerRegistry(serviceRegistry) + .executor(Executors.newFixedThreadPool(1)) + .build() + .start()); this.client = GrpcWindmillServer.newTestInstance( @@ -149,7 +155,12 @@ private void startServerAndClient(List experiments) throws Exception { clientId, new FakeWindmillStubFactoryFactory( new FakeWindmillStubFactory( - () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name))))); + () -> { + ManagedChannel channel = + grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)); + openedChannels.add(channel); + return channel; + }))); } private void maybeInjectError(Stream stream) { @@ -460,8 +471,9 @@ private void flushResponse() { "Sending batched response of {} ids", responseBuilder.getRequestIdCount()); try { responseObserver.onNext(responseBuilder.build()); - } catch (IllegalStateException e) { + } catch (Exception e) { // Stream is already closed. + LOG.warn(Arrays.toString(e.getStackTrace())); } responseBuilder.clear(); } @@ -480,16 +492,24 @@ private void flushResponse() { final String s = i % 5 == 0 ? largeString(i) : "tag"; executor.submit( () -> { - errorCollector.checkThat( - stream.requestKeyedData("computation", makeGetDataRequest(key, s)), - Matchers.equalTo(makeGetDataResponse(s))); + try { + errorCollector.checkThat( + stream.requestKeyedData("computation", makeGetDataRequest(key, s)), + Matchers.equalTo(makeGetDataResponse(s))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); executor.execute( () -> { - errorCollector.checkThat( - stream.requestGlobalData(makeGlobalDataRequest(key)), - Matchers.equalTo(makeGlobalDataResponse(key))); + try { + errorCollector.checkThat( + stream.requestGlobalData(makeGlobalDataRequest(key)), + Matchers.equalTo(makeGlobalDataResponse(key))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java new file mode 100644 index 000000000000..6bc713aa7747 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; +import org.apache.beam.sdk.fn.stream.AdvancingPhaser; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DirectStreamObserverTest { + + @Test + public void testOnNext_onCompleted() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnNext_onError() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnCompleted_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + streamObserver.onCompleted(); + assertThrows(IllegalStateException.class, streamObserver::onCompleted); + } + + @Test + public void testOnError_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + assertThrows(IllegalStateException.class, () -> streamObserver.onError(error)); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnNext_waitForReady() throws InterruptedException, ExecutionException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future<@Nullable Object> onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + streamObserver.onNext(1); + blockLatch.countDown(); + return null; + } catch (Throwable e) { + return e; + } + }); + + while (delegate.getNumIsReadyChecks() <= 1) { + // Wait for isReady check to block. + Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); + } + + delegate.setIsReady(true); + blockLatch.await(); + verify(delegate, times(3)).onNext(eq(1)); + assertNull(onNextFuture.get()); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testTerminate_waitingForReady() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + blockLatch.countDown(); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + } catch (Throwable e) { + return e; + } + + return new VerifyException(); + }); + RuntimeException terminationException = new RuntimeException("terminated"); + + assertTrue(blockLatch.await(5, TimeUnit.SECONDS)); + streamObserver.terminate(terminationException); + assertThat(onNextFuture.get()).isInstanceOf(StreamObserverCancelledException.class); + verify(delegate).onError(same(terminationException)); + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_interruption() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + StreamObserverCancelledException e = + assertThrows( + StreamObserverCancelledException.class, () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + onNextExecutor.shutdownNow(); + assertTrue(streamObserverExitLatch.await(5, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(InterruptedException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_timeOut() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + WindmillServerStub.WindmillRpcException e = + assertThrows( + WindmillServerStub.WindmillRpcException.class, + () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + assertTrue(streamObserverExitLatch.await(10, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(TimeoutException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + private static class TestStreamObserver extends CallStreamObserver { + private final CountDownLatch sendBlocker; + private final int blockAfter; + private final AtomicInteger seen = new AtomicInteger(0); + private final AtomicInteger numIsReadyChecks = new AtomicInteger(0); + private volatile boolean isReady = false; + + private TestStreamObserver(int blockAfter) { + this.blockAfter = blockAfter; + this.sendBlocker = new CountDownLatch(1); + } + + @Override + public void onNext(Integer integer) { + try { + if (seen.incrementAndGet() == blockAfter) { + sendBlocker.await(); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + numIsReadyChecks.incrementAndGet(); + return isReady; + } + + public int getNumIsReadyChecks() { + return numIsReadyChecks.get(); + } + + private void setIsReady(boolean isReadyOverride) { + isReady = isReadyOverride; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + } +}