diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ff72add83e4d..6ce60283735f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -65,6 +65,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; @@ -199,6 +200,7 @@ private StreamingDataflowWorker( this.workCommitter = windmillServiceEnabled ? StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory( WindmillStreamPool.create( numCommitThreads, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index f2893f3e7191..5f039be7b00f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -18,33 +18,24 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; -/** Bounded set of queues, with a maximum total weight. */ +/** Queue bounded by a {@link WeightedSemaphore}. */ public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; + private final WeightedSemaphore weightedSemaphore; private WeightedBoundedQueue( - LinkedBlockingQueue linkedBlockingQueue, - int maxWeight, - Semaphore limit, - Function weigher) { + LinkedBlockingQueue linkedBlockingQueue, WeightedSemaphore weightedSemaphore) { this.queue = linkedBlockingQueue; - this.maxWeight = maxWeight; - this.limit = limit; - this.weigher = weigher; + this.weightedSemaphore = weightedSemaphore; } - public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { - return new WeightedBoundedQueue<>( - new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + public static WeightedBoundedQueue create(WeightedSemaphore weightedSemaphore) { + return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore); } /** @@ -52,15 +43,15 @@ public static WeightedBoundedQueue create(int maxWeight, Function { + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedSemaphore(int maxWeight, Semaphore limit, Function weigher) { + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedSemaphore create(int maxWeight, Function weigherFn) { + return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + public void acquireUninterruptibly(V value) { + limit.acquireUninterruptibly(computePermits(value)); + } + + public void release(V value) { + limit.release(computePermits(value)); + } + + private int computePermits(V value) { + return Math.min(weigher.apply(value), maxWeight); + } + + public int currentWeight() { + return maxWeight - limit.availablePermits(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java new file mode 100644 index 000000000000..498e90f78e29 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** Utility class for commits. */ +@Internal +public final class Commits { + + /** Max bytes of commits queued on the user worker. */ + @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB + + private Commits() {} + + public static WeightedSemaphore maxCommitByteSemaphore() { + return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 6889764afe69..20b95b0661d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -42,7 +42,6 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class); private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private final Consumer commitWorkFn; private final WeightedBoundedQueue commitQueue; @@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private StreamingApplianceWorkCommitter( Consumer commitWorkFn, Consumer onCommitComplete) { this.commitWorkFn = commitWorkFn; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore()); this.commitWorkers = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create( } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { if (!commitWorkers.isShutdown()) { - commitWorkers.submit(this::commitLoop); + commitWorkers.execute(this::commitLoop); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index bf1007bc4bfb..85fa1d67c6c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -46,7 +47,6 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; @@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { Supplier> commitWorkStreamFactory, int numCommitSenders, Consumer onCommitComplete, - String backendWorkerToken) { + String backendWorkerToken, + WeightedSemaphore commitByteSemaphore) { this.commitWorkStreamFactory = commitWorkStreamFactory; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore); this.commitSenders = Executors.newFixedThreadPool( numCommitSenders, @@ -90,12 +89,11 @@ public static Builder builder() { } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { Preconditions.checkState( isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); + commitSenders.execute(this::streamingCommitLoop); } } @@ -166,6 +164,8 @@ private void streamingCommitLoop() { return; } } + + // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); if (initialCommit.work().isFailed()) { @@ -258,6 +258,8 @@ public interface Builder { Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); + Builder setCommitByteSemaphore(WeightedSemaphore commitByteSemaphore); + Builder setNumCommitSenders(int numCommitSenders); Builder setOnCommitComplete(Consumer onCommitComplete); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java index 4f035c88774c..c71001fbeee7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -30,27 +31,29 @@ @RunWith(JUnit4.class) public class WeightBoundedQueueTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int MAX_WEIGHT = 10; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Test public void testPut_hasCapacity() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue = 1; queue.put(insertedValue); - assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(insertedValue, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); assertEquals(insertedValue, (int) queue.poll()); } @Test public void testPut_noCapacity() throws InterruptedException { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); // Insert value that takes all the capacity into the queue. queue.put(MAX_WEIGHT); @@ -71,7 +74,7 @@ public void testPut_noCapacity() throws InterruptedException { // Should only see the first value in the queue, since the queue is at capacity. thread2 // should be blocked. - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); // Poll the queue, pulling off the only value inside and freeing up the capacity in the queue. @@ -80,14 +83,15 @@ public void testPut_noCapacity() throws InterruptedException { // Wait for the putThread which was previously blocked due to the queue being at capacity. putThread.join(); - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); } @Test public void testPoll() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue1 = 1; int insertedValue2 = 2; @@ -95,7 +99,7 @@ public void testPoll() { queue.put(insertedValue1); queue.put(insertedValue2); - assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(insertedValue1 + insertedValue2, weightedSemaphore.currentWeight()); assertEquals(2, queue.size()); assertEquals(insertedValue1, (int) queue.poll()); assertEquals(1, queue.size()); @@ -104,7 +108,8 @@ public void testPoll() { @Test public void testPoll_withTimeout() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int pollWaitTimeMillis = 10000; int insertedValue1 = 1; @@ -132,7 +137,8 @@ public void testPoll_withTimeout() throws InterruptedException { @Test public void testPoll_withTimeout_timesOut() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int defaultPollResult = -10; int pollWaitTimeMillis = 100; int insertedValue1 = 1; @@ -144,13 +150,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { Thread pollThread = new Thread( () -> { - int polled; + @Nullable Integer polled; try { polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); - pollResult.set(polled); + if (polled != null) { + pollResult.set(polled); + } } catch (InterruptedException e) { throw new RuntimeException(e); } + + assertNull(polled); }); pollThread.start(); @@ -164,7 +174,8 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { @Test public void testPoll_emptyQueue() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); assertNull(queue.poll()); } @@ -172,7 +183,8 @@ public void testPoll_emptyQueue() { @Test public void testTake() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); AtomicInteger value = new AtomicInteger(); // Should block until value is available @@ -194,4 +206,39 @@ public void testTake() throws InterruptedException { assertEquals(MAX_WEIGHT, value.get()); } + + @Test + public void testPut_sharedWeigher() throws InterruptedException { + WeightedSemaphore weigher = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue1 = WeightedBoundedQueue.create(weigher); + WeightedBoundedQueue queue2 = WeightedBoundedQueue.create(weigher); + + // Insert value that takes all the weight into the queue1. + queue1.put(MAX_WEIGHT); + + // Try to insert a value into the queue2. This will block since there is no capacity in the + // weigher. + Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT)); + putThread.start(); + // Should only see the first value in the queue, since the queue is at capacity. putThread + // should be blocked. The weight should be the same however, since queue1 and queue2 are sharing + // the weigher. + Thread.sleep(100); + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue1.size()); + assertEquals(0, queue2.size()); + + // Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher. + queue1.poll(); + + // Wait for the putThread which was previously blocked due to the weigher being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue2.size()); + queue2.poll(); + assertEquals(0, queue2.size()); + assertEquals(0, weigher.currentWeight()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..c05a4dd340dd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -121,6 +121,7 @@ public void setUp() throws IOException { private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { return StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setOnCommitComplete(onCommitComplete) .build(); @@ -342,6 +343,7 @@ public void testMultipleCommitSendersSingleStream() { Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setNumCommitSenders(5) .setOnCommitComplete(completeCommits::add)