Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Nov 15, 2024
1 parent f75df0f commit 74e503d
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;
Expand All @@ -35,6 +34,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
Expand Down Expand Up @@ -64,9 +64,9 @@
import org.apache.beam.sdk.util.MoreFutures;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
Expand Down Expand Up @@ -239,7 +239,7 @@ public synchronized void shutdown() {
Preconditions.checkNotNull(getWorkerMetadataStream).shutdown();
workerMetadataConsumer.shutdownNow();
// Close all the streams blocking until this completes to not leak resources.
closeStreamsNotIn(WindmillEndpoints.none()).forEach(CompletableFuture::join);
closeStreamsNotIn(WindmillEndpoints.none()).join();
channelCachingStubFactory.shutdown();

try {
Expand Down Expand Up @@ -304,10 +304,9 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi

/** Close the streams that are no longer valid asynchronously. */
@CanIgnoreReturnValue
private ImmutableList<CompletableFuture<Void>> closeStreamsNotIn(
WindmillEndpoints newWindmillEndpoints) {
private CompletableFuture<Void> closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
StreamingEngineBackends currentBackends = backends.get();
List<CompletableFuture<Void>> closeStreamFutures =
Stream<CompletableFuture<Void>> closeStreamFutures =
currentBackends.windmillStreams().entrySet().stream()
.filter(
connectionAndStream ->
Expand All @@ -318,24 +317,21 @@ private ImmutableList<CompletableFuture<Void>> closeStreamsNotIn(
entry ->
CompletableFuture.runAsync(
() -> closeStreamSender(entry.getKey(), entry.getValue()),
windmillStreamManager))
.collect(Collectors.toList());
windmillStreamManager));

Set<Endpoint> newGlobalDataEndpoints =
new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
List<CompletableFuture<Void>> closeGlobalDataStreamFutures =
Stream<CompletableFuture<Void>> closeGlobalDataStreamFutures =
currentBackends.globalDataStreams().values().stream()
.filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
.map(
sender ->
CompletableFuture.runAsync(
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager))
.collect(Collectors.toList());
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager));

return ImmutableList.<CompletableFuture<Void>>builder()
.addAll(closeStreamFutures)
.addAll(closeGlobalDataStreamFutures)
.build();
return CompletableFuture.allOf(
Streams.concat(closeStreamFutures, closeGlobalDataStreamFutures)
.toArray(CompletableFuture[]::new));
}

private void closeStreamSender(Endpoint endpoint, StreamSender sender) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
@AutoValue
public abstract class WindmillEndpoints {
private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class);
private static final WindmillEndpoints NO_ENDPOINTS =
WindmillEndpoints.builder()
.setVersion(Long.MAX_VALUE)
.setWindmillEndpoints(ImmutableSet.of())
.setGlobalDataEndpoints(ImmutableMap.of())
.build();

public static WindmillEndpoints none() {
return WindmillEndpoints.builder()
.setVersion(Long.MAX_VALUE)
.setWindmillEndpoints(ImmutableSet.of())
.setGlobalDataEndpoints(ImmutableMap.of())
.build();
return NO_ENDPOINTS;
}

public static WindmillEndpoints from(
Expand Down Expand Up @@ -132,10 +134,6 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
directEndpointAddress.getHostAddress(), (int) endpointProto.getPort()));
}

public final boolean isEmpty() {
return equals(none());
}

/** Version of the endpoints which increases with every modification. */
public abstract long version();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.windmill.client.ResettableThrowingStreamObserver.StreamClosedException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.sdk.util.BackOff;
Expand Down Expand Up @@ -159,7 +158,7 @@ protected final synchronized boolean trySend(RequestT request)
try {
requestObserver.onNext(request);
return true;
} catch (StreamClosedException e) {
} catch (ResettableThrowingStreamObserver.StreamClosedException e) {
// Stream was broken, requests may be retried when stream is reopened.
}

