diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 8be8d73fbcb8..d4a163256413 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -732,7 +732,7 @@ class BeamModulePlugin implements Plugin { google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20240124-2.0.0", // [bomupgrader] sets version google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240128-2.0.0", // [bomupgrader] sets version - google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240113-$google_clients_version", + google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240218-$google_clients_version", google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version", google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version", google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20240205-2.0.0", // [bomupgrader] sets version diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java index f3caa8d0f3ac..af8e7dd50c95 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClient.java @@ -39,6 +39,7 @@ import com.google.api.services.dataflow.model.WorkItemServiceState; import com.google.api.services.dataflow.model.WorkItemStatus; import com.google.api.services.dataflow.model.WorkerMessage; +import com.google.api.services.dataflow.model.WorkerMessageResponse; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -312,7 +313,8 @@ public WorkerMessage createWorkerMessageFromPerWorkerMetrics(PerWorkerMetrics re * perworkermetrics with this path. */ @Override - public void reportWorkerMessage(List messages) throws IOException { + public List reportWorkerMessage(List messages) + throws IOException { SendWorkerMessagesRequest request = new SendWorkerMessagesRequest() .setLocation(options.getRegion()) @@ -327,6 +329,10 @@ public void reportWorkerMessage(List messages) throws IOException logger.warn("Worker Message response is null"); throw new IOException("Got null Worker Message response"); } - // Currently no response is expected + if (result.getWorkerMessageResponses() == null) { + logger.debug("Worker Message response is empty."); + return Collections.emptyList(); + } + return result.getWorkerMessageResponses(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java index d75d91d00885..26b1dc55ead9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkUnitClient.java @@ -23,6 +23,7 @@ import com.google.api.services.dataflow.model.WorkItemServiceState; import com.google.api.services.dataflow.model.WorkItemStatus; import com.google.api.services.dataflow.model.WorkerMessage; +import com.google.api.services.dataflow.model.WorkerMessageResponse; import java.io.IOException; import java.util.List; import java.util.Optional; @@ -75,6 +76,7 @@ public interface WorkUnitClient { * perworkermetrics with this path. * * @param msg the WorkerMessages to report + * @return a list of {@link WorkerMessageResponse} */ - void reportWorkerMessage(List messages) throws IOException; + List reportWorkerMessage(List messages) throws IOException; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java index 409f0337eebd..8e950546ae68 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java @@ -23,8 +23,10 @@ import com.google.api.services.dataflow.model.PerStepNamespaceMetrics; import com.google.api.services.dataflow.model.PerWorkerMetrics; import com.google.api.services.dataflow.model.StreamingScalingReport; +import com.google.api.services.dataflow.model.StreamingScalingReportResponse; import com.google.api.services.dataflow.model.WorkItemStatus; import com.google.api.services.dataflow.model.WorkerMessage; +import com.google.api.services.dataflow.model.WorkerMessageResponse; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -34,6 +36,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.function.Supplier; @@ -70,6 +73,8 @@ public final class StreamingWorkerStatusReporter { private static final String GLOBAL_WORKER_UPDATE_REPORTER_THREAD = "GlobalWorkerUpdates"; private final boolean publishCounters; + private final int initialMaxThreadCount; + private final int initialMaxBundlesOutstanding; private final WorkUnitClient dataflowServiceClient; private final Supplier windmillQuotaThrottleTime; private final Supplier> allStageInfo; @@ -78,6 +83,7 @@ public final class StreamingWorkerStatusReporter { private final MemoryMonitor memoryMonitor; private final BoundedQueueExecutor workExecutor; private final AtomicLong previousTimeAtMaxThreads; + private final AtomicInteger maxThreadCountOverride; private final ScheduledExecutorService globalWorkerUpdateReporter; private final ScheduledExecutorService workerMessageReporter; @@ -99,7 +105,10 @@ private StreamingWorkerStatusReporter( this.streamingCounters = streamingCounters; this.memoryMonitor = memoryMonitor; this.workExecutor = workExecutor; + this.initialMaxThreadCount = workExecutor.getMaximumPoolSize(); + this.initialMaxBundlesOutstanding = workExecutor.maximumElementsOutstanding(); this.previousTimeAtMaxThreads = new AtomicLong(); + this.maxThreadCountOverride = new AtomicInteger(); this.globalWorkerUpdateReporter = executorFactory.apply(GLOBAL_WORKER_UPDATE_REPORTER_THREAD); this.workerMessageReporter = executorFactory.apply(WORKER_MESSAGE_REPORTER_THREAD); } @@ -299,9 +308,12 @@ private void sendWorkerUpdatesToDataflowService( } } - private void reportPeriodicWorkerMessage() { + @VisibleForTesting + public void reportPeriodicWorkerMessage() { try { - dataflowServiceClient.reportWorkerMessage(createWorkerMessage()); + List workerMessageResponses = + dataflowServiceClient.reportWorkerMessage(createWorkerMessage()); + readAndSaveWorkerMessageResponseForStreamingScalingReportResponse(workerMessageResponses); } catch (IOException e) { LOG.warn("Failed to send worker messages", e); } catch (Exception e) { @@ -346,6 +358,47 @@ private Optional createWorkerMessageForPerWorkerMetrics() { dataflowServiceClient.createWorkerMessageFromPerWorkerMetrics(perWorkerMetrics)); } + private void readAndSaveWorkerMessageResponseForStreamingScalingReportResponse( + List responses) { + Optional streamingScalingReportResponse = Optional.empty(); + for (WorkerMessageResponse response : responses) { + if (response.getStreamingScalingReportResponse() != null) { + streamingScalingReportResponse = Optional.of(response.getStreamingScalingReportResponse()); + } + } + if (streamingScalingReportResponse.isPresent()) { + int oldMaximumThreadCount = getMaxThreads(); + maxThreadCountOverride.set(streamingScalingReportResponse.get().getMaximumThreadCount()); + int newMaximumThreadCount = getMaxThreads(); + if (newMaximumThreadCount != oldMaximumThreadCount) { + LOG.info( + "Setting maximum thread count to {}, old value is {}", + newMaximumThreadCount, + oldMaximumThreadCount); + workExecutor.setMaximumPoolSize(newMaximumThreadCount, getMaxBundlesOutstanding()); + } + } + } + + private int getMaxThreads() { + int currentMaxThreadCountOverride = maxThreadCountOverride.get(); + if (currentMaxThreadCountOverride != 0) { + return currentMaxThreadCountOverride; + } + return initialMaxThreadCount; + } + + private int getMaxBundlesOutstanding() { + int currentMaxThreadCountOverride = maxThreadCountOverride.get(); + if (currentMaxThreadCountOverride != 0) { + return currentMaxThreadCountOverride + 100; + } + if (initialMaxBundlesOutstanding > 0) { + return initialMaxBundlesOutstanding; + } + return getMaxThreads() + 100; + } + @VisibleForTesting public void reportPeriodicWorkerUpdates() { updateVMMetrics(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index f7f6fd91a8c8..9a4811693500 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -21,7 +21,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; @@ -32,15 +32,26 @@ }) public class BoundedQueueExecutor { private final ThreadPoolExecutor executor; - private final int maximumElementsOutstanding; private final long maximumBytesOutstanding; - private final int maximumPoolSize; + // Used to guard elementsOutstanding and bytesOutstanding. private final Monitor monitor = new Monitor(); private int elementsOutstanding = 0; private long bytesOutstanding = 0; - private final AtomicInteger activeCount = new AtomicInteger(); + + @GuardedBy("this") + private int maximumElementsOutstanding; + + @GuardedBy("this") + private int activeCount; + + @GuardedBy("this") + private int maximumPoolSize; + + @GuardedBy("this") private long startTimeMaxActiveThreadsUsed; + + @GuardedBy("this") private long totalTimeMaxActiveThreadsUsed; public BoundedQueueExecutor( @@ -62,8 +73,8 @@ public BoundedQueueExecutor( @Override protected void beforeExecute(Thread t, Runnable r) { super.beforeExecute(t, r); - synchronized (this) { - if (activeCount.getAndIncrement() >= maximumPoolSize - 1) { + synchronized (BoundedQueueExecutor.this) { + if (++activeCount >= maximumPoolSize && startTimeMaxActiveThreadsUsed == 0) { startTimeMaxActiveThreadsUsed = System.currentTimeMillis(); } } @@ -72,8 +83,8 @@ protected void beforeExecute(Thread t, Runnable r) { @Override protected void afterExecute(Runnable r, Throwable t) { super.afterExecute(r, t); - synchronized (this) { - if (activeCount.getAndDecrement() == maximumPoolSize) { + synchronized (BoundedQueueExecutor.this) { + if (--activeCount < maximumPoolSize && startTimeMaxActiveThreadsUsed > 0) { totalTimeMaxActiveThreadsUsed += (System.currentTimeMillis() - startTimeMaxActiveThreadsUsed); startTimeMaxActiveThreadsUsed = 0; @@ -95,16 +106,31 @@ public void execute(Runnable work, long workBytes) { public boolean isSatisfied() { return elementsOutstanding == 0 || (bytesAvailable() >= workBytes - && elementsOutstanding < maximumElementsOutstanding); + && elementsOutstanding < maximumElementsOutstanding()); } }); - executeLockHeld(work, workBytes); + executeMonitorHeld(work, workBytes); } // Forcibly add something to the queue, ignoring the length limit. public void forceExecute(Runnable work, long workBytes) { monitor.enter(); - executeLockHeld(work, workBytes); + executeMonitorHeld(work, workBytes); + } + + // Set the maximum/core pool size of the executor. + public synchronized void setMaximumPoolSize(int maximumPoolSize, int maximumElementsOutstanding) { + // For ThreadPoolExecutor, the maximum pool size should always greater than or equal to core + // pool size. + if (maximumPoolSize > executor.getCorePoolSize()) { + executor.setMaximumPoolSize(maximumPoolSize); + executor.setCorePoolSize(maximumPoolSize); + } else { + executor.setCorePoolSize(maximumPoolSize); + executor.setMaximumPoolSize(maximumPoolSize); + } + this.maximumPoolSize = maximumPoolSize; + this.maximumElementsOutstanding = maximumElementsOutstanding; } public void shutdown() throws InterruptedException { @@ -118,31 +144,41 @@ public boolean executorQueueIsEmpty() { return executor.getQueue().isEmpty(); } - public long allThreadsActiveTime() { + public synchronized long allThreadsActiveTime() { return totalTimeMaxActiveThreadsUsed; } - public int activeCount() { - return activeCount.intValue(); + public synchronized int activeCount() { + return activeCount; } public long bytesOutstanding() { - return bytesOutstanding; + monitor.enter(); + try { + return bytesOutstanding; + } finally { + monitor.leave(); + } } public int elementsOutstanding() { - return elementsOutstanding; + monitor.enter(); + try { + return elementsOutstanding; + } finally { + monitor.leave(); + } } public long maximumBytesOutstanding() { return maximumBytesOutstanding; } - public int maximumElementsOutstanding() { + public synchronized int maximumElementsOutstanding() { return maximumElementsOutstanding; } - public final int getMaximumPoolSize() { + public synchronized int getMaximumPoolSize() { return maximumPoolSize; } @@ -163,7 +199,7 @@ public String summaryHtml() { builder.append("Work Queue Size: "); builder.append(elementsOutstanding); builder.append("/"); - builder.append(maximumElementsOutstanding); + builder.append(maximumElementsOutstanding()); builder.append("
/n"); builder.append("Work Queue Bytes: "); @@ -178,7 +214,7 @@ public String summaryHtml() { } } - private void executeLockHeld(Runnable work, long workBytes) { + private void executeMonitorHeld(Runnable work, long workBytes) { bytesOutstanding += workBytes; ++elementsOutstanding; monitor.leave(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java index fac56890f498..85d79e6be3c1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DataflowWorkUnitClientTest.java @@ -34,10 +34,13 @@ import com.google.api.services.dataflow.model.SendWorkerMessagesResponse; import com.google.api.services.dataflow.model.SeqMapTask; import com.google.api.services.dataflow.model.StreamingScalingReport; +import com.google.api.services.dataflow.model.StreamingScalingReportResponse; import com.google.api.services.dataflow.model.WorkItem; import com.google.api.services.dataflow.model.WorkerMessage; +import com.google.api.services.dataflow.model.WorkerMessageResponse; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; @@ -253,6 +256,12 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception { MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); response.setContentType(Json.MEDIA_TYPE); SendWorkerMessagesResponse workerMessage = new SendWorkerMessagesResponse(); + StreamingScalingReportResponse streamingScalingReportResponse = + new StreamingScalingReportResponse().setMaximumThreadCount(10); + WorkerMessageResponse workerMessageResponse = + new WorkerMessageResponse() + .setStreamingScalingReportResponse(streamingScalingReportResponse); + workerMessage.setWorkerMessageResponses(Collections.singletonList(workerMessageResponse)); workerMessage.setFactory(Transport.getJsonFactory()); response.setContent(workerMessage.toPrettyString()); @@ -271,12 +280,14 @@ public void testReportWorkerMessage_streamingScalingReport() throws Exception { .setMaximumBundleCount(5) .setMaximumBytes(6L); WorkerMessage msg = client.createWorkerMessageFromStreamingScalingReport(activeThreadsReport); - client.reportWorkerMessage(Collections.singletonList(msg)); + List responses = + client.reportWorkerMessage(Collections.singletonList(msg)); SendWorkerMessagesRequest actualRequest = Transport.getJsonFactory() .fromString(request.getContentAsString(), SendWorkerMessagesRequest.class); assertEquals(ImmutableList.of(msg), actualRequest.getWorkerMessages()); + assertEquals(ImmutableList.of(workerMessageResponse), responses); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java new file mode 100644 index 000000000000..bdf0f0031d69 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.services.dataflow.model.StreamingScalingReportResponse; +import com.google.api.services.dataflow.model.WorkerMessageResponse; +import java.util.Collections; +import java.util.concurrent.Executors; +import org.apache.beam.runners.dataflow.worker.WorkUnitClient; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class StreamingWorkerStatusReporterTest { + private final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000; + + private BoundedQueueExecutor mockExecutor; + private WorkUnitClient mockWorkUnitClient; + private FailureTracker mockFailureTracker; + private MemoryMonitor mockMemoryMonitor; + + @Before + public void setUp() { + this.mockExecutor = mock(BoundedQueueExecutor.class); + this.mockWorkUnitClient = mock(WorkUnitClient.class); + this.mockFailureTracker = mock(FailureTracker.class); + this.mockMemoryMonitor = mock(MemoryMonitor.class); + } + + @Test + public void testOverrideMaximumThreadCount() throws Exception { + StreamingWorkerStatusReporter reporter = + StreamingWorkerStatusReporter.forTesting( + true, + mockWorkUnitClient, + () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, + () -> Collections.emptyList(), + mockFailureTracker, + StreamingCounters.create(), + mockMemoryMonitor, + mockExecutor, + (threadName) -> Executors.newSingleThreadScheduledExecutor()); + StreamingScalingReportResponse streamingScalingReportResponse = + new StreamingScalingReportResponse().setMaximumThreadCount(10); + WorkerMessageResponse workerMessageResponse = + new WorkerMessageResponse() + .setStreamingScalingReportResponse(streamingScalingReportResponse); + when(mockWorkUnitClient.reportWorkerMessage(any())) + .thenReturn(Collections.singletonList(workerMessageResponse)); + reporter.reportPeriodicWorkerMessage(); + verify(mockExecutor).setMaximumPoolSize(10, 110); + } + + @Test + public void testHandleEmptyWorkerMessageResponse() throws Exception { + StreamingWorkerStatusReporter reporter = + StreamingWorkerStatusReporter.forTesting( + true, + mockWorkUnitClient, + () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, + () -> Collections.emptyList(), + mockFailureTracker, + StreamingCounters.create(), + mockMemoryMonitor, + mockExecutor, + (threadName) -> Executors.newSingleThreadScheduledExecutor()); + WorkerMessageResponse workerMessageResponse = new WorkerMessageResponse(); + when(mockWorkUnitClient.reportWorkerMessage(any())) + .thenReturn(Collections.singletonList(workerMessageResponse)); + reporter.reportPeriodicWorkerMessage(); + verify(mockExecutor, Mockito.times(0)).setMaximumPoolSize(anyInt(), anyInt()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java new file mode 100644 index 000000000000..c0620952ef9e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor}. */ +@RunWith(JUnit4.class) +// TODO(https://github.com/apache/beam/issues/21230): Remove when new version of errorprone is +// released (2.11.0) +@SuppressWarnings("unused") +public class BoundedQueueExecutorTest { + @Rule public transient Timeout globalTimeout = Timeout.seconds(300); + private static final long MAXIMUM_BYTES_OUTSTANDING = 10000000; + private static final int DEFAULT_MAX_THREADS = 2; + private static final int DEFAULT_THREAD_EXPIRATION_SEC = 60; + + private BoundedQueueExecutor executor; + + private Runnable createSleepProcessWorkFn(CountDownLatch start, CountDownLatch stop) { + Runnable runnable = + () -> { + start.countDown(); + try { + stop.await(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }; + return runnable; + } + + @Before + public void setUp() { + this.executor = + new BoundedQueueExecutor( + DEFAULT_MAX_THREADS, + DEFAULT_THREAD_EXPIRATION_SEC, + TimeUnit.SECONDS, + DEFAULT_MAX_THREADS + 100, + MAXIMUM_BYTES_OUTSTANDING, + new ThreadFactoryBuilder() + .setNameFormat("DataflowWorkUnits-%d") + .setDaemon(true) + .build()); + } + + @Test + public void testScheduleWorkWhenExceedMaximumPoolSize() throws Exception { + CountDownLatch processStart1 = new CountDownLatch(1); + CountDownLatch processStop1 = new CountDownLatch(1); + CountDownLatch processStart2 = new CountDownLatch(1); + CountDownLatch processStop2 = new CountDownLatch(1); + CountDownLatch processStart3 = new CountDownLatch(1); + CountDownLatch processStop3 = new CountDownLatch(1); + Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1); + Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2); + Runnable m3 = createSleepProcessWorkFn(processStart3, processStop3); + + executor.execute(m1, 1); + processStart1.await(); + executor.execute(m2, 1); + processStart2.await(); + // m1 and m2 have started and all threads are occupied so m3 will be queued and not executed. + executor.execute(m3, 1); + assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS)); + assertFalse(executor.executorQueueIsEmpty()); + + // Stop m1 so there is an available thread for m3 to run. + processStop1.countDown(); + processStart3.await(); + // m3 started. + assertTrue(executor.executorQueueIsEmpty()); + processStop2.countDown(); + processStop3.countDown(); + executor.shutdown(); + } + + @Test + public void testScheduleWorkWhenExceedMaximumBytesOutstanding() throws Exception { + CountDownLatch processStart1 = new CountDownLatch(1); + CountDownLatch processStop1 = new CountDownLatch(1); + CountDownLatch processStart2 = new CountDownLatch(1); + CountDownLatch processStop2 = new CountDownLatch(1); + Runnable m1 = createSleepProcessWorkFn(processStart1, processStop1); + Runnable m2 = createSleepProcessWorkFn(processStart2, processStop2); + + executor.execute(m1, 10000000); + processStart1.await(); + // m1 has started and reached the maximumBytesOutstanding. Though the executor has available + // threads, the new task will be blocked until the bytes are available. + // Start a new thread since executor.execute() is a blocking function. + Thread m2Runner = + new Thread( + () -> { + executor.execute(m2, 1000); + }); + m2Runner.start(); + assertFalse(processStart2.await(1000, TimeUnit.MILLISECONDS)); + // m2 will wait for monitor instead of being queued. + assertEquals(Thread.State.WAITING, m2Runner.getState()); + assertTrue(executor.executorQueueIsEmpty()); + + // Stop m1 so there are available bytes for m2 to run. + processStop1.countDown(); + processStart2.await(); + // m2 started. + assertEquals(Thread.State.TERMINATED, m2Runner.getState()); + processStop2.countDown(); + executor.shutdown(); + } + + @Test + public void testOverrideMaximumPoolSize() throws Exception { + CountDownLatch processStart1 = new CountDownLatch(1); + CountDownLatch processStart2 = new CountDownLatch(1); + CountDownLatch processStart3 = new CountDownLatch(1); + CountDownLatch stop = new CountDownLatch(1); + Runnable m1 = createSleepProcessWorkFn(processStart1, stop); + Runnable m2 = createSleepProcessWorkFn(processStart2, stop); + Runnable m3 = createSleepProcessWorkFn(processStart3, stop); + + // Initial state. + assertEquals(0, executor.activeCount()); + assertEquals(2, executor.getMaximumPoolSize()); + + // m1 and m2 are accepted. + executor.execute(m1, 1); + processStart1.await(); + assertEquals(1, executor.activeCount()); + executor.execute(m2, 1); + processStart2.await(); + assertEquals(2, executor.activeCount()); + + // Max pool size was reached so new work is queued. + executor.execute(m3, 1); + assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS)); + + // Increase the max thread count + executor.setMaximumPoolSize(3, 103); + assertEquals(3, executor.getMaximumPoolSize()); + + // m3 is accepted + processStart3.await(); + assertEquals(3, executor.activeCount()); + + stop.countDown(); + executor.shutdown(); + } + + @Test + public void testRecordTotalTimeMaxActiveThreadsUsed() throws Exception { + CountDownLatch processStart1 = new CountDownLatch(1); + CountDownLatch processStart2 = new CountDownLatch(1); + CountDownLatch processStart3 = new CountDownLatch(1); + CountDownLatch stop = new CountDownLatch(1); + Runnable m1 = createSleepProcessWorkFn(processStart1, stop); + Runnable m2 = createSleepProcessWorkFn(processStart2, stop); + Runnable m3 = createSleepProcessWorkFn(processStart3, stop); + + // Initial state. + assertEquals(0, executor.activeCount()); + assertEquals(2, executor.getMaximumPoolSize()); + + // m1 and m2 are accepted. + executor.execute(m1, 1); + processStart1.await(); + assertEquals(1, executor.activeCount()); + executor.execute(m2, 1); + processStart2.await(); + assertEquals(2, executor.activeCount()); + + // Max pool size was reached so no new work is accepted. + executor.execute(m3, 1); + assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS)); + + assertEquals(0l, executor.allThreadsActiveTime()); + stop.countDown(); + while (executor.activeCount() != 0) { + // Waiting for all threads to be ended. + Thread.sleep(200); + } + // Max pool size was reached so the allThreadsActiveTime() was updated. + assertThat(executor.allThreadsActiveTime(), greaterThan(0l)); + + executor.shutdown(); + } + + @Test + public void testRecordTotalTimeMaxActiveThreadsUsedWhenMaximumPoolSizeUpdated() throws Exception { + CountDownLatch processStart1 = new CountDownLatch(1); + CountDownLatch processStart2 = new CountDownLatch(1); + CountDownLatch processStart3 = new CountDownLatch(1); + CountDownLatch stop = new CountDownLatch(1); + Runnable m1 = createSleepProcessWorkFn(processStart1, stop); + Runnable m2 = createSleepProcessWorkFn(processStart2, stop); + Runnable m3 = createSleepProcessWorkFn(processStart3, stop); + + // Initial state. + assertEquals(0, executor.activeCount()); + assertEquals(2, executor.getMaximumPoolSize()); + + // m1 and m2 are accepted. + executor.execute(m1, 1); + processStart1.await(); + assertEquals(1, executor.activeCount()); + executor.execute(m2, 1); + processStart2.await(); + assertEquals(2, executor.activeCount()); + + // Max pool size was reached so no new work is accepted. + executor.execute(m3, 1); + assertFalse(processStart3.await(1000, TimeUnit.MILLISECONDS)); + + assertEquals(0l, executor.allThreadsActiveTime()); + // Increase the max thread count + executor.setMaximumPoolSize(5, 105); + stop.countDown(); + while (executor.activeCount() != 0) { + // Waiting for all threads to be ended. + Thread.sleep(200); + } + // Max pool size was updated during execution but allThreadsActiveTime() was still recorded + // for the thread which reached the old max pool size. + assertThat(executor.allThreadsActiveTime(), greaterThan(0l)); + + executor.shutdown(); + } + + @Test + public void testRenderSummaryHtml() throws Exception { + String expectedSummaryHtml = + "Worker Threads: 0/2
/n" + + "Active Threads: 0
/n" + + "Work Queue Size: 0/102
/n" + + "Work Queue Bytes: 0/10000000
/n"; + assertEquals(expectedSummaryHtml, executor.summaryHtml()); + } +}