Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Oct 21, 2024
1 parent 3f6888d commit e62c55a
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,23 +307,17 @@ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
connectionAndStream ->
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
.forEach(
entry -> {
CompletableFuture<Void> ignored =
CompletableFuture.runAsync(
() -> closeStreamSender(entry.getKey(), entry.getValue()),
windmillStreamManager);
});
entry ->
windmillStreamManager.execute(
() -> closeStreamSender(entry.getKey(), entry.getValue())));

Set<Endpoint> newGlobalDataEndpoints =
new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
currentBackends.globalDataStreams().values().stream()
.filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
.forEach(
sender -> {
CompletableFuture<Void> ignored =
CompletableFuture.runAsync(
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager);
});
sender ->
windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender)));
}

private void closeStreamSender(Endpoint endpoint, Closeable sender) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;

import com.google.auto.value.AutoValue;
import java.io.PrintWriter;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import javax.annotation.concurrent.GuardedBy;
import net.jcip.annotations.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
Expand Down Expand Up @@ -122,7 +121,7 @@ private GrpcDirectGetWorkStream(
this.getDataClient = getDataClient;
this.lastRequest = new AtomicReference<>();
this.budgetTracker =
GetWorkBudgetTracker.create(
new GetWorkBudgetTracker(
GetWorkBudget.builder()
.setItems(requestHeader.getMaxItems())
.setBytes(requestHeader.getMaxBytes())
Expand Down Expand Up @@ -229,18 +228,9 @@ protected boolean hasPendingRequests() {
public void appendSpecificHtml(PrintWriter writer) {
// Number of buffers is same as distinct workers that sent work on this stream.
writer.format(
"GetWorkStream: %d buffers, "
+ "max budget: %s, "
+ "in-flight budget: %s, "
+ "total budget requested: %s, "
+ "total budget received: %s,"
+ "last sent request: %s. ",
workItemAssemblers.size(),
budgetTracker.maxGetWorkBudget().get(),
budgetTracker.inFlightBudget(),
budgetTracker.totalRequestedBudget(),
budgetTracker.totalReceivedBudget(),
lastRequest.get());
"GetWorkStream: %d buffers, " + "last sent request: %s; ",
workItemAssemblers.size(), lastRequest.get());
writer.print(budgetTracker.debugString());
}

@Override
Expand Down Expand Up @@ -300,49 +290,57 @@ private void executeSafely(Runnable runnable) {
* extensions.
*/
@ThreadSafe
@AutoValue
abstract static class GetWorkBudgetTracker {

private static GetWorkBudgetTracker create(GetWorkBudget initialMaxGetWorkBudget) {
return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker(
new AtomicReference<>(initialMaxGetWorkBudget),
new AtomicLong(),
new AtomicLong(),
new AtomicLong(),
new AtomicLong());
}
private static final class GetWorkBudgetTracker {

@GuardedBy("GetWorkBudgetTracker.this")
private GetWorkBudget maxGetWorkBudget;

abstract AtomicReference<GetWorkBudget> maxGetWorkBudget();
@GuardedBy("GetWorkBudgetTracker.this")
private long itemsRequested = 0;

abstract AtomicLong itemsRequested();
@GuardedBy("GetWorkBudgetTracker.this")
private long bytesRequested = 0;

abstract AtomicLong bytesRequested();
@GuardedBy("GetWorkBudgetTracker.this")
private long itemsReceived = 0;

abstract AtomicLong itemsReceived();
@GuardedBy("GetWorkBudgetTracker.this")
private long bytesReceived = 0;

abstract AtomicLong bytesReceived();
private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) {
this.maxGetWorkBudget = maxGetWorkBudget;
}

private synchronized void reset() {
itemsRequested().set(0);
bytesRequested().set(0);
itemsReceived().set(0);
bytesReceived().set(0);
itemsRequested = 0;
bytesRequested = 0;
itemsReceived = 0;
bytesReceived = 0;
}

private synchronized String debugString() {
return String.format(
"max budget: %s; "
+ "in-flight budget: %s; "
+ "total budget requested: %s; "
+ "total budget received: %s.",
maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget());
}

/** Consumes the new budget and computes an extension based on the new budget. */
private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) {
maxGetWorkBudget().set(newBudget);
maxGetWorkBudget = newBudget;
return computeBudgetExtension();
}

private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) {
itemsRequested().addAndGet(budgetRequested.items());
bytesRequested().addAndGet(budgetRequested.bytes());
itemsRequested += budgetRequested.items();
bytesRequested += budgetRequested.bytes();
}

private synchronized void recordBudgetReceived(long bytesReceived) {
itemsReceived().incrementAndGet();
bytesReceived().addAndGet(bytesReceived);
private synchronized void recordBudgetReceived(long returnedBudget) {
itemsReceived++;
bytesReceived += returnedBudget;
}

/**
Expand All @@ -351,11 +349,10 @@ private synchronized void recordBudgetReceived(long bytesReceived) {
* without sending too many extension requests.
*/
private synchronized GetWorkBudget computeBudgetExtension() {
GetWorkBudget maxGetWorkBudget = maxGetWorkBudget().get();
// Expected items and bytes can go negative here, since WorkItems returned might be larger
// than the initially requested budget.
long inFlightItems = itemsRequested().get() - itemsReceived().get();
long inFlightBytes = bytesRequested().get() - bytesReceived().get();
long inFlightItems = itemsRequested - itemsReceived;
long inFlightBytes = bytesRequested - bytesReceived;

// Don't send negative budget extensions.
long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes);
Expand All @@ -366,25 +363,19 @@ private synchronized GetWorkBudget computeBudgetExtension() {
: GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build();
}

private GetWorkBudget inFlightBudget() {
private synchronized GetWorkBudget inFlightBudget() {
return GetWorkBudget.builder()
.setItems(itemsRequested().get() - itemsReceived().get())
.setBytes(bytesRequested().get() - bytesReceived().get())
.setItems(itemsRequested - itemsReceived)
.setBytes(bytesRequested - bytesReceived)
.build();
}

private GetWorkBudget totalRequestedBudget() {
return GetWorkBudget.builder()
.setItems(itemsRequested().get())
.setBytes(bytesRequested().get())
.build();
private synchronized GetWorkBudget totalRequestedBudget() {
return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build();
}

private GetWorkBudget totalReceivedBudget() {
return GetWorkBudget.builder()
.setItems(itemsReceived().get())
.setBytes(bytesReceived().get())
.build();
private synchronized GetWorkBudget totalReceivedBudget() {
return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ public <T extends GetWorkBudgetSpender> void distributeBudget(
return;
}

GetWorkBudget budgetPerStream = computeDesiredBudgets(budgetSpenders, getWorkBudget);
GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget);
budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream));
}

private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredBudgets(
private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredPerStreamBudget(
ImmutableCollection<T> streams, GetWorkBudget totalGetWorkBudget) {
return GetWorkBudget.builder()
.setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public void testStreamsStartCorrectly() throws InterruptedException {

getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());

StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends();

Expand Down Expand Up @@ -286,7 +286,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends();
assertEquals(1, currentBackends.windmillStreams().size());
Set<String> workerTokens =
Expand Down Expand Up @@ -334,19 +334,14 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce
getWorkerMetadataReady.await();

fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
getWorkBudgetDistributor.expectNumDistributions(1);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());

verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any());
}

private void waitForWorkerMetadataToBeConsumed(
TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException {
getWorkBudgetDistributor.waitForBudgetDistribution();
}

private static class GetWorkerMetadataTestStub
extends CloudWindmillMetadataServiceV1Alpha1Grpc
.CloudWindmillMetadataServiceV1Alpha1ImplBase {
Expand Down Expand Up @@ -395,9 +390,8 @@ private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) {
this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected);
}

@SuppressWarnings("ReturnValueIgnored")
private void waitForBudgetDistribution() throws InterruptedException {
getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
private boolean waitForBudgetDistribution() throws InterruptedException {
return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
}

private void expectNumDistributions(int numBudgetDistributionsExpected) {
Expand Down
Loading

0 comments on commit e62c55a

Please sign in to comment.