Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify budget distribution logic and new worker metadata consumption #32775

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -198,7 +199,6 @@ static FanOutStreamingEngineWorkerHarness forTesting(
return fanOutStreamingEngineWorkProvider;
}

@SuppressWarnings("ReturnValueIgnored")
@Override
public synchronized void start() {
Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice.");
Expand Down Expand Up @@ -234,9 +234,29 @@ private GetDataStream getGlobalDataStream(String globalDataKey) {
@Override
public synchronized void shutdown() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also shutdown the windmillStreamManager? (after possibly closing streams below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started.");
Preconditions.checkNotNull(getWorkerMetadataStream).halfClose();
Preconditions.checkNotNull(getWorkerMetadataStream).shutdown();
workerMetadataConsumer.shutdownNow();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this shutdown all the stream senders (perhaps could call closeStaleStreams(emptyBackends))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

closeStreamsNotIn(WindmillEndpoints.none());
channelCachingStubFactory.shutdown();

try {
Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e);
}

windmillStreamManager.shutdown();
boolean isStreamManagerShutdown = false;
try {
isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e);
}
if (!isStreamManagerShutdown) {
windmillStreamManager.shutdownNow();
}
}

private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
Expand Down Expand Up @@ -265,7 +285,7 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
newWindmillEndpoints,
activeMetadataVersion,
newWindmillEndpoints.version());
closeStaleStreams(newWindmillEndpoints);
closeStreamsNotIn(newWindmillEndpoints);
ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
StreamingEngineBackends newBackends =
Expand All @@ -280,29 +300,30 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
}

/** Close the streams that are no longer valid asynchronously. */
@SuppressWarnings("FutureReturnValueIgnored")
private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
StreamingEngineBackends currentBackends = backends.get();
ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
currentBackends.windmillStreams();
currentWindmillStreams.entrySet().stream()
currentBackends.windmillStreams().entrySet().stream()
.filter(
connectionAndStream ->
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
.forEach(
entry ->
CompletableFuture.runAsync(
() -> closeStreamSender(entry.getKey(), entry.getValue()),
windmillStreamManager));
entry -> {
CompletableFuture<Void> ignored =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this any different than just executing directly? if not it seems simpler to avoid the future.

windmillStreamManager.execute(
() -> closeStreamSender(entry.getKey(), entry.getValue()))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

CompletableFuture.runAsync(
() -> closeStreamSender(entry.getKey(), entry.getValue()),
windmillStreamManager);
});

Set<Endpoint> newGlobalDataEndpoints =
new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
currentBackends.globalDataStreams().values().stream()
.filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
.forEach(
sender ->
CompletableFuture.runAsync(
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager));
sender -> {
CompletableFuture<Void> ignored =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

CompletableFuture.runAsync(
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager);
});
}

private void closeStreamSender(Endpoint endpoint, Closeable sender) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
public abstract class WindmillEndpoints {
private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class);

