diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 6f1bb0847bc8..4c3ffd08a0b7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -87,17 +87,14 @@ import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; -import org.apache.beam.runners.dataflow.worker.streaming.Commit; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState; import org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; -import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.Work.State; import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; -import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; @@ -110,9 +107,13 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; @@ -217,9 +218,6 @@ public class StreamingDataflowWorker { final WindmillStateCache stateCache; // Maps from computation ids to per-computation state. private final ConcurrentMap computationMap; - private final WeightedBoundedQueue commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); // Cache of tokens to commit callbacks. // Using Cache with time eviction policy helps us to prevent memory leak when callback ids are // discarded by Dataflow service and calling commitCallback is best-effort. @@ -234,8 +232,6 @@ public class StreamingDataflowWorker { private final BoundedQueueExecutor workUnitExecutor; private final WindmillServerStub windmillServer; private final Thread dispatchThread; - @VisibleForTesting final ImmutableList commitThreads; - private final AtomicLong activeCommitBytes = new AtomicLong(); private final AtomicLong previousTimeAtMaxThreads = new AtomicLong(); private final AtomicBoolean running = new AtomicBoolean(); private final SideInputStateFetcher sideInputStateFetcher; @@ -296,6 +292,7 @@ public class StreamingDataflowWorker { private final DataflowExecutionStateSampler sampler = DataflowExecutionStateSampler.instance(); private final ActiveWorkRefresher activeWorkRefresher; + private final WorkCommitter workCommitter; private StreamingDataflowWorker( WindmillServerStub windmillServer, @@ -403,29 +400,6 @@ private StreamingDataflowWorker( dispatchThread.setPriority(Thread.MIN_PRIORITY); dispatchThread.setName("DispatchThread"); - int numCommitThreads = 1; - if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 0) { - numCommitThreads = options.getWindmillServiceCommitThreads(); - } - - ImmutableList.Builder commitThreadsBuilder = ImmutableList.builder(); - for (int i = 0; i < numCommitThreads; ++i) { - Thread commitThread = - new Thread( - () -> { - if (windmillServiceEnabled) { - streamingCommitLoop(); - } else { - commitLoop(); - } - }); - commitThread.setDaemon(true); - commitThread.setPriority(Thread.MAX_PRIORITY); - commitThread.setName("CommitThread " + i); - commitThreadsBuilder.add(commitThread); - } - commitThreads = commitThreadsBuilder.build(); - this.publishCounters = publishCounters; this.clientId = clientId; this.windmillServer = windmillServer; @@ -438,6 +412,21 @@ private StreamingDataflowWorker( this.sideInputStateFetcher = new SideInputStateFetcher(metricTrackingWindmillServer::getSideInputData, options); + int numCommitThreads = 1; + if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 0) { + numCommitThreads = options.getWindmillServiceCommitThreads(); + } + + this.workCommitter = + windmillServiceEnabled + ? StreamingEngineWorkCommitter.create( + WindmillStreamPool.create( + NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream) + ::getCloseableStream, + numCommitThreads, + this::onCompleteCommit) + : StreamingApplianceWorkCommitter.create( + windmillServer::commitWork, this::onCompleteCommit); // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); @@ -705,6 +694,11 @@ public boolean workExecutorIsEmpty() { return workUnitExecutor.executorQueueIsEmpty(); } + @VisibleForTesting + int numCommitThreads() { + return workCommitter.parallelism(); + } + @SuppressWarnings("FutureReturnValueIgnored") public void start() { running.set(true); @@ -716,7 +710,6 @@ public void start() { memoryMonitorThread.start(); dispatchThread.start(); - commitThreads.forEach(Thread::start); sampler.start(); // Periodically report workers counters and other updates. @@ -778,7 +771,7 @@ public void start() { TimeUnit.SECONDS); scheduledExecutors.add(statusPageTimer); } - + workCommitter.start(); reportHarnessStartup(); } @@ -834,12 +827,8 @@ public void stop() { running.set(false); dispatchThread.interrupt(); dispatchThread.join(); - // We need to interrupt the commitThreads in case they are blocking on pulling - // from the commitQueue. - commitThreads.forEach(Thread::interrupt); - for (Thread commitThread : commitThreads) { - commitThread.join(); - } + + workCommitter.stop(); memoryMonitor.stop(); memoryMonitorThread.join(); workUnitExecutor.shutdown(); @@ -1086,7 +1075,7 @@ private void process( if (workItem.getSourceState().getOnlyFinalize()) { outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); work.setState(State.COMMIT_QUEUED); - commitQueue.put(Commit.create(outputBuilder.build(), computationState, work)); + workCommitter.commit(Commit.create(outputBuilder.build(), computationState, work)); return; } @@ -1315,7 +1304,7 @@ private void process( commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize); } - commitQueue.put(Commit.create(commitRequest, computationState, work)); + workCommitter.commit(Commit.create(commitRequest, computationState, work)); // Compute shuffle and state byte statistics these will be flushed asynchronously. long stateBytesWritten = @@ -1444,163 +1433,21 @@ private WorkItemCommitRequest buildWorkItemTruncationRequest( return outputBuilder.build(); } - private void commitLoop() { - Map computationRequestMap = - new HashMap<>(); - while (running.get()) { - computationRequestMap.clear(); - Windmill.CommitWorkRequest.Builder commitRequestBuilder = - Windmill.CommitWorkRequest.newBuilder(); - long commitBytes = 0; - // Block until we have a commit, then batch with additional commits. - Commit commit = null; - try { - commit = commitQueue.take(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - continue; - } - while (commit != null) { - ComputationState computationState = commit.computationState(); - commit.work().setState(Work.State.COMMITTING); - Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = - computationRequestMap.get(computationState); - if (computationRequestBuilder == null) { - computationRequestBuilder = commitRequestBuilder.addRequestsBuilder(); - computationRequestBuilder.setComputationId(computationState.getComputationId()); - computationRequestMap.put(computationState, computationRequestBuilder); - } - computationRequestBuilder.addRequests(commit.request()); - // Send the request if we've exceeded the bytes or there is no more - // pending work. commitBytes is a long, so this cannot overflow. - commitBytes += commit.getSize(); - if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) { - break; - } - commit = commitQueue.poll(); - } - Windmill.CommitWorkRequest commitRequest = commitRequestBuilder.build(); - LOG.trace("Commit: {}", commitRequest); - activeCommitBytes.addAndGet(commitBytes); - windmillServer.commitWork(commitRequest); - activeCommitBytes.addAndGet(-commitBytes); - for (Map.Entry entry : - computationRequestMap.entrySet()) { - ComputationState computationState = entry.getKey(); - for (Windmill.WorkItemCommitRequest workRequest : entry.getValue().getRequestsList()) { - computationState.completeWorkAndScheduleNextWorkForKey( - ShardedKey.create(workRequest.getKey(), workRequest.getShardingKey()), - WorkId.builder() - .setCacheToken(workRequest.getCacheToken()) - .setWorkToken(workRequest.getWorkToken()) - .build()); - } - } - } - } - - // Adds the commit to the commitStream if it fits, returning true iff it is consumed. - private boolean addCommitToStream(Commit commit, CommitWorkStream commitStream) { - Preconditions.checkNotNull(commit); - final ComputationState state = commit.computationState(); - final Windmill.WorkItemCommitRequest request = commit.request(); - // Drop commits for failed work. Such commits will be dropped by Windmill anyway. - if (commit.work().isFailed()) { + private void onCompleteCommit(CompleteCommit completeCommit) { + if (completeCommit.status() != Windmill.CommitStatus.OK) { readerCache.invalidateReader( WindmillComputationKey.create( - state.getComputationId(), request.getKey(), request.getShardingKey())); + completeCommit.computationId(), completeCommit.shardedKey())); stateCache - .forComputation(state.getComputationId()) - .invalidate(request.getKey(), request.getShardingKey()); - state.completeWorkAndScheduleNextWorkForKey( - ShardedKey.create(request.getKey(), request.getShardingKey()), - WorkId.builder() - .setWorkToken(request.getWorkToken()) - .setCacheToken(request.getCacheToken()) - .build()); - return true; - } - - final int size = commit.getSize(); - commit.work().setState(Work.State.COMMITTING); - activeCommitBytes.addAndGet(size); - if (commitStream.commitWorkItem( - state.getComputationId(), - request, - (Windmill.CommitStatus status) -> { - if (status != Windmill.CommitStatus.OK) { - readerCache.invalidateReader( - WindmillComputationKey.create( - state.getComputationId(), request.getKey(), request.getShardingKey())); - stateCache - .forComputation(state.getComputationId()) - .invalidate(request.getKey(), request.getShardingKey()); - } - activeCommitBytes.addAndGet(-size); - state.completeWorkAndScheduleNextWorkForKey( - ShardedKey.create(request.getKey(), request.getShardingKey()), - WorkId.builder() - .setCacheToken(request.getCacheToken()) - .setWorkToken(request.getWorkToken()) - .build()); - })) { - return true; - } else { - // Back out the stats changes since the commit wasn't consumed. - commit.work().setState(Work.State.COMMIT_QUEUED); - activeCommitBytes.addAndGet(-size); - return false; + .forComputation(completeCommit.computationId()) + .invalidate(completeCommit.shardedKey()); } - } - // Helper to batch additional commits into the commit stream as long as they fit. - // Returns a commit that was removed from the queue but not consumed or null. - private Commit batchCommitsToStream(CommitWorkStream commitStream) { - int commits = 1; - while (running.get()) { - Commit commit; - try { - if (commits < 5) { - commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS); - } else { - commit = commitQueue.poll(); - } - } catch (InterruptedException e) { - // Continue processing until !running.get() - continue; - } - if (commit == null || !addCommitToStream(commit, commitStream)) { - return commit; - } - commits++; - } - return null; - } - - private void streamingCommitLoop() { - WindmillStreamPool streamPool = - WindmillStreamPool.create( - NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream); - Commit initialCommit = null; - while (running.get()) { - if (initialCommit == null) { - try { - initialCommit = commitQueue.take(); - } catch (InterruptedException e) { - continue; - } - } - // We initialize the commit stream only after we have a commit to make sure it is fresh. - CommitWorkStream commitStream = streamPool.getStream(); - if (!addCommitToStream(initialCommit, commitStream)) { - throw new AssertionError("Initial commit on flushed stream should always be accepted."); - } - // Batch additional commits to the stream and possibly make an un-batched commit the next - // initial commit. - initialCommit = batchCommitsToStream(commitStream); - commitStream.flush(); - streamPool.releaseStream(commitStream); - } + Optional.ofNullable(computationMap.get(completeCommit.computationId())) + .ifPresent( + state -> + state.completeWorkAndScheduleNextWorkForKey( + completeCommit.shardedKey(), completeCommit.workId())); } private Windmill.GetWorkResponse getWork() { @@ -2094,7 +1941,7 @@ public void appendSummaryHtml(PrintWriter writer) { writer.println(workUnitExecutor.summaryHtml()); writer.print("Active commit: "); - appendHumanizedBytes(activeCommitBytes.get(), writer); + appendHumanizedBytes(workCommitter.currentActiveCommitBytes(), writer); writer.println("
"); metricTrackingWindmillServer.printHtml(writer); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java index a01b1d297c22..274fa3aff026 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker; import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; @@ -29,6 +30,10 @@ public static WindmillComputationKey create( return new AutoValue_WindmillComputationKey(computationId, key, shardingKey); } + public static WindmillComputationKey create(String computationId, ShardedKey shardedKey) { + return create(computationId, shardedKey.key(), shardedKey.shardingKey()); + } + public abstract String computationId(); public abstract ByteString key(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java new file mode 100644 index 000000000000..e76cc3659653 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.annotations.Internal; + +/** + * Wrapper for a {@link WindmillStream} that allows callers to tie an action after the stream is + * finished being used. Has an option for closing code to be a no-op. + */ +@Internal +@AutoValue +public abstract class CloseableStream implements AutoCloseable { + public static CloseableStream create( + StreamT stream, Runnable onClose) { + return new AutoValue_CloseableStream<>(stream, onClose); + } + + public abstract StreamT stream(); + + abstract Runnable onClose(); + + @Override + public void close() throws Exception { + onClose().run(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java index 9f1b67edc1e0..0e4e085c066c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java @@ -25,6 +25,7 @@ import java.util.function.Supplier; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -36,6 +37,7 @@ *

The pool holds a fixed total number of streams, and keeps each stream open for a specified * time to allow for better load-balancing. */ +@Internal @ThreadSafe public class WindmillStreamPool { @@ -131,6 +133,11 @@ public StreamT getStream() { } } + public CloseableStream getCloseableStream() { + StreamT stream = getStream(); + return CloseableStream.create(stream, () -> releaseStream(stream)); + } + private synchronized WindmillStreamPool.StreamData createAndCacheStream(int cacheKey) { WindmillStreamPool.StreamData newStreamData = new WindmillStreamPool.StreamData<>(streamSupplier.get()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java similarity index 81% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java index 946897967561..b840d22a3434 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java @@ -15,13 +15,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.dataflow.worker.streaming; +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; /** Value class for a queued commit. */ +@Internal @AutoValue public abstract class Commit { @@ -31,6 +35,10 @@ public static Commit create( return new AutoValue_Commit(request, computationState, work); } + public final String computationId() { + return computationState().getComputationId(); + } + public abstract WorkItemCommitRequest request(); public abstract ComputationState computationState(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java new file mode 100644 index 000000000000..64fec71b0006 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.WorkId; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; + +/** + * A {@link Commit} is marked as complete when it has been attempted to be committed back to + * Streaming Engine/Appliance via {@link + * org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWorkStream(StreamObserver)} + * for Streaming Engine or {@link + * org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWork(Windmill.CommitWorkRequest, + * StreamObserver)} for Streaming Appliance. + */ +@Internal +@AutoValue +public abstract class CompleteCommit { + + public static CompleteCommit create(Commit commit, CommitStatus commitStatus) { + return new AutoValue_CompleteCommit( + commit.computationId(), + ShardedKey.create(commit.request().getKey(), commit.request().getShardingKey()), + WorkId.builder() + .setWorkToken(commit.request().getWorkToken()) + .setCacheToken(commit.request().getCacheToken()) + .build(), + commitStatus); + } + + public static CompleteCommit create( + String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus status) { + return new AutoValue_CompleteCommit(computationId, shardedKey, workId, status); + } + + public static CompleteCommit forFailedWork(Commit commit) { + return create(commit, CommitStatus.ABORTED); + } + + public abstract String computationId(); + + public abstract ShardedKey shardedKey(); + + public abstract WorkId workId(); + + public abstract CommitStatus status(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java new file mode 100644 index 000000000000..344f04cfd00b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.WorkId; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Streaming appliance implementation of {@link WorkCommitter}. */ +@Internal +@ThreadSafe +public final class StreamingApplianceWorkCommitter implements WorkCommitter { + private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class); + private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; + private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + + private final Consumer commitWorkFn; + private final WeightedBoundedQueue commitQueue; + private final ExecutorService commitWorkers; + private final AtomicLong activeCommitBytes; + private final Consumer onCommitComplete; + + private StreamingApplianceWorkCommitter( + Consumer commitWorkFn, Consumer onCommitComplete) { + this.commitWorkFn = commitWorkFn; + this.commitQueue = + WeightedBoundedQueue.create( + MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitWorkers = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setPriority(Thread.MAX_PRIORITY) + .setNameFormat("CommitThread-%d") + .build()); + this.activeCommitBytes = new AtomicLong(); + this.onCommitComplete = onCommitComplete; + } + + public static StreamingApplianceWorkCommitter create( + Consumer commitWork, Consumer onCommitComplete) { + return new StreamingApplianceWorkCommitter(commitWork, onCommitComplete); + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public void start() { + if (!commitWorkers.isShutdown()) { + commitWorkers.submit(this::commitLoop); + } + } + + @Override + public void commit(Commit commit) { + commitQueue.put(commit); + } + + @Override + public long currentActiveCommitBytes() { + return activeCommitBytes.get(); + } + + @Override + public void stop() { + commitWorkers.shutdownNow(); + } + + @Override + public int parallelism() { + return 1; + } + + private void commitLoop() { + Map computationRequestMap = + new HashMap<>(); + while (true) { + computationRequestMap.clear(); + CommitWorkRequest.Builder commitRequestBuilder = CommitWorkRequest.newBuilder(); + long commitBytes = 0; + // Block until we have a commit, then batch with additional commits. + Commit commit; + try { + commit = commitQueue.take(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + continue; + } + while (commit != null) { + ComputationState computationState = commit.computationState(); + commit.work().setState(Work.State.COMMITTING); + Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder = + computationRequestMap.get(computationState); + if (computationRequestBuilder == null) { + computationRequestBuilder = commitRequestBuilder.addRequestsBuilder(); + computationRequestBuilder.setComputationId(computationState.getComputationId()); + computationRequestMap.put(computationState, computationRequestBuilder); + } + computationRequestBuilder.addRequests(commit.request()); + // Send the request if we've exceeded the bytes or there is no more + // pending work. commitBytes is a long, so this cannot overflow. + commitBytes += commit.getSize(); + if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) { + break; + } + commit = commitQueue.poll(); + } + commitWork(commitRequestBuilder.build(), commitBytes); + completeWork(computationRequestMap); + } + } + + private void commitWork(CommitWorkRequest commitRequest, long commitBytes) { + LOG.trace("Commit: {}", commitRequest); + activeCommitBytes.addAndGet(commitBytes); + commitWorkFn.accept(commitRequest); + activeCommitBytes.addAndGet(-commitBytes); + } + + private void completeWork( + Map committedWork) { + for (Map.Entry entry : + committedWork.entrySet()) { + for (Windmill.WorkItemCommitRequest workRequest : entry.getValue().getRequestsList()) { + // Appliance errors are propagated by exception on entire batch. + onCommitComplete.accept( + CompleteCommit.create( + entry.getKey().getComputationId(), + ShardedKey.create(workRequest.getKey(), workRequest.getShardingKey()), + WorkId.builder() + .setCacheToken(workRequest.getCacheToken()) + .setWorkToken(workRequest.getWorkToken()) + .build(), + Windmill.CommitStatus.OK)); + } + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java new file mode 100644 index 000000000000..f6088acf0115 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Streaming engine implementation of {@link WorkCommitter}. Commits work back to Streaming Engine + * backend. + */ +@Internal +@ThreadSafe +public final class StreamingEngineWorkCommitter implements WorkCommitter { + private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); + private static final int TARGET_COMMIT_BATCH_KEYS = 5; + private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + + private final Supplier> commitWorkStreamFactory; + private final WeightedBoundedQueue commitQueue; + private final ExecutorService commitSenders; + private final AtomicLong activeCommitBytes; + private final Consumer onCommitComplete; + private final int numCommitSenders; + + private StreamingEngineWorkCommitter( + Supplier> commitWorkStreamFactory, + int numCommitSenders, + Consumer onCommitComplete) { + this.commitWorkStreamFactory = commitWorkStreamFactory; + this.commitQueue = + WeightedBoundedQueue.create( + MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitSenders = + Executors.newFixedThreadPool( + numCommitSenders, + new ThreadFactoryBuilder() + .setDaemon(true) + .setPriority(Thread.MAX_PRIORITY) + .setNameFormat("CommitThread-%d") + .build()); + this.activeCommitBytes = new AtomicLong(); + this.onCommitComplete = onCommitComplete; + this.numCommitSenders = numCommitSenders; + } + + public static StreamingEngineWorkCommitter create( + Supplier> commitWorkStreamFactory, + int numCommitSenders, + Consumer onCommitComplete) { + return new StreamingEngineWorkCommitter( + commitWorkStreamFactory, numCommitSenders, onCommitComplete); + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public void start() { + if (!commitSenders.isShutdown()) { + for (int i = 0; i < numCommitSenders; i++) { + commitSenders.submit(this::streamingCommitLoop); + } + } + } + + @Override + public void commit(Commit commit) { + commitQueue.put(commit); + } + + @Override + public long currentActiveCommitBytes() { + return activeCommitBytes.get(); + } + + @Override + public void stop() { + if (!commitSenders.isTerminated() || !commitSenders.isShutdown()) { + commitSenders.shutdown(); + try { + commitSenders.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.warn("Could not shut down commitSenders gracefully, forcing shutdown.", e); + } + commitSenders.shutdownNow(); + } + drainCommitQueue(); + } + + private void drainCommitQueue() { + Commit queuedCommit = commitQueue.poll(); + while (queuedCommit != null) { + failCommit(queuedCommit); + queuedCommit = commitQueue.poll(); + } + } + + private void failCommit(Commit commit) { + commit.work().setFailed(); + onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + } + + @Override + public int parallelism() { + return numCommitSenders; + } + + private void streamingCommitLoop() { + @Nullable Commit initialCommit = null; + try { + while (true) { + if (initialCommit == null) { + try { + // Block until we have a commit or are shutting down. + initialCommit = commitQueue.take(); + } catch (InterruptedException e) { + continue; + } + } + + if (initialCommit.work().isFailed()) { + onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit)); + initialCommit = null; + continue; + } + + try (CloseableStream closeableCommitStream = + commitWorkStreamFactory.get()) { + CommitWorkStream commitStream = closeableCommitStream.stream(); + if (!tryAddToCommitStream(initialCommit, commitStream)) { + throw new AssertionError("Initial commit on flushed stream should always be accepted."); + } + // Batch additional commits to the stream and possibly make an un-batched commit the next + // initial commit. + initialCommit = batchCommitsToStream(commitStream); + commitStream.flush(); + } catch (Exception e) { + LOG.error("Error occurred fetching a CommitWorkStream.", e); + } + } + } finally { + if (initialCommit != null) { + failCommit(initialCommit); + } + } + } + + /** Adds the commit to the commitStream if it fits, returning true if it is consumed. */ + private boolean tryAddToCommitStream(Commit commit, CommitWorkStream commitStream) { + Preconditions.checkNotNull(commit); + commit.work().setState(Work.State.COMMITTING); + activeCommitBytes.addAndGet(commit.getSize()); + boolean isCommitAccepted = + commitStream.commitWorkItem( + commit.computationId(), + commit.request(), + (commitStatus) -> { + onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); + activeCommitBytes.addAndGet(-commit.getSize()); + }); + + // Since the commit was not accepted, revert the changes made above. + if (!isCommitAccepted) { + commit.work().setState(Work.State.COMMIT_QUEUED); + activeCommitBytes.addAndGet(-commit.getSize()); + } + + return isCommitAccepted; + } + + // Helper to batch additional commits into the commit stream as long as they fit. + // Returns a commit that was removed from the queue but not consumed or null. + private Commit batchCommitsToStream(CommitWorkStream commitStream) { + int commits = 1; + while (true) { + Commit commit; + try { + if (commits < TARGET_COMMIT_BATCH_KEYS) { + commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS); + } else { + commit = commitQueue.poll(); + } + } catch (InterruptedException e) { + // Continue processing until !running.get() + continue; + } + + if (commit == null) { + return null; + } + + // Drop commits for failed work. Such commits will be dropped by Windmill anyway. + if (commit.work().isFailed()) { + onCommitComplete.accept(CompleteCommit.forFailedWork(commit)); + continue; + } + + if (!tryAddToCommitStream(commit, commitStream)) { + return commit; + } + commits++; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java new file mode 100644 index 000000000000..11a4c00db9d3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; + +/** + * Commits {@link org.apache.beam.runners.dataflow.worker.streaming.Work} that has completed user + * processing back to persistence layer. + */ +@Internal +@ThreadSafe +public interface WorkCommitter { + + /** Starts internal processing of commits. */ + void start(); + + /** + * Add a commit to {@link WorkCommitter}. This may be block the calling thread depending on + * underlying implementations, and persisting to the persistence layer may be done asynchronously. + */ + void commit(Commit commit); + + /** Number of bytes currently trying to be committed to the backing persistence layer. */ + long currentActiveCommitBytes(); + + /** + * Stops internal processing of commits. In progress and subsequent commits may be canceled or + * dropped. + */ + void stop(); + + /** + * Number of internal workers {@link WorkCommitter} uses to commit work to the backing persistence + * layer. + */ + int parallelism(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java index 0d4e7c6b645c..85c74fe8591d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.WindmillComputationKey; import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.util.Weighted; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; @@ -318,6 +319,10 @@ public void invalidate(ByteString processingKey, long shardingKey) { keyIndex.remove(key); } + public final void invalidate(ShardedKey shardedKey) { + invalidate(shardedKey.key(), shardedKey.shardingKey()); + } + /** * Returns a per-computation, per-key view of the state cache. Access to the cached data for * this key is not thread-safe. Callers should ensure that there is only a single ForKey object diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index e4985193d1cf..89939d5d3413 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -45,6 +46,7 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; +import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest; @@ -74,11 +76,12 @@ import org.slf4j.LoggerFactory; /** An in-memory Windmill server that offers provided work and data. */ -public class FakeWindmillServer extends WindmillServerStub { +public final class FakeWindmillServer extends WindmillServerStub { private static final Logger LOG = LoggerFactory.getLogger(FakeWindmillServer.class); private final ResponseQueue workToOffer; private final ResponseQueue dataToOffer; private final ResponseQueue commitsToOffer; + private final Map streamingCommitsToOffer; // Keys are work tokens. private final Map commitsReceived; private final ArrayList statsReceived; @@ -109,6 +112,7 @@ public FakeWindmillServer( commitsToOffer = new ResponseQueue() .returnByDefault(CommitWorkResponse.getDefaultInstance()); + streamingCommitsToOffer = new HashMap<>(); commitsReceived = new ConcurrentHashMap<>(); exceptions = new LinkedBlockingQueue<>(); expectedExceptionCount = new AtomicInteger(); @@ -139,6 +143,10 @@ public void sendFailedHeartbeats(List res return commitsToOffer; } + public Map whenCommitWorkStreamCalled() { + return streamingCommitsToOffer; + } + @Override public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request) { LOG.debug("getWorkRequest: {}", request.toString()); @@ -376,7 +384,15 @@ public boolean commitWorkItem( droppedStreamingCommits.put(request.getWorkToken(), onDone); } else { commitsReceived.put(request.getWorkToken(), request); - onDone.accept(Windmill.CommitStatus.OK); + onDone.accept( + Optional.ofNullable( + streamingCommitsToOffer.remove( + WorkId.builder() + .setWorkToken(request.getWorkToken()) + .setCacheToken(request.getCacheToken()) + .build())) + // Default to CommitStatus.OK + .orElse(Windmill.CommitStatus.OK)); } // Return true to indicate the request was accepted even if we are dropping the commit // to simulate a dropped commit. @@ -502,32 +518,32 @@ public void setIsReady(boolean ready) { this.isReady = ready; } - static class ResponseQueue { + public static class ResponseQueue { private final Queue> responses = new ConcurrentLinkedQueue<>(); Duration sleep = Duration.ZERO; private Function defaultResponse; // (Fluent) interface for response producers, accessible from tests. - ResponseQueue thenAnswer(Function mapFun) { + public ResponseQueue thenAnswer(Function mapFun) { responses.add(mapFun); return this; } - ResponseQueue thenReturn(U response) { + public ResponseQueue thenReturn(U response) { return thenAnswer((request) -> response); } - ResponseQueue answerByDefault(Function mapFun) { + public ResponseQueue answerByDefault(Function mapFun) { defaultResponse = mapFun; return this; } - ResponseQueue returnByDefault(U response) { + public ResponseQueue returnByDefault(U response) { return answerByDefault((request) -> response); } - ResponseQueue delayEachResponseBy(Duration sleep) { + public ResponseQueue delayEachResponseBy(Duration sleep) { this.sleep = sleep; return this; } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d00ea64d7d4d..d8ead447e8e5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3894,7 +3894,7 @@ private void runNumCommitThreadsTest(int configNumCommitThreads, int expectedNum options.setWindmillServiceCommitThreads(configNumCommitThreads); StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); worker.start(); - assertEquals(expectedNumCommitThreads, worker.commitThreads.size()); + assertEquals(expectedNumCommitThreads, worker.numCommitThreads()); worker.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java new file mode 100644 index 000000000000..cfad61385476 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotNull; + +import com.google.api.services.dataflow.model.MapTask; +import com.google.common.truth.Correspondence; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.FakeWindmillServer; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Instant; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class StreamingApplianceWorkCommitterTest { + @Rule public ErrorCollector errorCollector = new ErrorCollector(); + private FakeWindmillServer fakeWindmillServer; + private StreamingApplianceWorkCommitter workCommitter; + + private static Work createMockWork(long workToken, Consumer processWorkFn) { + return Work.create( + Windmill.WorkItem.newBuilder() + .setKey(ByteString.EMPTY) + .setWorkToken(workToken) + .setShardingKey(workToken) + .setCacheToken(workToken) + .build(), + Instant::now, + Collections.emptyList(), + processWorkFn); + } + + private static ComputationState createComputationState(String computationId) { + return new ComputationState( + computationId, + new MapTask().setSystemName("system").setStageName("stage"), + Mockito.mock(BoundedQueueExecutor.class), + ImmutableMap.of(), + null); + } + + private StreamingApplianceWorkCommitter createWorkCommitter( + Consumer onCommitComplete) { + return StreamingApplianceWorkCommitter.create(fakeWindmillServer::commitWork, onCommitComplete); + } + + @Before + public void setUp() { + fakeWindmillServer = + new FakeWindmillServer( + errorCollector, ignored -> Optional.of(Mockito.mock(ComputationState.class))); + } + + @After + public void cleanUp() { + workCommitter.stop(); + } + + @Test + public void testCommit() { + List completeCommits = new ArrayList<>(); + workCommitter = createWorkCommitter(completeCommits::add); + List commits = new ArrayList<>(); + for (int i = 1; i <= 5; i++) { + Work work = createMockWork(i, ignored -> {}); + Windmill.WorkItemCommitRequest commitRequest = + Windmill.WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + commits.add(Commit.create(commitRequest, createComputationState("computationId-" + i), work)); + } + + workCommitter.start(); + commits.forEach(workCommitter::commit); + + Map committed = + fakeWindmillServer.waitForAndGetCommits(commits.size()); + + for (Commit commit : commits) { + Windmill.WorkItemCommitRequest request = + committed.get(commit.work().getWorkItem().getWorkToken()); + assertNotNull(request); + assertThat(request).isEqualTo(commit.request()); + } + + assertThat(completeCommits).hasSize(commits.size()); + assertThat(completeCommits) + .comparingElementsUsing( + Correspondence.from( + (CompleteCommit completeCommit, Commit commit) -> + completeCommit.computationId().equals(commit.computationId()) + && completeCommit.status() == Windmill.CommitStatus.OK + && completeCommit.workId().equals(commit.work().id()) + && completeCommit + .shardedKey() + .equals( + ShardedKey.create( + commit.request().getKey(), commit.request().getShardingKey())), + "expected to equal")) + .containsExactlyElementsIn(commits); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java new file mode 100644 index 000000000000..1bf2e44f9f0e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus.OK; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.api.services.dataflow.model.MapTask; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.FakeWindmillServer; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.WorkId; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class StreamingEngineWorkCommitterTest { + + @Rule public ErrorCollector errorCollector = new ErrorCollector(); + private StreamingEngineWorkCommitter workCommitter; + private FakeWindmillServer fakeWindmillServer; + private Supplier> commitWorkStreamFactory; + + private static Work createMockWork(long workToken, Consumer processWorkFn) { + return Work.create( + Windmill.WorkItem.newBuilder() + .setKey(ByteString.EMPTY) + .setWorkToken(workToken) + .setShardingKey(workToken) + .setCacheToken(workToken) + .build(), + Instant::now, + Collections.emptyList(), + processWorkFn); + } + + private static ComputationState createComputationState(String computationId) { + return new ComputationState( + computationId, + new MapTask().setSystemName("system").setStageName("stage"), + Mockito.mock(BoundedQueueExecutor.class), + ImmutableMap.of(), + null); + } + + private static CompleteCommit asCompleteCommit(Commit commit, Windmill.CommitStatus status) { + if (commit.work().isFailed()) { + return CompleteCommit.forFailedWork(commit); + } + + return CompleteCommit.create(commit, status); + } + + @Before + public void setUp() throws IOException { + fakeWindmillServer = + new FakeWindmillServer( + errorCollector, ignored -> Optional.of(Mockito.mock(ComputationState.class))); + commitWorkStreamFactory = + WindmillStreamPool.create( + 1, Duration.standardMinutes(1), fakeWindmillServer::commitWorkStream) + ::getCloseableStream; + } + + @After + public void cleanUp() { + workCommitter.stop(); + } + + private StreamingEngineWorkCommitter createWorkCommitter( + Consumer onCommitComplete) { + return StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 1, onCommitComplete); + } + + @Test + public void testCommit_sendsCommitsToStreamingEngine() { + Set completeCommits = new HashSet<>(); + workCommitter = createWorkCommitter(completeCommits::add); + List commits = new ArrayList<>(); + for (int i = 1; i <= 5; i++) { + Work work = createMockWork(i, ignored -> {}); + WorkItemCommitRequest commitRequest = + WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + commits.add(Commit.create(commitRequest, createComputationState("computationId-" + i), work)); + } + + workCommitter.start(); + commits.parallelStream().forEach(workCommitter::commit); + + Map committed = + fakeWindmillServer.waitForAndGetCommits(commits.size()); + + for (Commit commit : commits) { + WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + assertNotNull(request); + assertThat(request).isEqualTo(commit.request()); + assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + } + } + + @Test + public void testCommit_handlesFailedCommits() { + Set completeCommits = new HashSet<>(); + workCommitter = createWorkCommitter(completeCommits::add); + List commits = new ArrayList<>(); + for (int i = 1; i <= 10; i++) { + Work work = createMockWork(i, ignored -> {}); + // Fail half of the work. + if (i % 2 == 0) { + work.setFailed(); + } + WorkItemCommitRequest commitRequest = + WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + commits.add(Commit.create(commitRequest, createComputationState("computationId-" + i), work)); + } + + workCommitter.start(); + commits.parallelStream().forEach(workCommitter::commit); + + Map committed = + fakeWindmillServer.waitForAndGetCommits(commits.size() / 2); + + for (Commit commit : commits) { + if (commit.work().isFailed()) { + assertThat(completeCommits) + .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED)); + assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken()); + } else { + assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); + assertThat(committed) + .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); + } + } + } + + @Test + public void testCommit_handlesCompleteCommits_commitStatusNotOK() { + Set completeCommits = new HashSet<>(); + workCommitter = createWorkCommitter(completeCommits::add); + Map expectedCommitStatus = new HashMap<>(); + Random commitStatusSelector = new Random(); + int commitStatusSelectorBound = Windmill.CommitStatus.values().length - 1; + // Compute the CommitStatus randomly, to test plumbing of different commitStatuses to + // StreamingEngine. + Function computeCommitStatusForTest = + work -> { + Windmill.CommitStatus commitStatus = + work.getWorkItem().getWorkToken() % 2 == 0 + ? Windmill.CommitStatus.values()[ + commitStatusSelector.nextInt(commitStatusSelectorBound)] + : OK; + expectedCommitStatus.put(work.id(), commitStatus); + return commitStatus; + }; + + List commits = new ArrayList<>(); + for (int i = 1; i <= 10; i++) { + Work work = createMockWork(i, ignored -> {}); + WorkItemCommitRequest commitRequest = + WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + commits.add(Commit.create(commitRequest, createComputationState("computationId-" + i), work)); + fakeWindmillServer + .whenCommitWorkStreamCalled() + .put(work.id(), computeCommitStatusForTest.apply(work)); + } + + workCommitter.start(); + commits.parallelStream().forEach(workCommitter::commit); + + Map committed = + fakeWindmillServer.waitForAndGetCommits(commits.size()); + + for (Commit commit : commits) { + WorkItemCommitRequest request = committed.get(commit.work().getWorkItem().getWorkToken()); + assertNotNull(request); + assertThat(request).isEqualTo(commit.request()); + assertThat(completeCommits) + .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); + } + assertThat(completeCommits.size()).isEqualTo(commits.size()); + } + + @Test + public void testStop_drainsCommitQueue() { + // Use this fake to queue up commits on the committer. + Supplier fakeCommitWorkStream = + () -> + new CommitWorkStream() { + @Override + public boolean commitWorkItem( + String computation, + WorkItemCommitRequest request, + Consumer onDone) { + return false; + } + + @Override + public void flush() {} + + @Override + public void close() {} + + @Override + public boolean awaitTermination(int time, TimeUnit unit) { + return false; + } + + @Override + public Instant startTime() { + return Instant.now(); + } + }; + commitWorkStreamFactory = + WindmillStreamPool.create(1, Duration.standardMinutes(1), fakeCommitWorkStream) + ::getCloseableStream; + + Set completeCommits = new HashSet<>(); + workCommitter = createWorkCommitter(completeCommits::add); + + List commits = new ArrayList<>(); + for (int i = 1; i <= 10; i++) { + Work work = createMockWork(i, ignored -> {}); + WorkItemCommitRequest commitRequest = + WorkItemCommitRequest.newBuilder() + .setKey(work.getWorkItem().getKey()) + .setShardingKey(work.getWorkItem().getShardingKey()) + .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) + .build(); + commits.add(Commit.create(commitRequest, createComputationState("computationId-" + i), work)); + } + + workCommitter.start(); + commits.parallelStream().forEach(workCommitter::commit); + workCommitter.stop(); + + assertThat(commits.size()).isEqualTo(completeCommits.size()); + for (CompleteCommit completeCommit : completeCommits) { + assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED); + } + + for (Commit commit : commits) { + assertTrue(commit.work().isFailed()); + } + } +}