Skip to content

Commit

Permalink
extract semaphore logic out of WeightBoundedQueue to allow for sharin…
Browse files Browse the repository at this point in the history
…g the weigher (#32905)
  • Loading branch information
m-trieu authored Nov 6, 2024
1 parent 738a76d commit 7089321
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
Expand Down Expand Up @@ -199,6 +200,7 @@ private StreamingDataflowWorker(
this.workCommitter =
windmillServiceEnabled
? StreamingEngineWorkCommitter.builder()
.setCommitByteSemaphore(Commits.maxCommitByteSemaphore())
.setCommitWorkStreamFactory(
WindmillStreamPool.create(
numCommitThreads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,40 @@
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;

/** Bounded set of queues, with a maximum total weight. */
/** Queue bounded by a {@link WeightedSemaphore}. */
public final class WeightedBoundedQueue<V> {

private final LinkedBlockingQueue<V> queue;
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;
private final WeightedSemaphore<V> weightedSemaphore;

private WeightedBoundedQueue(
LinkedBlockingQueue<V> linkedBlockingQueue,
int maxWeight,
Semaphore limit,
Function<V, Integer> weigher) {
LinkedBlockingQueue<V> linkedBlockingQueue, WeightedSemaphore<V> weightedSemaphore) {
this.queue = linkedBlockingQueue;
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
this.weightedSemaphore = weightedSemaphore;
}

public static <V> WeightedBoundedQueue<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedBoundedQueue<>(
new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn);
public static <V> WeightedBoundedQueue<V> create(WeightedSemaphore<V> weightedSemaphore) {
return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore);
}

/**
* Adds the value to the queue, blocking if this would cause the overall weight to exceed the
* limit.
*/
public void put(V value) {
limit.acquireUninterruptibly(weigher.apply(value));
weightedSemaphore.acquireUninterruptibly(value);
queue.add(value);
}

/** Returns and removes the next value, or null if there is no such value. */
public @Nullable V poll() {
V result = queue.poll();
@Nullable V result = queue.poll();
if (result != null) {
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
}
return result;
}
Expand All @@ -76,26 +67,22 @@ public void put(V value) {
* @throws InterruptedException if interrupted while waiting
*/
public @Nullable V poll(long timeout, TimeUnit unit) throws InterruptedException {
V result = queue.poll(timeout, unit);
@Nullable V result = queue.poll(timeout, unit);
if (result != null) {
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
}
return result;
}

/** Returns and removes the next value, or blocks until one is available. */
public @Nullable V take() throws InterruptedException {
public V take() throws InterruptedException {
V result = queue.take();
limit.release(weigher.apply(result));
weightedSemaphore.release(result);
return result;
}

/** Returns the current weight of the queue. */
public int queuedElementsWeight() {
return maxWeight - limit.availablePermits();
}

public int size() {
@VisibleForTesting
int size() {
return queue.size();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import java.util.concurrent.Semaphore;
import java.util.function.Function;

public final class WeightedSemaphore<V> {
private final int maxWeight;
private final Semaphore limit;
private final Function<V, Integer> weigher;

private WeightedSemaphore(int maxWeight, Semaphore limit, Function<V, Integer> weigher) {
this.maxWeight = maxWeight;
this.limit = limit;
this.weigher = weigher;
}

public static <V> WeightedSemaphore<V> create(int maxWeight, Function<V, Integer> weigherFn) {
return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn);
}

public void acquireUninterruptibly(V value) {
limit.acquireUninterruptibly(computePermits(value));
}

public void release(V value) {
limit.release(computePermits(value));
}

private int computePermits(V value) {
return Math.min(weigher.apply(value), maxWeight);
}

public int currentWeight() {
return maxWeight - limit.availablePermits();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.windmill.client.commits;

import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;

/** Utility class for commits. */
@Internal
public final class Commits {

/** Max bytes of commits queued on the user worker. */
@VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB

private Commits() {}

public static WeightedSemaphore<Commit> maxCommitByteSemaphore() {
return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
public final class StreamingApplianceWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB

private final Consumer<CommitWorkRequest> commitWorkFn;
private final WeightedBoundedQueue<Commit> commitQueue;
Expand All @@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter {
private StreamingApplianceWorkCommitter(
Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit> onCommitComplete) {
this.commitWorkFn = commitWorkFn;
this.commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()));
this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore());
this.commitWorkers =
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
Expand All @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create(
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
if (!commitWorkers.isShutdown()) {
commitWorkers.submit(this::commitLoop);
commitWorkers.execute(this::commitLoop);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
Expand All @@ -46,7 +47,6 @@
public final class StreamingEngineWorkCommitter implements WorkCommitter {
private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
private static final int TARGET_COMMIT_BATCH_KEYS = 5;
private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
private static final String NO_BACKEND_WORKER_TOKEN = "";

private final Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory;
Expand All @@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter {
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
int numCommitSenders,
Consumer<CompleteCommit> onCommitComplete,
String backendWorkerToken) {
String backendWorkerToken,
WeightedSemaphore<Commit> commitByteSemaphore) {
this.commitWorkStreamFactory = commitWorkStreamFactory;
this.commitQueue =
WeightedBoundedQueue.create(
MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize()));
this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore);
this.commitSenders =
Executors.newFixedThreadPool(
numCommitSenders,
Expand All @@ -90,12 +89,11 @@ public static Builder builder() {
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void start() {
Preconditions.checkState(
isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start().");
for (int i = 0; i < numCommitSenders; i++) {
commitSenders.submit(this::streamingCommitLoop);
commitSenders.execute(this::streamingCommitLoop);
}
}

Expand Down Expand Up @@ -166,6 +164,8 @@ private void streamingCommitLoop() {
return;
}
}

// take() blocks until a value is available in the commitQueue.
Preconditions.checkNotNull(initialCommit);

if (initialCommit.work().isFailed()) {
Expand Down Expand Up @@ -258,6 +258,8 @@ public interface Builder {
Builder setCommitWorkStreamFactory(
Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory);

Builder setCommitByteSemaphore(WeightedSemaphore<Commit> commitByteSemaphore);

Builder setNumCommitSenders(int numCommitSenders);

Builder setOnCommitComplete(Consumer<CompleteCommit> onCommitComplete);
Expand Down
Loading

0 comments on commit 7089321

Please sign in to comment.