public static WindmillEndpoints none() {
return WindmillEndpoints.builder()
.setVersion(Long.MAX_VALUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min seems safer. Otherwise if somehow none() was observed the logic to ensure version is increasing mean's we'd never process another endpoint set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.setWindmillEndpoints(ImmutableSet.of())
.setGlobalDataEndpoints(ImmutableMap.of())
.build();
}

public static WindmillEndpoints from(
Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ public interface WindmillStream {
@ThreadSafe
interface GetWorkStream extends WindmillStream {
/** Adjusts the {@link GetWorkBudget} for the stream. */
void setBudget(long newItems, long newBytes);
void setBudget(GetWorkBudget newBudget);

default void setBudget(GetWorkBudget newBudget) {
setBudget(newBudget.items(), newBudget.bytes());
default void setBudget(long newItems, long newBytes) {
setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import net.jcip.annotations.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
Expand Down Expand Up @@ -70,7 +71,6 @@ final class GrpcDirectGetWorkStream
.build())
.build();

private final AtomicReference<GetWorkBudget> maxGetWorkBudget;
private final GetWorkBudgetTracker budgetTracker;
private final GetWorkRequest requestHeader;
private final WorkItemScheduler workItemScheduler;
Expand Down Expand Up @@ -120,14 +120,13 @@ private GrpcDirectGetWorkStream(
this.heartbeatSender = heartbeatSender;
this.workCommitter = workCommitter;
this.getDataClient = getDataClient;
this.maxGetWorkBudget =
new AtomicReference<>(
this.lastRequest = new AtomicReference<>();
this.budgetTracker =
GetWorkBudgetTracker.create(
GetWorkBudget.builder()
.setItems(requestHeader.getMaxItems())
.setBytes(requestHeader.getMaxBytes())
.build());
this.lastRequest = new AtomicReference<>();
this.budgetTracker = GetWorkBudgetTracker.create();
}

static GrpcDirectGetWorkStream create(
Expand All @@ -146,19 +145,22 @@ static GrpcDirectGetWorkStream create(
GetDataClient getDataClient,
WorkCommitter workCommitter,
WorkItemScheduler workItemScheduler) {
return new GrpcDirectGetWorkStream(
backendWorkerToken,
startGetWorkRpcFn,
request,
backoff,
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
workCommitter,
workItemScheduler);
GrpcDirectGetWorkStream getWorkStream =
new GrpcDirectGetWorkStream(
backendWorkerToken,
startGetWorkRpcFn,
request,
backoff,
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
workCommitter,
workItemScheduler);
getWorkStream.startStream();
return getWorkStream;
}

private static Watermarks createWatermarks(
Expand Down Expand Up @@ -188,7 +190,11 @@ private void maybeSendRequestExtension(GetWorkBudget extension) {
.build();
lastRequest.set(request);
budgetTracker.recordBudgetRequested(extension);
send(request);
try {
send(request);
} catch (IllegalStateException e) {
// Stream was closed.
}
});
}
}
Expand All @@ -198,8 +204,7 @@ protected synchronized void onNewStream() {
workItemAssemblers.clear();
if (!isShutdown()) {
budgetTracker.reset();
GetWorkBudget initialGetWorkBudget =
budgetTracker.computeBudgetExtension(maxGetWorkBudget.get());
GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension();
StreamingGetWorkRequest request =
StreamingGetWorkRequest.newBuilder()
.setRequest(
Expand Down Expand Up @@ -231,7 +236,7 @@ public void appendSpecificHtml(PrintWriter writer) {
+ "total budget received: %s,"
+ "last sent request: %s. ",
workItemAssemblers.size(),
maxGetWorkBudget.get(),
budgetTracker.maxGetWorkBudget().get(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could move html generation into budgettracker and not need all the accessors. If we change how the tracker works in the future we might want to show more too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

budgetTracker.inFlightBudget(),
budgetTracker.totalRequestedBudget(),
budgetTracker.totalReceivedBudget(),
Expand Down Expand Up @@ -262,7 +267,7 @@ private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
createProcessingContext(metadata.computationId()),
assembledWorkItem.latencyAttributions());
budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize());
GetWorkBudget extension = budgetTracker.computeBudgetExtension(maxGetWorkBudget.get());
GetWorkBudget extension = budgetTracker.computeBudgetExtension();
maybeSendRequestExtension(extension);
}

Expand All @@ -277,26 +282,38 @@ protected void startThrottleTimer() {
}

@Override
public void setBudget(long newItems, long newBytes) {
GetWorkBudget currentMaxGetWorkBudget =
maxGetWorkBudget.updateAndGet(
ignored -> GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
GetWorkBudget extension = budgetTracker.computeBudgetExtension(currentMaxGetWorkBudget);
public void setBudget(GetWorkBudget newBudget) {
GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget);
maybeSendRequestExtension(extension);
}

private void executeSafely(Runnable runnable) {
try {
executor().execute(runnable);
} catch (RejectedExecutionException e) {
LOG.debug("{} has been shutdown.", getClass());
}
}

/**
* Tracks sent and received GetWorkBudget and uses this information to generate request
* Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request
* extensions.
*/
@ThreadSafe
@AutoValue
abstract static class GetWorkBudgetTracker {

private static GetWorkBudgetTracker create() {
private static GetWorkBudgetTracker create(GetWorkBudget initialMaxGetWorkBudget) {
return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker(
new AtomicLong(), new AtomicLong(), new AtomicLong(), new AtomicLong());
new AtomicReference<>(initialMaxGetWorkBudget),
new AtomicLong(),
new AtomicLong(),
new AtomicLong(),
new AtomicLong());
}

abstract AtomicReference<GetWorkBudget> maxGetWorkBudget();

abstract AtomicLong itemsRequested();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just using synchronized instead of lots of separate atomics? Multiple atomic ops might be worse performance anyway and it means we might have weird races where they are inconsistently updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done also added tests for this class

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the members be changed to just raw longs/objects? The accessors just need to be synchronized as well.

Seems like this could be easier without autovalue since we don't need the accessors eather.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


abstract AtomicLong bytesRequested();
Expand All @@ -305,19 +322,25 @@ private static GetWorkBudgetTracker create() {

abstract AtomicLong bytesReceived();

private void reset() {
private synchronized void reset() {
itemsRequested().set(0);
bytesRequested().set(0);
itemsReceived().set(0);
bytesReceived().set(0);
}

private void recordBudgetRequested(GetWorkBudget budgetRequested) {
/** Consumes the new budget and computes an extension based on the new budget. */
private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) {
maxGetWorkBudget().set(newBudget);
return computeBudgetExtension();
}

private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) {
itemsRequested().addAndGet(budgetRequested.items());
bytesRequested().addAndGet(budgetRequested.bytes());
}

private void recordBudgetReceived(long bytesReceived) {
private synchronized void recordBudgetReceived(long bytesReceived) {
itemsReceived().incrementAndGet();
bytesReceived().addAndGet(bytesReceived);
}
Expand All @@ -327,7 +350,8 @@ private void recordBudgetReceived(long bytesReceived) {
* GetWorkExtension. The goal is to keep the limits relatively close to their maximum values
* without sending too many extension requests.
*/
private GetWorkBudget computeBudgetExtension(GetWorkBudget maxGetWorkBudget) {
private synchronized GetWorkBudget computeBudgetExtension() {
GetWorkBudget maxGetWorkBudget = maxGetWorkBudget().get();
// Expected items and bytes can go negative here, since WorkItems returned might be larger
// than the initially requested budget.
long inFlightItems = itemsRequested().get() - itemsReceived().get();
Expand Down Expand Up @@ -363,14 +387,4 @@ private GetWorkBudget totalReceivedBudget() {
.build();
}
}

private void executeSafely(Runnable runnable) {
try {
executor().execute(runnable);
} catch (RejectedExecutionException e) {
LOG.debug("{} has been shutdown.", getClass());
} catch (IllegalStateException e) {
// Stream was closed.
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class GrpcGetDataStream
final class GrpcGetDataStream
extends AbstractWindmillStream<StreamingGetDataRequest, StreamingGetDataResponse>
implements GetDataStream {
private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;

Expand Down Expand Up @@ -193,7 +194,7 @@ protected void startThrottleTimer() {
}

@Override
public void setBudget(long newItems, long newBytes) {
public void setBudget(GetWorkBudget newBudget) {
// no-op
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static ManagedChannel remoteChannel(
windmillServiceRpcChannelTimeoutSec);
default:
throw new UnsupportedOperationException(
"Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported"
"Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported"
+ " WindmillServiceAddresses.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ public <T extends GetWorkBudgetSpender> void distributeBudget(

private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredBudgets(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe name computeDesiredPerStreamBudget? or just inline?

budgets makes it sound like it is computing multiple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ImmutableCollection<T> streams, GetWorkBudget totalGetWorkBudget) {
// TODO: Fix possibly non-deterministic handing out of budgets.
// Rounding up here will drift upwards over the lifetime of the streams.
return GetWorkBudget.builder()
.setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING))
.setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING))
Expand Down
Loading
Loading