Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Oct 10, 2024
1 parent 3db3b7f commit 06836da
Show file tree
Hide file tree
Showing 16 changed files with 102 additions and 339 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
/**
* Represents the state of an attempt to process a {@link WorkItem} by executing user code.
*
* @implNote Not thread safe, should not be executed or accessed by more than 1 thread at a time.
* @implNote Not thread safe, should not be modified by more than 1 thread at a time.
*/
@NotThreadSafe
@Internal
Expand All @@ -70,7 +70,7 @@ public final class Work implements RefreshableWork {
private final Map<LatencyAttribution.State, Duration> totalDurationPerState;
private final WorkId id;
private final String latencyTrackingId;
private TimedState currentState;
private volatile TimedState currentState;
private volatile boolean isFailed;
private volatile String processingThreadName = "";

Expand Down Expand Up @@ -112,7 +112,7 @@ public static ProcessingContext createProcessingContext(
Consumer<Commit> workCommitter,
HeartbeatSender heartbeatSender) {
return ProcessingContext.create(
computationId, getDataClient, workCommitter, heartbeatSender, "");
computationId, getDataClient, workCommitter, heartbeatSender, /* backendWorkerToken= */ "");
}

public static ProcessingContext createProcessingContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
Expand Down Expand Up @@ -61,12 +62,10 @@
import org.apache.beam.sdk.util.MoreFutures;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.checkerframework.checker.initialization.qual.UnderInitialization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -81,7 +80,8 @@
public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness {
private static final Logger LOG =
LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class);
private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = "WorkerMetadataConsumerThread";
private static final String WORKER_METADATA_CONSUMER_THREAD_NAME =
"WindmillWorkerMetadataConsumerThread";
private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d";

private final JobHeader jobHeader;
Expand All @@ -92,7 +92,6 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
private final GetWorkBudgetDistributor getWorkBudgetDistributor;
private final GetWorkBudget totalGetWorkBudget;
private final ThrottleTimer getWorkerMetadataThrottleTimer;
private final Supplier<GetWorkerMetadataStream> getWorkerMetadataStream;
private final Function<WindmillStream.CommitWorkStream, WorkCommitter> workCommitterFactory;
private final ThrottlingGetDataMetricTracker getDataMetricTracker;
private final ExecutorService windmillStreamManager;
Expand All @@ -111,6 +110,10 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
@GuardedBy("this")
private boolean started;

/** Set once when {@link #start()} is called. */
@GuardedBy("this")
private @Nullable GetWorkerMetadataStream getWorkerMetadataStream = null;

