diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index 5eb691cbf55a..9130327a4969 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -125,6 +125,8 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce // If the delegate above was already terminated via onError or onComplete from another // thread. logger.warn("StreamObserver was previously cancelled.", e); + } catch (RuntimeException ignored) { + logger.warn("StreamObserver was unexpectedly cancelled.", e); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 19eb6dd4915a..b5b49c8ee976 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -187,9 +187,14 @@ protected void onResponse(StreamingGetDataResponse chunk) { onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - synchronized (this) { - verify(responseStream != null || isShutdown, "No pending response stream"); + @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); + if (responseStream == null) { + synchronized (this) { + // shutdown()/shutdownInternal() cleans up pending, else we expect a pending + // responseStream for every response. + verify(isShutdown, "No pending response stream"); + } + continue; } responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 3a289e4dd48b..13a26959d2e1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; +import static org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.Verify.verify; + import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -39,27 +41,28 @@ * becomes ready. */ @ThreadSafe -public final class DirectStreamObserver implements TerminatingStreamObserver { +final class DirectStreamObserver implements TerminatingStreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; private final Phaser isReadyNotifier; - + private final long deadlineSeconds; + private final int messagesBetweenIsReadyChecks; private final Object lock = new Object(); @GuardedBy("lock") private final CallStreamObserver outboundObserver; - private final long deadlineSeconds; - private final int messagesBetweenIsReadyChecks; - @GuardedBy("lock") private boolean isClosed = false; + @GuardedBy("lock") + private boolean isUserClosed = false; + @GuardedBy("lock") private int messagesSinceReady = 0; - public DirectStreamObserver( + DirectStreamObserver( Phaser isReadyNotifier, CallStreamObserver outboundObserver, long deadlineSeconds, @@ -89,6 +92,9 @@ public void onNext(T value) throws StreamObserverCancelledException { throw new StreamObserverCancelledException("StreamObserver was terminated."); } + // We close under "lock", so this should never happen. + verify(!isClosed); + // If we awaited previously and timed out, wait for the same phase. Otherwise we're // careful to observe the phase before observing isReady. if (awaitPhase < 0) { @@ -131,6 +137,10 @@ public void onNext(T value) throws StreamObserverCancelledException { if (currentPhase < 0) { throw new StreamObserverCancelledException("StreamObserver was terminated."); } + + // We close under "lock", so this should never happen. + verify(!isClosed); + messagesSinceReady = 0; outboundObserver.onNext(value); return; @@ -162,8 +172,11 @@ public void onNext(T value) throws StreamObserverCancelledException { public void onError(Throwable t) { isReadyNotifier.forceTermination(); synchronized (lock) { - markClosedOrThrow(); - outboundObserver.onError(t); + if (!isClosed) { + Preconditions.checkState(!isUserClosed); + outboundObserver.onError(t); + isClosed = true; + } } } @@ -171,15 +184,11 @@ public void onError(Throwable t) { public void onCompleted() { isReadyNotifier.forceTermination(); synchronized (lock) { - markClosedOrThrow(); - outboundObserver.onCompleted(); - } - } - - private void markClosedOrThrow() { - synchronized (lock) { - Preconditions.checkState(!isClosed); - isClosed = true; + if (!isClosed) { + Preconditions.checkState(!isUserClosed); + outboundObserver.onCompleted(); + isClosed = true; + } } } @@ -188,8 +197,9 @@ public void terminate(Throwable terminationException) { // Free the blocked threads in onNext(). isReadyNotifier.forceTermination(); synchronized (lock) { - if (!isClosed) { + if (!isUserClosed) { onError(terminationException); + isUserClosed = true; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index 036158f5289e..05fbc6f969df 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -108,7 +108,7 @@ public void setMessageCompression(boolean b) {} testStream.shutdown(); // Sleep a bit to give sendExecutor time to execute the send(). - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); sendBlocker.countDown(); assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 316ff76eb929..b7e8f50f9249 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -23,8 +23,6 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import java.io.IOException; import java.util.HashSet; @@ -53,12 +51,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.InOrder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { - private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStreamTest.class); private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() @@ -126,6 +121,7 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { spy(new TestCommitWorkStreamRequestObserver()); CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + InOrder requestObserverVerifier = inOrder(requestObserver); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { batcher.commitWorkItem( @@ -140,21 +136,14 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { } // Verify that we sent the commits above in a request + the initial header. - verify(requestObserver, times(2)) - .onNext( - argThat( - request -> { - if (request.getHeader().equals(TEST_JOB_HEADER)) { - LOG.info("Header received."); - return true; - } else if (!request.getCommitChunkList().isEmpty()) { - LOG.info("Chunk received."); - return true; - } else { - LOG.error("Incorrect request."); - return false; - } - })); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> !request.getCommitChunkList().isEmpty())); + requestObserverVerifier.verifyNoMoreInteractions(); + // We won't get responses so we will have some pending requests. assertTrue(commitWorkStream.hasPendingRequests()); commitWorkStream.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java index a6120f4052b3..dc2dce7807a9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -145,7 +145,7 @@ public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOn WindmillStreamShutdownException.class, queuedBatch::waitForSendOrFailNotification)); // Wait a few seconds for the above future to get scheduled and run. - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); queuedBatch.notifyFailed(); waitFuture.join(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index 252a73c92319..e5e77e16abef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -129,7 +129,7 @@ public void testRequestKeyedData() { }); // Sleep a bit to allow future to run. - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); Windmill.KeyedGetDataResponse response = Windmill.KeyedGetDataResponse.newBuilder() diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 4f0552959ee1..a595524ca582 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -513,7 +513,7 @@ private void flushResponse() { done.countDown(); }); } - while (done.await(5, TimeUnit.SECONDS)) {} + done.await(); stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); executor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java index 374c5aec3b5b..6a51ddc07d1a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -19,9 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -39,6 +41,8 @@ import org.apache.beam.sdk.fn.stream.AdvancingPhaser; import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,10 +50,123 @@ @RunWith(JUnit4.class) public class DirectStreamObserverTest { + @Test + public void testOnNext_onCompleted() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnNext_onError() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnCompleted_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + streamObserver.onCompleted(); + streamObserver.onCompleted(); + streamObserver.onCompleted(); + + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnError_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + streamObserver.onError(error); + streamObserver.onError(error); + + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnNext_waitForReady() throws InterruptedException, ExecutionException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future<@Nullable Object> onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + streamObserver.onNext(1); + blockLatch.countDown(); + return null; + } catch (Throwable e) { + return e; + } + }); + + while (delegate.getNumIsReadyChecks() <= 1) { + // Wait for isReady check to block. + Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); + } + + delegate.setIsReady(true); + blockLatch.await(); + verify(delegate, times(3)).onNext(eq(1)); + assertNull(onNextFuture.get()); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + @Test public void testTerminate_waitingForReady() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); @@ -82,8 +199,7 @@ public void testTerminate_waitingForReady() throws ExecutionException, Interrupt @Test public void testOnNext_interruption() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); @@ -116,8 +232,7 @@ public void testOnNext_interruption() throws ExecutionException, InterruptedExce @Test public void testOnNext_timeOut() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); @@ -152,11 +267,12 @@ private static class TestStreamObserver extends CallStreamObserver { private final CountDownLatch sendBlocker; private final int blockAfter; private final AtomicInteger seen = new AtomicInteger(0); + private final AtomicInteger numIsReadyChecks = new AtomicInteger(0); private volatile boolean isReady = false; - private TestStreamObserver(CountDownLatch sendBlocker, int blockAfter) { + private TestStreamObserver(int blockAfter) { this.blockAfter = blockAfter; - this.sendBlocker = sendBlocker; + this.sendBlocker = new CountDownLatch(1); } @Override @@ -178,9 +294,14 @@ public void onCompleted() {} @Override public boolean isReady() { + numIsReadyChecks.incrementAndGet(); return isReady; } + public int getNumIsReadyChecks() { + return numIsReadyChecks.get(); + } + private void setIsReady(boolean isReadyOverride) { isReady = isReadyOverride; }