Skip to content

Commit

Permalink
address PR comments around deadlocking, move WindmillStreamShutdownEx…
Browse files Browse the repository at this point in the history
…ception to its own top level class
  • Loading branch information
m-trieu committed Oct 22, 2024
1 parent 2426a6b commit 9679382
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public abstract class AbstractWindmillStream<RequestT, ResponseT> implements Win
/**
* Used to guard {@link #start()} and {@link #shutdown()} behavior.
*
* @implNote Should not be held when performing IO.
* @implNote Do not hold when performing IO. If also locking on {@code this} in the same context,
* should acquire shutdownLock first to prevent deadlocks.
*/
protected final Object shutdownLock = new Object();

Expand Down Expand Up @@ -184,15 +185,6 @@ protected boolean isShutdown() {
return isShutdown;
}

private StreamObserver<RequestT> requestObserver() {
if (requestObserver == null) {
throw new NullPointerException(
"requestObserver cannot be null. Missing a call to start() to initialize stream.");
}

return requestObserver;
}

/** Send a request to the server. */
protected final void send(RequestT request) {
synchronized (this) {
Expand Down Expand Up @@ -221,14 +213,17 @@ protected final void send(RequestT request) {

@Override
public final void start() {
boolean shouldStartStream = false;
synchronized (shutdownLock) {
if (!isShutdown && !started) {
// start() should only be executed once during the lifetime of the stream for idempotency
// and when shutdown() has not been called.
startStream();
started = true;
shouldStartStream = true;
}
}

if (shouldStartStream) {
startStream();
}
}

/** Starts the underlying stream. */
Expand Down Expand Up @@ -366,8 +361,8 @@ public final void shutdown() {
if (!isShutdown) {
isShutdown = true;
shutdownTime.set(DateTime.now());
requestObserver()
.onError(new WindmillStreamShutdownException("Explicit call to shutdown stream."));
requestObserver.onError(
new WindmillStreamShutdownException("Explicit call to shutdown stream."));
shutdownInternal();
}
}
Expand All @@ -380,12 +375,6 @@ private void recordRestartReason(String error) {

protected abstract void shutdownInternal();

public static class WindmillStreamShutdownException extends RuntimeException {
public WindmillStreamShutdownException(String message) {
super(message);
}
}

/**
* Request observer that allows resetting its internal delegate using the given {@link
* #requestObserverSupplier}.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.windmill.client;

/** Thrown when operations are requested on a {@link WindmillStream} has been shutdown/closed. */
public final class WindmillStreamShutdownException extends RuntimeException {
public WindmillStreamShutdownException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import java.util.function.Function;
import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
import org.apache.beam.sdk.annotations.Internal;

/** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */
Expand Down Expand Up @@ -61,7 +61,7 @@ public Windmill.KeyedGetDataResponse getStateData(
String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException {
try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) {
return getDataStream.requestKeyedData(computationId, request);
} catch (AbstractWindmillStream.WindmillStreamShutdownException e) {
} catch (WindmillStreamShutdownException e) {
throw new WorkItemCancelledException(request.getShardingKey());
} catch (Exception e) {
throw new GetDataException(
Expand All @@ -86,7 +86,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request)
sideInputGetDataStreamFactory.apply(request.getDataId().getTag());
try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) {
return sideInputGetDataStream.requestGlobalData(request);
} catch (AbstractWindmillStream.WindmillStreamShutdownException e) {
} catch (WindmillStreamShutdownException e) {
throw new WorkItemCancelledException(e);
} catch (Exception e) {
throw new GetDataException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,16 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) {
.setShardingKey(pendingRequest.shardingKey())
.setSerializedWorkItemCommit(pendingRequest.serializedCommit());
StreamingCommitWorkRequest chunk = requestBuilder.build();
synchronized (this) {
synchronized (shutdownLock) {
if (!isShutdown()) {
pending.put(id, pendingRequest);
} else {
return;
}
}
try {
send(chunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
if (shouldCancelRequest(id, pendingRequest)) {
pendingRequest.abort();
return;
}

try {
send(chunk);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.

}
}

Expand All @@ -265,33 +262,28 @@ private void issueBatchedRequest(Map<Long, PendingRequest> requests) {
.setSerializedWorkItemCommit(request.serializedCommit());
}
StreamingCommitWorkRequest request = requestBuilder.build();
synchronized (this) {
synchronized (shutdownLock) {
if (!isShutdown()) {
pending.putAll(requests);
} else {
return;
}
}
try {
send(request);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}

if (shouldCancelRequest(requests)) {
requests.forEach((ignored, pendingRequest) -> pendingRequest.abort());
return;
}

try {
send(request);
} catch (IllegalStateException e) {
// Stream was broken, request will be retried when stream is reopened.
}
}

private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) {
private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) {
checkNotNull(pendingRequest.computationId());
final ByteString serializedCommit = pendingRequest.serializedCommit();
ByteString serializedCommit = pendingRequest.serializedCommit();
if (shouldCancelRequest(id, pendingRequest)) {
pendingRequest.abort();
return;
}

synchronized (this) {
synchronized (shutdownLock) {
if (!isShutdown()) {
pending.put(id, pendingRequest);
} else {
return;
}
}
for (int i = 0;
i < serializedCommit.size();
i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) {
Expand Down Expand Up @@ -321,6 +313,32 @@ private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest
}
}

private boolean shouldCancelRequest(long id, PendingRequest request) {
synchronized (shutdownLock) {
synchronized (this) {
if (!isShutdown()) {
pending.put(id, request);
return false;
}

return true;
}
}
}

private boolean shouldCancelRequest(Map<Long, PendingRequest> requests) {
synchronized (shutdownLock) {
synchronized (this) {
if (!isShutdown()) {
pending.putAll(requests);
return false;
}

return true;
}
}
}

@AutoValue
abstract static class PendingRequest {

Expand Down Expand Up @@ -402,6 +420,11 @@ private Batcher() {
@Override
public boolean commitWorkItem(
String computation, WorkItemCommitRequest commitRequest, Consumer<CommitStatus> onDone) {
if (isShutdown()) {
onDone.accept(CommitStatus.ABORTED);
return false;
}

if (!canAccept(commitRequest.getSerializedSize() + computation.length()) || isShutdown()) {
return false;
}
Expand All @@ -418,7 +441,7 @@ public void flush() {
if (!isShutdown()) {
flushInternal(queue);
} else {
queue.forEach((ignored, request) -> request.onDone().accept(CommitStatus.ABORTED));
queue.forEach((ignored, request) -> request.abort());
}
} finally {
queuedBytes = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
Expand Down Expand Up @@ -198,7 +199,8 @@ public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataReq
return issueRequest(
QueuedRequest.forComputation(uniqueId(), computation, request),
KeyedGetDataResponse::parseFrom);
} catch (WindmillStreamShutdownException e) {
} catch (
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) {
throw new WorkItemCancelledException(request.getShardingKey());
}
}
Expand All @@ -207,7 +209,8 @@ public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataReq
public GlobalData requestGlobalData(GlobalDataRequest request) {
try {
return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom);
} catch (WindmillStreamShutdownException e) {
} catch (
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) {
throw new WorkItemCancelledException(
"SideInput fetch failed for request due to stream shutdown: " + request, e);
}
Expand All @@ -216,7 +219,8 @@ public GlobalData requestGlobalData(GlobalDataRequest request) {
@Override
public void refreshActiveWork(Map<String, Collection<HeartbeatRequest>> heartbeats) {
if (isShutdown()) {
throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream.");
throw new org.apache.beam.runners.dataflow.worker.windmill.client
.WindmillStreamShutdownException("Unable to refresh work for shutdown stream.");
}

StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
Expand Down Expand Up @@ -354,18 +358,24 @@ private <ResponseT> ResponseT issueRequest(QueuedRequest request, ParseFn<Respon
}
}

throw new WindmillStreamShutdownException(
"Cannot send request=[" + request + "] on closed stream.");
throw new org.apache.beam.runners.dataflow.worker.windmill.client
.WindmillStreamShutdownException("Cannot send request=[" + request + "] on closed stream.");
}

private void handleShutdown(QueuedRequest request, Throwable cause) {
if (cause instanceof WindmillStreamShutdownException) {
throw (WindmillStreamShutdownException) cause;
if (cause
instanceof
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException) {
throw (org.apache.beam.runners.dataflow.worker.windmill.client
.WindmillStreamShutdownException)
cause;
}
if (isShutdown()) {
WindmillStreamShutdownException shutdownException =
new WindmillStreamShutdownException(
"Cannot send request=[" + request + "] on closed stream.");
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException
shutdownException =
new org.apache.beam.runners.dataflow.worker.windmill.client
.WindmillStreamShutdownException(
"Cannot send request=[" + request + "] on closed stream.");
shutdownException.addSuppressed(cause);
throw shutdownException;
}
Expand Down Expand Up @@ -431,11 +441,15 @@ void trySendBatch(QueuedBatch batch) {

@SuppressWarnings("NullableProblems")
private void sendBatch(List<QueuedRequest> requests) {
if (requests.isEmpty()) {
return;
}

StreamingGetDataRequest batchedRequest = flushToBatch(requests);
synchronized (this) {
synchronized (shutdownLock) {
// Synchronization of pending inserts is necessary with send to ensure duplicates are not
// sent on stream reconnect.
synchronized (shutdownLock) {
synchronized (this) {
// shutdown() clears pending, once the stream is shutdown, prevent values from being added
// to it.
if (isShutdown()) {
Expand All @@ -448,12 +462,13 @@ private void sendBatch(List<QueuedRequest> requests) {
verify(pending.put(request.id(), request.getResponseStream()) == null);
}
}
try {
send(batchedRequest);
} catch (IllegalStateException e) {
// The stream broke before this call went through; onNewStream will retry the fetch.
LOG.warn("GetData stream broke before call started.", e);
}
}

try {
send(batchedRequest);
} catch (IllegalStateException e) {
// The stream broke before this call went through; onNewStream will retry the fetch.
LOG.warn("GetData stream broke before call started.", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
import org.apache.beam.sdk.annotations.Internal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -61,7 +61,7 @@ public void sendHeartbeats(Heartbeats heartbeats) {
Thread.currentThread().setName(originalThreadName + "-" + backendWorkerToken);
}
getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap());
} catch (AbstractWindmillStream.WindmillStreamShutdownException e) {
} catch (WindmillStreamShutdownException e) {
LOG.warn(
"Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker."
+ " heartbeats",
Expand Down
Loading

0 comments on commit 9679382

Please sign in to comment.