Expand Down Expand Up @@ -199,6 +198,7 @@ private void startStream() {
} catch (WindmillStreamShutdownException e) {
// shutdown() is responsible for cleaning up pending requests.
logger.debug("Stream was shutdown while creating new stream.", e);
break;
} catch (Exception e) {
logger.error("Failed to create new stream, retrying: ", e);
try {
Expand Down Expand Up @@ -298,10 +298,12 @@ public final synchronized void halfClose() {
clientClosed = true;
try {
requestObserver.onCompleted();
} catch (StreamClosedException e) {
} catch (ResettableThrowingStreamObserver.StreamClosedException e) {
logger.warn("Stream was previously closed.");
} catch (WindmillStreamShutdownException e) {
logger.warn("Stream was previously shutdown.");
} catch (IllegalStateException e) {
logger.warn("Unexpected error when trying to close stream", e);
}
}

Expand All @@ -320,7 +322,6 @@ public String backendWorkerToken() {
return backendWorkerToken;
}

@SuppressWarnings("GuardedBy")
@Override
public final void shutdown() {
// Don't lock on "this" before poisoning the request observer since otherwise the observer may
Expand All @@ -339,9 +340,7 @@ public final void shutdown() {

/** Returns true if the stream was torn down and should not be restarted internally. */
private synchronized boolean maybeTearDownStream() {
if (requestObserver.hasReceivedPoisonPill()
|| isShutdown
|| (clientClosed && !hasPendingRequests())) {
if (isShutdown || (clientClosed && !hasPendingRequests())) {
streamRegistry.remove(AbstractWindmillStream.this);
finishLatch.countDown();
executor.shutdownNow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,33 +104,35 @@ synchronized void poison() {
}
}

synchronized boolean hasReceivedPoisonPill() {
return isPoisoned;
}

public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException {
// Make sure onNext and onError below to be called on the same StreamObserver instance.
StreamObserver<T> delegate = delegate();
try {
// Do NOT lock while sending message over the stream as this will block other StreamObserver
// operations.
delegate.onNext(t);
} catch (StreamObserverCancelledException e) {
} catch (StreamObserverCancelledException cancellationException) {
synchronized (this) {
if (isPoisoned) {
logger.debug("Stream was shutdown during send.", e);
logger.debug("Stream was shutdown during send.", cancellationException);
return;
}
}

try {
delegate.onError(e);
} catch (IllegalStateException ignored) {
delegate.onError(cancellationException);
} catch (IllegalStateException onErrorException) {
// If the delegate above was already terminated via onError or onComplete from another
// thread.
logger.warn("StreamObserver was previously cancelled.", e);
} catch (RuntimeException ignored) {
logger.warn("StreamObserver was unexpectedly cancelled.", e);
logger.warn(
"StreamObserver was already cancelled {} due to error.",
onErrorException,
cancellationException);
} catch (RuntimeException onErrorException) {
logger.warn(
"Encountered unexpected error {} when cancelling due to error.",
onErrorException,
cancellationException);
}
}
}
Expand All @@ -156,7 +158,7 @@ synchronized boolean isClosed() {
* {@link StreamObserver#onCompleted()}. The stream may perform
*/
static final class StreamClosedException extends Exception {
private StreamClosedException(String s) {
StreamClosedException(String s) {
super(s);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,4 @@ public final class WindmillStreamShutdownException extends Exception {
public WindmillStreamShutdownException(String message) {
super(message);
}

/** Returns whether an exception was caused by a {@link WindmillStreamShutdownException}. */
public static boolean isCauseOf(Throwable t) {
while (t != null) {
if (t instanceof WindmillStreamShutdownException) {
return true;
}
t = t.getCause();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ final class DirectStreamObserver<T> implements TerminatingStreamObserver<T> {
private final CallStreamObserver<T> outboundObserver;

@GuardedBy("lock")
private boolean isClosed = false;
private boolean isOutboundObserverClosed = false;

@GuardedBy("lock")
private boolean isUserClosed = false;
Expand All @@ -74,8 +74,14 @@ final class DirectStreamObserver<T> implements TerminatingStreamObserver<T> {
this.messagesBetweenIsReadyChecks = Math.max(1, messagesBetweenIsReadyChecks);
}

/**
* @throws StreamObserverCancelledException if the StreamObserver was closed via {@link
* #onError(Throwable)}, {@link #onCompleted()}, or {@link #terminate(Throwable)} while
* waiting for {@code outboundObserver#isReady}.
* @throws WindmillRpcException if we time out for waiting for {@code outboundObserver#isReady}.
*/
@Override
public void onNext(T value) throws StreamObserverCancelledException {
public void onNext(T value) {
int awaitPhase = -1;
long totalSecondsWaited = 0;
long waitSeconds = 1;
Expand All @@ -90,8 +96,9 @@ public void onNext(T value) throws StreamObserverCancelledException {
throw new StreamObserverCancelledException("StreamObserver was terminated.");
}

// We close under "lock", so this should never happen.
assert !isClosed;
// Closing is performed under "lock" after terminating, so if termination was not observed
// above, the observer should not be closed.
assert !isOutboundObserverClosed;

// If we awaited previously and timed out, wait for the same phase. Otherwise we're
// careful to observe the phase before observing isReady.
Expand Down Expand Up @@ -136,8 +143,9 @@ public void onNext(T value) throws StreamObserverCancelledException {
throw new StreamObserverCancelledException("StreamObserver was terminated.");
}

// We close under "lock", so this should never happen.
assert !isClosed;
// Closing is performed under "lock" after terminating, so if termination was not observed
// above, the observer should not be closed.
assert !isOutboundObserverClosed;

messagesSinceReady = 0;
outboundObserver.onNext(value);
Expand Down Expand Up @@ -166,26 +174,32 @@ public void onNext(T value) throws StreamObserverCancelledException {
}
}

/** @throws IllegalStateException if called multiple times or after {@link #onCompleted()}. */
@Override
public void onError(Throwable t) {
isReadyNotifier.forceTermination();
synchronized (lock) {
if (!isClosed) {
Preconditions.checkState(!isUserClosed);
Preconditions.checkState(!isUserClosed);
isUserClosed = true;
if (!isOutboundObserverClosed) {
outboundObserver.onError(t);
isClosed = true;
isOutboundObserverClosed = true;
}
}
}

/**
* @throws IllegalStateException if called multiple times or after {@link #onError(Throwable)}.
*/
@Override
public void onCompleted() {
isReadyNotifier.forceTermination();
synchronized (lock) {
if (!isClosed) {
Preconditions.checkState(!isUserClosed);
Preconditions.checkState(!isUserClosed);
isUserClosed = true;
if (!isOutboundObserverClosed) {
outboundObserver.onCompleted();
isClosed = true;
isOutboundObserverClosed = true;
}
}
}
Expand All @@ -195,9 +209,9 @@ public void terminate(Throwable terminationException) {
// Free the blocked threads in onNext().
isReadyNotifier.forceTermination();
synchronized (lock) {
if (!isUserClosed) {
onError(terminationException);
isUserClosed = true;
if (!isOutboundObserverClosed) {
outboundObserver.onError(terminationException);
isOutboundObserverClosed = true;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
Expand Down Expand Up @@ -167,15 +168,15 @@ public void testCommitWorkItem_afterShutdown() {
}
commitWorkStream.shutdown();

Set<Windmill.CommitStatus> commitStatuses = new HashSet<>();
AtomicReference<Windmill.CommitStatus> commitStatus = new AtomicReference<>();
try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) {
for (int i = 0; i < numCommits; i++) {
assertTrue(
batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add));
batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set));
}
}

assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED);
assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow
}
}
try {

getDataStream.requestKeyedData(
"computationId",
Windmill.KeyedGetDataRequest.newBuilder()
Expand Down
Loading

0 comments on commit 74e503d

Please sign in to comment.