Skip to content

Commit

Permalink
enable WindmillStreams to be started by an external caller and passed…
Browse files Browse the repository at this point in the history
… around in an unstarted state. Start the streams in WindmillStreamSender in parallel
  • Loading branch information
m-trieu committed Oct 12, 2024
1 parent 4f5b381 commit b6045f1
Show file tree
Hide file tree
Showing 19 changed files with 234 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;

Expand All @@ -35,7 +34,6 @@
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
Expand Down Expand Up @@ -98,6 +96,7 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
private final ExecutorService windmillStreamManager;
private final ExecutorService workerMetadataConsumer;
private final Object metadataLock;
private final GetWorkerMetadataStream getWorkerMetadataStream;

/** Writes are guarded by synchronization, reads are lock free. */
private final AtomicReference<StreamingEngineBackends> backends;
Expand All @@ -111,10 +110,6 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
@GuardedBy("this")
private boolean started;

/** Set once when {@link #start()} is called. */
@GuardedBy("this")
private @Nullable GetWorkerMetadataStream getWorkerMetadataStream = null;

private FanOutStreamingEngineWorkerHarness(
JobHeader jobHeader,
GetWorkBudget totalGetWorkBudget,
Expand Down Expand Up @@ -145,6 +140,13 @@ private FanOutStreamingEngineWorkerHarness(
this.activeMetadataVersion = Long.MIN_VALUE;
this.workCommitterFactory = workCommitterFactory;
this.metadataLock = new Object();
@SuppressWarnings("methodref.receiver.bound")
GetWorkerMetadataStream getWorkerMetadataStream =
streamFactory.createGetWorkerMetadataStream(
dispatcherClient::getWindmillMetadataServiceStubBlocking,
getWorkerMetadataThrottleTimer,
this::consumeWorkerMetadata);
this.getWorkerMetadataStream = getWorkerMetadataStream;
}

/**
Expand Down Expand Up @@ -203,12 +205,8 @@ static FanOutStreamingEngineWorkerHarness forTesting(
@SuppressWarnings("ReturnValueIgnored")
@Override
public synchronized void start() {
Preconditions.checkState(!started, "StreamingEngineClient cannot start twice.");
getWorkerMetadataStream =
streamFactory.createGetWorkerMetadataStream(
dispatcherClient.getWindmillMetadataServiceStubBlocking(),
getWorkerMetadataThrottleTimer,
this::consumeWorkerMetadata);
Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice.");
getWorkerMetadataStream.start();
started = true;
}

Expand All @@ -235,7 +233,7 @@ private GetDataStream getGlobalDataStream(String globalDataKey) {
@VisibleForTesting
@Override
public synchronized void shutdown() {
Preconditions.checkState(started, "StreamingEngineClient never started.");
Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started.");
Preconditions.checkNotNull(getWorkerMetadataStream).halfClose();
workerMetadataConsumer.shutdownNow();
channelCachingStubFactory.shutdown();
Expand Down Expand Up @@ -371,20 +369,19 @@ private ImmutableMap<String, GlobalDataStreamSender> createNewGlobalDataStreams(
toImmutableMap(
Entry::getKey,
keyedEndpoint ->
existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams)));
getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams)));
}

private GlobalDataStreamSender existingOrNewGetDataStreamFor(
private GlobalDataStreamSender getOrCreateGlobalDataSteam(
Entry<String, Endpoint> keyedEndpoint,
ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams) {
return checkNotNull(
currentGlobalDataStreams.getOrDefault(
keyedEndpoint.getKey(),
new GlobalDataStreamSender(
() ->
return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey()))
.orElseGet(
() ->
new GlobalDataStreamSender(
streamFactory.createGetDataStream(
createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()),
keyedEndpoint.getValue())));
keyedEndpoint.getValue()));
}

private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;

@Internal
@ThreadSafe
final class GlobalDataStreamSender implements Closeable, Supplier<GetDataStream> {
private final Endpoint endpoint;
private final Supplier<GetDataStream> delegate;
private final GetDataStream delegate;
private volatile boolean started;

GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
// Make sure the call to get() is
this.delegate = Suppliers.memoize(delegate::get);
GlobalDataStreamSender(GetDataStream delegate, Endpoint endpoint) {
this.delegate = delegate;
this.started = false;
this.endpoint = endpoint;
}
Expand All @@ -43,17 +41,14 @@ final class GlobalDataStreamSender implements Closeable, Supplier<GetDataStream>
public GetDataStream get() {
if (!started) {
started = true;
delegate.start();
}
return delegate.get();
return delegate;
}

@Override
public void close() {
if (started) {
// get() may start the stream which is expensive, don't call it if the stream was never
// started.
delegate.get().shutdown();
}
delegate.shutdown();
}

Endpoint endpoint() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.Closeable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
Expand All @@ -37,7 +41,6 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetSpender;
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;

/**
* Owns and maintains a set of streams used to communicate with a specific Windmill worker.
Expand All @@ -61,13 +64,15 @@
@Internal
@ThreadSafe
final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable {
private static final String STREAM_STARTER_THREAD_NAME = "StartWindmillStreamThread-%d";
private final AtomicBoolean started;
private final AtomicReference<GetWorkBudget> getWorkBudget;
private final Supplier<GetWorkStream> getWorkStream;
private final Supplier<GetDataStream> getDataStream;
private final Supplier<CommitWorkStream> commitWorkStream;
private final Supplier<WorkCommitter> workCommitter;
private final GetWorkStream getWorkStream;
private final GetDataStream getDataStream;
private final CommitWorkStream commitWorkStream;
private final WorkCommitter workCommitter;
private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
private final ExecutorService streamStarter;

private WindmillStreamSender(
WindmillConnection connection,
Expand All @@ -81,33 +86,28 @@ private WindmillStreamSender(
this.getWorkBudget = getWorkBudget;
this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create();

// All streams are memoized/cached since they are expensive to create and some implementations
// perform side effects on construction (i.e. sending initial requests to the stream server to
// initiate the streaming RPC connection). Stream instances connect/reconnect internally, so we
// can reuse the same instance through the entire lifecycle of WindmillStreamSender.
// Stream instances connect/reconnect internally, so we can reuse the same instance through the
// entire lifecycle of WindmillStreamSender.
this.getDataStream =
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createDirectGetDataStream(
connection, streamingEngineThrottleTimers.getDataThrottleTimer()));
streamingEngineStreamFactory.createDirectGetDataStream(
connection, streamingEngineThrottleTimers.getDataThrottleTimer());
this.commitWorkStream =
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createDirectCommitWorkStream(
connection, streamingEngineThrottleTimers.commitWorkThrottleTimer()));
this.workCommitter =
Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get()));
streamingEngineStreamFactory.createDirectCommitWorkStream(
connection, streamingEngineThrottleTimers.commitWorkThrottleTimer());
this.workCommitter = workCommitterFactory.apply(commitWorkStream);
this.getWorkStream =
Suppliers.memoize(
() ->
streamingEngineStreamFactory.createDirectGetWorkStream(
connection,
withRequestBudget(getWorkRequest, getWorkBudget.get()),
streamingEngineThrottleTimers.getWorkThrottleTimer(),
() -> FixedStreamHeartbeatSender.create(getDataStream.get()),
() -> getDataClientFactory.apply(getDataStream.get()),
workCommitter,
workItemScheduler));
streamingEngineStreamFactory.createDirectGetWorkStream(
connection,
withRequestBudget(getWorkRequest, getWorkBudget.get()),
streamingEngineThrottleTimers.getWorkThrottleTimer(),
FixedStreamHeartbeatSender.create(getDataStream),
getDataClientFactory.apply(getDataStream),
workCommitter,
workItemScheduler);
// 3 threads, 1 for each stream type (GetWork, GetData, CommitWork).
this.streamStarter =
Executors.newFixedThreadPool(
3, new ThreadFactoryBuilder().setNameFormat(STREAM_STARTER_THREAD_NAME).build());
}

static WindmillStreamSender create(
Expand All @@ -133,13 +133,18 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB
}

@SuppressWarnings("ReturnValueIgnored")
void start() {
getWorkStream.get();
getDataStream.get();
commitWorkStream.get();
workCommitter.get().start();
// *stream.get() is all memoized in a threadsafe manner.
started.set(true);
synchronized void start() {
if (!started.get()) {
// Start these 3 streams in parallel since they each may perform blocking IO.
CompletableFuture.allOf(
CompletableFuture.runAsync(getWorkStream::start, streamStarter),
CompletableFuture.runAsync(getDataStream::start, streamStarter),
CompletableFuture.runAsync(commitWorkStream::start, streamStarter))
.join();
workCommitter.start();
// start() is idempotent in a threadsafe manner.
started.set(true);
}
}

@Override
Expand All @@ -148,10 +153,10 @@ public void close() {
// streaming RPCs by possibly making calls over the network. Do not close the streams unless
// they have already been started.
if (started.get()) {
getWorkStream.get().shutdown();
getDataStream.get().shutdown();
workCommitter.get().stop();
commitWorkStream.get().shutdown();
getWorkStream.shutdown();
getDataStream.shutdown();
workCommitter.stop();
commitWorkStream.shutdown();
}
}

Expand All @@ -160,7 +165,7 @@ public void setBudget(long items, long bytes) {
GetWorkBudget adjustment = GetWorkBudget.builder().setItems(items).setBytes(bytes).build();
getWorkBudget.set(adjustment);
if (started.get()) {
getWorkStream.get().setBudget(adjustment);
getWorkStream.setBudget(adjustment);
}
}

Expand All @@ -169,6 +174,6 @@ long getAndResetThrottleTime() {
}

long getCurrentActiveCommitBytes() {
return started.get() ? workCommitter.get().currentActiveCommitBytes() : 0;
return workCommitter.currentActiveCommitBytes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public abstract class AbstractWindmillStream<RequestT, ResponseT> implements Win
private final String backendWorkerToken;
private final ResettableRequestObserver<RequestT> requestObserver;
private final AtomicBoolean isShutdown;
private final AtomicBoolean started;
private final AtomicReference<DateTime> shutdownTime;

/**
Expand Down Expand Up @@ -122,6 +123,7 @@ protected AbstractWindmillStream(
this.logEveryNStreamFailures = logEveryNStreamFailures;
this.clientClosed = new AtomicBoolean();
this.isShutdown = new AtomicBoolean(false);
this.started = new AtomicBoolean(false);
this.streamClosed = new AtomicBoolean(false);
this.startTimeMs = new AtomicLong();
this.lastSendTimeMs = new AtomicLong();
Expand Down Expand Up @@ -177,7 +179,7 @@ protected boolean isShutdown() {
private StreamObserver<RequestT> requestObserver() {
if (requestObserver == null) {
throw new NullPointerException(
"requestObserver cannot be null. Missing a call to startStream() to initialize.");
"requestObserver cannot be null. Missing a call to start() to initialize stream.");
}

return requestObserver;
Expand Down Expand Up @@ -208,8 +210,17 @@ protected final void send(RequestT request) {
}
}

@Override
public final void start() {
if (!isShutdown.get() && started.compareAndSet(false, true)) {
// start() should only be executed once during the lifetime of the stream for idempotency and
// when shutdown() has not been called.
startStream();
}
}

/** Starts the underlying stream. */
protected final void startStream() {
private void startStream() {
// Add the stream to the registry after it has been fully constructed.
streamRegistry.add(this);
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
@ThreadSafe
public interface WindmillStream {

/**
* Start the stream, opening a connection to the backend server. A call to start() is required for
* any further interactions on the stream.
*/
void start();

/** An identifier for the backend worker where the stream is sending/receiving RPCs. */
String backendWorkerToken();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,17 @@ static GrpcCommitWorkStream create(
JobHeader jobHeader,
AtomicLong idGenerator,
int streamingRpcBatchLimit) {
GrpcCommitWorkStream commitWorkStream =
new GrpcCommitWorkStream(
backendWorkerToken,
startCommitWorkRpcFn,
backoff,
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
commitWorkThrottleTimer,
jobHeader,
idGenerator,
streamingRpcBatchLimit);
commitWorkStream.startStream();
return commitWorkStream;
return new GrpcCommitWorkStream(
backendWorkerToken,
startCommitWorkRpcFn,
backoff,
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
commitWorkThrottleTimer,
jobHeader,
idGenerator,
streamingRpcBatchLimit);
}

@Override
Expand Down
Loading

0 comments on commit b6045f1

Please sign in to comment.