private FanOutStreamingEngineWorkerHarness(
JobHeader jobHeader,
GetWorkBudget totalGetWorkBudget,
Expand Down Expand Up @@ -139,7 +142,6 @@ private FanOutStreamingEngineWorkerHarness(
this.getWorkBudgetDistributor = getWorkBudgetDistributor;
this.totalGetWorkBudget = totalGetWorkBudget;
this.activeMetadataVersion = Long.MIN_VALUE;
this.getWorkerMetadataStream = Suppliers.memoize(createGetWorkerMetadataStream()::get);
this.workCommitterFactory = workCommitterFactory;
this.metadataLock = new Object();
}
Expand Down Expand Up @@ -201,13 +203,16 @@ static FanOutStreamingEngineWorkerHarness forTesting(
@Override
public synchronized void start() {
Preconditions.checkState(!started, "StreamingEngineClient cannot start twice.");
// Starts the stream, this value is memoized.
getWorkerMetadataStream.get();
getWorkerMetadataStream =
streamFactory.createGetWorkerMetadataStream(
dispatcherClient.getWindmillMetadataServiceStubBlocking(),
getWorkerMetadataThrottleTimer,
this::consumeWorkerMetadata);
started = true;
}

public ImmutableSet<HostAndPort> currentWindmillEndpoints() {
return connections.get().windmillConnections().keySet().stream()
return connections.get().windmillStreams().keySet().stream()
.map(Endpoint::directEndpoint)
.filter(Optional::isPresent)
.map(Optional::get)
Expand Down Expand Up @@ -239,26 +244,11 @@ private GetDataStream getGlobalDataStream(String globalDataKey) {
@Override
public synchronized void shutdown() {
Preconditions.checkState(started, "StreamingEngineClient never started.");
getWorkerMetadataStream.get().halfClose();
Preconditions.checkNotNull(getWorkerMetadataStream).halfClose();
workerMetadataConsumer.shutdownNow();
channelCachingStubFactory.shutdown();
}

@SuppressWarnings("methodref.receiver.bound")
private Supplier<GetWorkerMetadataStream> createGetWorkerMetadataStream(
@UnderInitialization FanOutStreamingEngineWorkerHarness this) {
// Checker Framework complains about reference to "this" in the constructor since the instance
// is "UnderInitialization" here, which we pass as a lambda to GetWorkerMetadataStream for
// processing new worker metadata. Supplier.get() is only called in start(), after we have
// constructed the FanOutStreamingEngineWorkerHarness.
return () ->
checkNotNull(streamFactory)
.createGetWorkerMetadataStream(
checkNotNull(dispatcherClient).getWindmillMetadataServiceStubBlocking(),
checkNotNull(getWorkerMetadataThrottleTimer),
this::consumeWorkerMetadata);
}

private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
synchronized (metadataLock) {
// Only process versions greater than what we currently have to prevent double processing of
Expand All @@ -281,37 +271,31 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
}

long previousMetadataVersion = activeMetadataVersion;
LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints());
closeStaleStreams(newWindmillConnections.values(), connections.get().windmillStreams());
ImmutableMap<WindmillConnection, WindmillStreamSender> newStreams =
createAndStartNewStreams(newWindmillConnections.values()).join();
LOG.debug(
"Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}",
newWindmillEndpoints,
previousMetadataVersion,
activeMetadataVersion);
closeStaleStreams(
newWindmillEndpoints.windmillEndpoints(), connections.get().windmillStreams());
ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
StreamingEngineConnectionState newConnectionsState =
StreamingEngineConnectionState.builder()
.setWindmillConnections(newWindmillConnections)
.setWindmillStreams(newStreams)
.setGlobalDataStreams(
createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints()))
.build();
LOG.info(
"Setting new connections: {}. Previous connections: {}.",
newConnectionsState,
connections.get());
connections.set(newConnectionsState);
getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget);
activeMetadataVersion = newWindmillEndpoints.version();
LOG.info(
"Consumed new endpoints. previous metadata version: {}, current metadata version: {}",
previousMetadataVersion,
activeMetadataVersion);
}

/** Close the streams that are no longer valid asynchronously. */
@SuppressWarnings("FutureReturnValueIgnored")
private void closeStaleStreams(
Collection<WindmillConnection> newWindmillConnections,
ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams) {
Collection<Endpoint> newWindmillConnections,
ImmutableMap<Endpoint, WindmillStreamSender> currentStreams) {
currentStreams.entrySet().stream()
.filter(
connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey()))
Expand All @@ -334,27 +318,24 @@ private void closeStaleStreams(
windmillStreamManager));
}

private synchronized CompletableFuture<ImmutableMap<WindmillConnection, WindmillStreamSender>>
createAndStartNewStreams(Collection<WindmillConnection> newWindmillConnections) {
ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
private synchronized CompletableFuture<ImmutableMap<Endpoint, WindmillStreamSender>>
createAndStartNewStreams(Collection<Endpoint> newWindmillConnections) {
ImmutableMap<Endpoint, WindmillStreamSender> currentStreams =
connections.get().windmillStreams();
CompletionStage<List<Pair<WindmillConnection, WindmillStreamSender>>>
connectionAndSenderFuture =
MoreFutures.allAsList(
newWindmillConnections.stream()
.map(
connection ->
MoreFutures.supplyAsync(
() ->
Pair.of(
connection,
Optional.ofNullable(currentStreams.get(connection))
.orElseGet(
() ->
createAndStartWindmillStreamSender(
connection))),
windmillStreamManager))
.collect(Collectors.toList()));
CompletionStage<List<Pair<Endpoint, WindmillStreamSender>>> connectionAndSenderFuture =
MoreFutures.allAsList(
newWindmillConnections.stream()
.map(
connection ->
MoreFutures.supplyAsync(
() ->
Pair.of(
connection,
Optional.ofNullable(currentStreams.get(connection))
.orElseGet(
() -> createAndStartWindmillStreamSender(connection))),
windmillStreamManager))
.collect(Collectors.toList()));

return connectionAndSenderFuture
.thenApply(
Expand Down Expand Up @@ -384,23 +365,6 @@ StreamingEngineConnectionState getCurrentConnections() {
return connections.get();
}

private synchronized ImmutableMap<Endpoint, WindmillConnection> createNewWindmillConnections(
List<Endpoint> newWindmillEndpoints) {
ImmutableMap<Endpoint, WindmillConnection> currentConnections =
connections.get().windmillConnections();
return newWindmillEndpoints.stream()
.collect(
toImmutableMap(
Function.identity(),
endpoint ->
// Reuse existing stubs if they exist. Optional.orElseGet only calls the
// supplier if the value is not present, preventing constructing expensive
// objects.
Optional.ofNullable(currentConnections.get(endpoint))
.orElseGet(
() -> WindmillConnection.from(endpoint, this::createWindmillStub))));
}

private ImmutableMap<String, Supplier<GetDataStream>> createNewGlobalDataStreams(
ImmutableMap<String, Endpoint> newGlobalDataEndpoints) {
ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams =
Expand All @@ -421,19 +385,13 @@ private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
keyedEndpoint.getKey(),
() ->
streamFactory.createGetDataStream(
newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer())));
}

private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) {
return Optional.ofNullable(connections.get().windmillConnections().get(endpoint))
.map(WindmillConnection::stub)
.orElseGet(() -> createWindmillStub(endpoint));
createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer())));
}

