Skip to content

Commit

Permalink
Tune maximum thread count for streaming dataflow worker executor dyna…
Browse files Browse the repository at this point in the history
…mically. (apache#30439)

Workers will read the StreamignScalingReportResponse from worker messages and
configure the executor pool size based on the specified value.
  • Loading branch information
MelodyShen authored Apr 5, 2024
1 parent 6c280c6 commit f437a78
Show file tree
Hide file tree
Showing 8 changed files with 503 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ class BeamModulePlugin implements Plugin<Project> {
google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version
google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20240124-2.0.0", // [bomupgrader] sets version
google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240128-2.0.0", // [bomupgrader] sets version
google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240113-$google_clients_version",
google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240218-$google_clients_version",
google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version",
google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version",
google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20240205-2.0.0", // [bomupgrader] sets version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import com.google.api.services.dataflow.model.WorkItemServiceState;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -312,7 +313,8 @@ public WorkerMessage createWorkerMessageFromPerWorkerMetrics(PerWorkerMetrics re
* perworkermetrics with this path.
*/
@Override
public void reportWorkerMessage(List<WorkerMessage> messages) throws IOException {
public List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage> messages)
throws IOException {
SendWorkerMessagesRequest request =
new SendWorkerMessagesRequest()
.setLocation(options.getRegion())
Expand All @@ -327,6 +329,10 @@ public void reportWorkerMessage(List<WorkerMessage> messages) throws IOException
logger.warn("Worker Message response is null");
throw new IOException("Got null Worker Message response");
}
// Currently no response is expected
if (result.getWorkerMessageResponses() == null) {
logger.debug("Worker Message response is empty.");
return Collections.emptyList();
}
return result.getWorkerMessageResponses();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.api.services.dataflow.model.WorkItemServiceState;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -75,6 +76,7 @@ public interface WorkUnitClient {
* perworkermetrics with this path.
*
* @param msg the WorkerMessages to report
* @return a list of {@link WorkerMessageResponse}
*/
void reportWorkerMessage(List<WorkerMessage> messages) throws IOException;
List<WorkerMessageResponse> reportWorkerMessage(List<WorkerMessage> messages) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import com.google.api.services.dataflow.model.PerStepNamespaceMetrics;
import com.google.api.services.dataflow.model.PerWorkerMetrics;
import com.google.api.services.dataflow.model.StreamingScalingReport;
import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
import com.google.api.services.dataflow.model.WorkItemStatus;
import com.google.api.services.dataflow.model.WorkerMessage;
import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -34,6 +36,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -70,6 +73,8 @@ public final class StreamingWorkerStatusReporter {
private static final String GLOBAL_WORKER_UPDATE_REPORTER_THREAD = "GlobalWorkerUpdates";

private final boolean publishCounters;
private final int initialMaxThreadCount;
private final int initialMaxBundlesOutstanding;
private final WorkUnitClient dataflowServiceClient;
private final Supplier<Long> windmillQuotaThrottleTime;
private final Supplier<Collection<StageInfo>> allStageInfo;
Expand All @@ -78,6 +83,7 @@ public final class StreamingWorkerStatusReporter {
private final MemoryMonitor memoryMonitor;
private final BoundedQueueExecutor workExecutor;
private final AtomicLong previousTimeAtMaxThreads;
private final AtomicInteger maxThreadCountOverride;
private final ScheduledExecutorService globalWorkerUpdateReporter;
private final ScheduledExecutorService workerMessageReporter;

Expand All @@ -99,7 +105,10 @@ private StreamingWorkerStatusReporter(
this.streamingCounters = streamingCounters;
this.memoryMonitor = memoryMonitor;
this.workExecutor = workExecutor;
this.initialMaxThreadCount = workExecutor.getMaximumPoolSize();
this.initialMaxBundlesOutstanding = workExecutor.maximumElementsOutstanding();
this.previousTimeAtMaxThreads = new AtomicLong();
this.maxThreadCountOverride = new AtomicInteger();
this.globalWorkerUpdateReporter = executorFactory.apply(GLOBAL_WORKER_UPDATE_REPORTER_THREAD);
this.workerMessageReporter = executorFactory.apply(WORKER_MESSAGE_REPORTER_THREAD);
}
Expand Down Expand Up @@ -299,9 +308,12 @@ private void sendWorkerUpdatesToDataflowService(
}
}

private void reportPeriodicWorkerMessage() {
@VisibleForTesting
public void reportPeriodicWorkerMessage() {
try {
dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
List<WorkerMessageResponse> workerMessageResponses =
dataflowServiceClient.reportWorkerMessage(createWorkerMessage());
readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(workerMessageResponses);
} catch (IOException e) {
LOG.warn("Failed to send worker messages", e);
} catch (Exception e) {
Expand Down Expand Up @@ -346,6 +358,47 @@ private Optional<WorkerMessage> createWorkerMessageForPerWorkerMetrics() {
dataflowServiceClient.createWorkerMessageFromPerWorkerMetrics(perWorkerMetrics));
}

private void readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(
List<WorkerMessageResponse> responses) {
Optional<StreamingScalingReportResponse> streamingScalingReportResponse = Optional.empty();
for (WorkerMessageResponse response : responses) {
if (response.getStreamingScalingReportResponse() != null) {
streamingScalingReportResponse = Optional.of(response.getStreamingScalingReportResponse());
}
}
if (streamingScalingReportResponse.isPresent()) {
int oldMaximumThreadCount = getMaxThreads();
maxThreadCountOverride.set(streamingScalingReportResponse.get().getMaximumThreadCount());
int newMaximumThreadCount = getMaxThreads();
if (newMaximumThreadCount != oldMaximumThreadCount) {
LOG.info(
"Setting maximum thread count to {}, old value is {}",
newMaximumThreadCount,
oldMaximumThreadCount);
workExecutor.setMaximumPoolSize(newMaximumThreadCount, getMaxBundlesOutstanding());
}
}
}

private int getMaxThreads() {
int currentMaxThreadCountOverride = maxThreadCountOverride.get();
if (currentMaxThreadCountOverride != 0) {
return currentMaxThreadCountOverride;
}
return initialMaxThreadCount;
}

private int getMaxBundlesOutstanding() {
int currentMaxThreadCountOverride = maxThreadCountOverride.get();
if (currentMaxThreadCountOverride != 0) {
return currentMaxThreadCountOverride + 100;
}
if (initialMaxBundlesOutstanding > 0) {
return initialMaxBundlesOutstanding;
}
return getMaxThreads() + 100;
}

@VisibleForTesting
public void reportPeriodicWorkerUpdates() {
updateVMMetrics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard;
Expand All @@ -32,15 +32,26 @@
})
public class BoundedQueueExecutor {
private final ThreadPoolExecutor executor;
private final int maximumElementsOutstanding;
private final long maximumBytesOutstanding;
private final int maximumPoolSize;

// Used to guard elementsOutstanding and bytesOutstanding.
private final Monitor monitor = new Monitor();
private int elementsOutstanding = 0;
private long bytesOutstanding = 0;
private final AtomicInteger activeCount = new AtomicInteger();

@GuardedBy("this")
private int maximumElementsOutstanding;

@GuardedBy("this")
private int activeCount;

@GuardedBy("this")
private int maximumPoolSize;

@GuardedBy("this")
private long startTimeMaxActiveThreadsUsed;

@GuardedBy("this")
private long totalTimeMaxActiveThreadsUsed;

public BoundedQueueExecutor(
Expand All @@ -62,8 +73,8 @@ public BoundedQueueExecutor(
@Override
protected void beforeExecute(Thread t, Runnable r) {
super.beforeExecute(t, r);
synchronized (this) {
if (activeCount.getAndIncrement() >= maximumPoolSize - 1) {
synchronized (BoundedQueueExecutor.this) {
if (++activeCount >= maximumPoolSize && startTimeMaxActiveThreadsUsed == 0) {
startTimeMaxActiveThreadsUsed = System.currentTimeMillis();
}
}
Expand All @@ -72,8 +83,8 @@ protected void beforeExecute(Thread t, Runnable r) {
@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
synchronized (this) {
if (activeCount.getAndDecrement() == maximumPoolSize) {
synchronized (BoundedQueueExecutor.this) {
if (--activeCount < maximumPoolSize && startTimeMaxActiveThreadsUsed > 0) {
totalTimeMaxActiveThreadsUsed +=
(System.currentTimeMillis() - startTimeMaxActiveThreadsUsed);
startTimeMaxActiveThreadsUsed = 0;
Expand All @@ -95,16 +106,31 @@ public void execute(Runnable work, long workBytes) {
public boolean isSatisfied() {
return elementsOutstanding == 0
|| (bytesAvailable() >= workBytes
&& elementsOutstanding < maximumElementsOutstanding);
&& elementsOutstanding < maximumElementsOutstanding());
}
});
executeLockHeld(work, workBytes);
executeMonitorHeld(work, workBytes);
}

// Forcibly add something to the queue, ignoring the length limit.
public void forceExecute(Runnable work, long workBytes) {
monitor.enter();
executeLockHeld(work, workBytes);
executeMonitorHeld(work, workBytes);
}

// Set the maximum/core pool size of the executor.
public synchronized void setMaximumPoolSize(int maximumPoolSize, int maximumElementsOutstanding) {
// For ThreadPoolExecutor, the maximum pool size should always greater than or equal to core
// pool size.
if (maximumPoolSize > executor.getCorePoolSize()) {
executor.setMaximumPoolSize(maximumPoolSize);
executor.setCorePoolSize(maximumPoolSize);
} else {
executor.setCorePoolSize(maximumPoolSize);
executor.setMaximumPoolSize(maximumPoolSize);
}
this.maximumPoolSize = maximumPoolSize;
this.maximumElementsOutstanding = maximumElementsOutstanding;
}

public void shutdown() throws InterruptedException {
Expand All @@ -118,31 +144,41 @@ public boolean executorQueueIsEmpty() {
return executor.getQueue().isEmpty();
}

public long allThreadsActiveTime() {
public synchronized long allThreadsActiveTime() {
return totalTimeMaxActiveThreadsUsed;
}

public int activeCount() {
return activeCount.intValue();
public synchronized int activeCount() {
return activeCount;
}

public long bytesOutstanding() {
return bytesOutstanding;
monitor.enter();
try {
return bytesOutstanding;
} finally {
monitor.leave();
}
}

public int elementsOutstanding() {
return elementsOutstanding;
monitor.enter();
try {
return elementsOutstanding;
} finally {
monitor.leave();
}
}

public long maximumBytesOutstanding() {
return maximumBytesOutstanding;
}

public int maximumElementsOutstanding() {
public synchronized int maximumElementsOutstanding() {
return maximumElementsOutstanding;
}

public final int getMaximumPoolSize() {
public synchronized int getMaximumPoolSize() {
return maximumPoolSize;
}

Expand All @@ -163,7 +199,7 @@ public String summaryHtml() {
builder.append("Work Queue Size: ");
builder.append(elementsOutstanding);
builder.append("/");
builder.append(maximumElementsOutstanding);
builder.append(maximumElementsOutstanding());
builder.append("<br>/n");

builder.append("Work Queue Bytes: ");
Expand All @@ -178,7 +214,7 @@ public String summaryHtml() {
}
}

private void executeLockHeld(Runnable work, long workBytes) {
private void executeMonitorHeld(Runnable work, long workBytes) {
bytesOutstanding += workBytes;
++elementsOutstanding;
monitor.leave();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
import com.google.api.services.dataflow.model.SendWorkerMessagesResponse;
import com.google.api.services.dataflow.model.SeqMapTask;
import com.google.api.services.dataflow.model.StreamingScalingReport;
import com.google.api.services.dataflow.model.StreamingScalingReportResponse;
import com.google.api.services.dataflow.model.WorkItem;
import com.google.api.services.dataflow.model.WorkerMessage;
import com.google.api.services.dataflow.model.WorkerMessageResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC;
Expand Down Expand Up @@ -253,6 +256,12 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception {
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.setContentType(Json.MEDIA_TYPE);
SendWorkerMessagesResponse workerMessage = new SendWorkerMessagesResponse();
StreamingScalingReportResponse streamingScalingReportResponse =
new StreamingScalingReportResponse().setMaximumThreadCount(10);
WorkerMessageResponse workerMessageResponse =
new WorkerMessageResponse()
.setStreamingScalingReportResponse(streamingScalingReportResponse);
workerMessage.setWorkerMessageResponses(Collections.singletonList(workerMessageResponse));
workerMessage.setFactory(Transport.getJsonFactory());
response.setContent(workerMessage.toPrettyString());

Expand All @@ -271,12 +280,14 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception {
.setMaximumBundleCount(5)
.setMaximumBytes(6L);
WorkerMessage msg = client.createWorkerMessageFromStreamingScalingReport(activeThreadsReport);
client.reportWorkerMessage(Collections.singletonList(msg));
List<WorkerMessageResponse> responses =
client.reportWorkerMessage(Collections.singletonList(msg));

SendWorkerMessagesRequest actualRequest =
Transport.getJsonFactory()
.fromString(request.getContentAsString(), SendWorkerMessagesRequest.class);
assertEquals(ImmutableList.of(msg), actualRequest.getWorkerMessages());
assertEquals(ImmutableList.of(workerMessageResponse), responses);
}

@Test
Expand Down
Loading

0 comments on commit f437a78

Please sign in to comment.