private WindmillStreamSender createAndStartWindmillStreamSender(WindmillConnection connection) {
private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint connection) {
WindmillStreamSender windmillStreamSender =
WindmillStreamSender.create(
connection,
WindmillConnection.from(connection, this::createWindmillStub),
GetWorkRequest.newBuilder()
.setClientId(jobHeader.getClientId())
.setJobId(jobHeader.getJobId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import com.google.auto.value.AutoValue;
import java.util.function.Supplier;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
Expand All @@ -36,25 +35,18 @@ abstract class StreamingEngineConnectionState {

static Builder builder() {
return new AutoValue_StreamingEngineConnectionState.Builder()
.setWindmillConnections(ImmutableMap.of())
.setWindmillStreams(ImmutableMap.of())
.setGlobalDataStreams(ImmutableMap.of());
}

abstract ImmutableMap<Endpoint, WindmillConnection> windmillConnections();

abstract ImmutableMap<WindmillConnection, WindmillStreamSender> windmillStreams();
abstract ImmutableMap<Endpoint, WindmillStreamSender> windmillStreams();

/** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */
abstract ImmutableMap<String, Supplier<GetDataStream>> globalDataStreams();

@AutoValue.Builder
abstract static class Builder {
public abstract Builder setWindmillConnections(
ImmutableMap<Endpoint, WindmillConnection> value);

public abstract Builder setWindmillStreams(
ImmutableMap<WindmillConnection, WindmillStreamSender> value);
public abstract Builder setWindmillStreams(ImmutableMap<Endpoint, WindmillStreamSender> value);

public abstract Builder setGlobalDataStreams(
ImmutableMap<String, Supplier<GetDataStream>> value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,6 @@ public void setBudget(long items, long bytes) {
}
}

@Override
public GetWorkBudget remainingBudget() {
return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get();
}

long getAndResetThrottleTime() {
return streamingEngineThrottleTimers.getAndResetThrottleTime();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;

import com.google.auto.value.AutoValue;
import java.net.Inet6Address;
Expand All @@ -27,8 +27,8 @@
import java.util.Map;
import java.util.Optional;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -53,12 +53,12 @@ public static WindmillEndpoints from(
endpoint.getValue(),
workerMetadataResponseProto.getExternalEndpoint())));

ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
ImmutableSet<WindmillEndpoints.Endpoint> windmillServers =
workerMetadataResponseProto.getWorkEndpointsList().stream()
.map(
endpointProto ->
Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint()))
.collect(toImmutableList());
.collect(toImmutableSet());

return WindmillEndpoints.builder()
.setVersion(workerMetadataResponseProto.getMetadataVersion())
Expand Down Expand Up @@ -142,7 +142,7 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
* Windmill servers. Returns a list of endpoints used to communicate with the corresponding
* Windmill servers.
*/
public abstract ImmutableList<Endpoint> windmillEndpoints();
public abstract ImmutableSet<Endpoint> windmillEndpoints();

/**
* Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with
Expand Down Expand Up @@ -214,9 +214,9 @@ public abstract Builder setGlobalDataEndpoints(
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);

public abstract Builder setWindmillEndpoints(
ImmutableList<WindmillEndpoints.Endpoint> windmillServers);
ImmutableSet<WindmillEndpoints.Endpoint> windmillServers);

abstract ImmutableList.Builder<WindmillEndpoints.Endpoint> windmillEndpointsBuilder();
abstract ImmutableSet.Builder<WindmillEndpoints.Endpoint> windmillEndpointsBuilder();

public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) {
windmillEndpointsBuilder().add(endpoint);
Expand Down
Loading

0 comments on commit 06836da

Please sign in to comment.