diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java index 0e929249b3a1..800504f44515 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java @@ -27,7 +27,7 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; @@ -239,25 +239,37 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) } /** Tells windmill processing is ongoing for the given keys. */ - public void refreshActiveWork(Map> active) { - activeHeartbeats.set(active.size()); + public void refreshActiveWork(Map> heartbeats) { + activeHeartbeats.set(heartbeats.size()); try { if (useStreamingRequests) { // With streaming requests, always send the request even when it is empty, to ensure that // we trigger health checks for the stream even when it is idle. GetDataStream stream = streamPool.getStream(); try { - stream.refreshActiveWork(active); + stream.refreshActiveWork(heartbeats); } finally { streamPool.releaseStream(stream); } - } else if (!active.isEmpty()) { + } else if (!heartbeats.isEmpty()) { + // This code path is only used by appliance which sends heartbeats (used to refresh active + // work) as KeyedGetDataRequests. So we must translate the HeartbeatRequest to a + // KeyedGetDataRequest here regardless of the value of sendKeyedGetDataRequests. Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); - for (Map.Entry> entry : active.entrySet()) { - builder.addRequests( - Windmill.ComputationGetDataRequest.newBuilder() - .setComputationId(entry.getKey()) - .addAllRequests(entry.getValue())); + for (Map.Entry> entry : heartbeats.entrySet()) { + Windmill.ComputationGetDataRequest.Builder perComputationBuilder = + Windmill.ComputationGetDataRequest.newBuilder(); + perComputationBuilder.setComputationId(entry.getKey()); + for (HeartbeatRequest request : entry.getValue()) { + perComputationBuilder.addRequests( + Windmill.KeyedGetDataRequest.newBuilder() + .setShardingKey(request.getShardingKey()) + .setWorkToken(request.getWorkToken()) + .setCacheToken(request.getCacheToken()) + .addAllLatencyAttribution(request.getLatencyAttributionList()) + .build()); + } + builder.addRequests(perComputationBuilder.build()); } server.getData(builder.build()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java index d0931e02cc87..be0bccec0265 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java @@ -112,6 +112,14 @@ protected PubsubReaderIterator(Windmill.WorkItem work) { super(work); } + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + return false; + } + return super.advance(); + } + @Override protected WindowedValue decodeMessage(Windmill.Message message) throws IOException { T value; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index f2d7c02729c5..a95e78288819 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -85,6 +85,7 @@ import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; +import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens; import org.apache.beam.runners.dataflow.worker.streaming.Commit; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState; @@ -422,6 +423,7 @@ public void run() { this.publishCounters = publishCounters; this.windmillServer = options.getWindmillServerStub(); + this.windmillServer.setProcessHeartbeatResponses(this::handleHeartbeatResponses); this.metricTrackingWindmillServer = new MetricTrackingWindmillServerStub(windmillServer, memoryMonitor, windmillServiceEnabled); this.metricTrackingWindmillServer.start(); @@ -982,6 +984,9 @@ private void process( String counterName = "dataflow_source_bytes_processed-" + mapTask.getSystemName(); try { + if (work.isFailed()) { + throw new WorkItemCancelledException(workItem.getShardingKey()); + } executionState = computationState.getExecutionStateQueue().poll(); if (executionState == null) { MutableNetwork mapTaskNetwork = mapTaskToNetwork.apply(mapTask); @@ -1098,7 +1103,8 @@ public void close() { work.setState(State.PROCESSING); } }; - }); + }, + work::isFailed); SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView(); // If the read output KVs, then we can decode Windmill's byte key into a userland @@ -1136,12 +1142,16 @@ public void close() { synchronizedProcessingTime, stateReader, localSideInputStateFetcher, - outputBuilder); + outputBuilder, + work::isFailed); // Blocks while executing work. executionState.workExecutor().execute(); - // Reports source bytes processed to workitemcommitrequest if available. + if (work.isFailed()) { + throw new WorkItemCancelledException(workItem.getShardingKey()); + } + // Reports source bytes processed to WorkItemCommitRequest if available. try { long sourceBytesProcessed = 0; HashMap counters = @@ -1234,6 +1244,12 @@ public void close() { + "Work will not be retried locally.", computationId, key.toStringUtf8()); + } else if (WorkItemCancelledException.isWorkItemCancelledException(t)) { + LOG.debug( + "Execution of work for computation '{}' on key '{}' failed. " + + "Work will not be retried locally.", + computationId, + workItem.getShardingKey()); } else { LastExceptionDataProvider.reportException(t); LOG.debug("Failed work: {}", work); @@ -1369,6 +1385,10 @@ private void commitLoop() { // Adds the commit to the commitStream if it fits, returning true iff it is consumed. private boolean addCommitToStream(Commit commit, CommitWorkStream commitStream) { Preconditions.checkNotNull(commit); + // Drop commits for failed work. Such commits will be dropped by Windmill anyway. + if (commit.work().isFailed()) { + return true; + } final ComputationState state = commit.computationState(); final Windmill.WorkItemCommitRequest request = commit.request(); final int size = commit.getSize(); @@ -1896,6 +1916,25 @@ private void sendWorkerUpdatesToDataflowService( } } + public void handleHeartbeatResponses(List responses) { + for (Windmill.ComputationHeartbeatResponse computationHeartbeatResponse : responses) { + // Maps sharding key to (work token, cache token) for work that should be marked failed. + Map> failedWork = new HashMap<>(); + for (Windmill.HeartbeatResponse heartbeatResponse : + computationHeartbeatResponse.getHeartbeatResponsesList()) { + if (heartbeatResponse.getFailed()) { + failedWork + .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new ArrayList<>()) + .add( + new FailedTokens( + heartbeatResponse.getWorkToken(), heartbeatResponse.getCacheToken())); + } + } + ComputationState state = computationMap.get(computationHeartbeatResponse.getComputationId()); + if (state != null) state.failWork(failedWork); + } + } + /** * Sends a GetData request to Windmill for all sufficiently old active work. * @@ -1904,15 +1943,15 @@ private void sendWorkerUpdatesToDataflowService( * StreamingDataflowWorkerOptions#getActiveWorkRefreshPeriodMillis}. */ private void refreshActiveWork() { - Map> active = new HashMap<>(); + Map> heartbeats = new HashMap<>(); Instant refreshDeadline = clock.get().minus(Duration.millis(options.getActiveWorkRefreshPeriodMillis())); for (Map.Entry entry : computationMap.entrySet()) { - active.put(entry.getKey(), entry.getValue().getKeysToRefresh(refreshDeadline, sampler)); + heartbeats.put(entry.getKey(), entry.getValue().getKeyHeartbeats(refreshDeadline, sampler)); } - metricTrackingWindmillServer.refreshActiveWork(active); + metricTrackingWindmillServer.refreshActiveWork(heartbeats); } private void invalidateStuckCommits() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index d630601c28a3..83cf49112a8d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -112,6 +112,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext activeReader; private volatile long backlogBytes; + private Supplier workIsFailed; public StreamingModeExecutionContext( CounterFactory counterFactory, @@ -135,6 +136,7 @@ public StreamingModeExecutionContext( this.stateNameMap = ImmutableMap.copyOf(stateNameMap); this.stateCache = stateCache; this.backlogBytes = UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN; + this.workIsFailed = () -> Boolean.FALSE; } @VisibleForTesting @@ -142,6 +144,10 @@ public long getBacklogBytes() { return backlogBytes; } + public boolean workIsFailed() { + return workIsFailed.get(); + } + public void start( @Nullable Object key, Windmill.WorkItem work, @@ -150,9 +156,11 @@ public void start( @Nullable Instant synchronizedProcessingTime, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) { + Windmill.WorkItemCommitRequest.Builder outputBuilder, + @Nullable Supplier workFailed) { this.key = key; this.work = work; + this.workIsFailed = (workFailed != null) ? workFailed : () -> Boolean.FALSE; this.computationKey = WindmillComputationKey.create(computationId, work.getKey(), work.getShardingKey()); this.sideInputStateFetcher = sideInputStateFetcher; @@ -429,7 +437,7 @@ void writePCollectionViewData( /** * Execution states in Streaming are shared between multiple map-task executors. Thus this class - * needs to be thread safe for multiple writers. A single stage could have have multiple executors + * needs to be thread safe for multiple writers. A single stage could have multiple executors * running concurrently. */ public static class StreamingModeExecutionState @@ -670,7 +678,7 @@ class StepContext extends DataflowExecutionContext.DataflowStepContext private NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = null; // A list of timer keys that were modified by user processing earlier in this bundle. This // serves a tombstone, so - // that we know not to fire any bundle tiemrs that were moddified. + // that we know not to fire any bundle timers that were modified. private Table modifiedUserTimerKeys = null; public StepContext(DataflowOperationContext operationContext) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java index e4e56a96c15a..4aac93ceb3fa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java @@ -99,6 +99,14 @@ class UngroupedWindmillReaderIterator extends WindmillReaderIteratorBase { super(work); } + @Override + public boolean advance() throws IOException { + if (context.workIsFailed()) { + return false; + } + return super.advance(); + } + @Override protected WindowedValue decodeMessage(Windmill.Message message) throws IOException { Instant timestampMillis = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java new file mode 100644 index 000000000000..934977fe0985 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +/** Indicates that the work item was cancelled and should not be retried. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class WorkItemCancelledException extends RuntimeException { + public WorkItemCancelledException(long sharding_key) { + super("Work item cancelled for key " + sharding_key); + } + + /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ + public static boolean isWorkItemCancelledException(Throwable t) { + while (t != null) { + if (t instanceof WorkItemCancelledException) { + return true; + } + t = t.getCause(); + } + return false; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java index a9050236efc8..2dc3494af5e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java @@ -836,7 +836,8 @@ public boolean advance() throws IOException { while (true) { if (elemsRead >= maxElems || Instant.now().isAfter(endTime) - || context.isSinkFullHintSet()) { + || context.isSinkFullHintSet() + || context.workIsFailed()) { return false; } try { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index 16266de9d47c..54942dfeee1f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -23,6 +23,7 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -34,8 +35,10 @@ import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -50,7 +53,8 @@ * activate, queue, and complete {@link Work} (including invalidating stuck {@link Work}). */ @ThreadSafe -final class ActiveWorkState { +@Internal +public final class ActiveWorkState { private static final Logger LOG = LoggerFactory.getLogger(ActiveWorkState.class); /* The max number of keys in COMMITTING or COMMIT_QUEUED status to be shown.*/ @@ -120,6 +124,50 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w return ActivateWorkResult.QUEUED; } + public static final class FailedTokens { + public long workToken; + public long cacheToken; + + public FailedTokens(long workToken, long cacheToken) { + this.workToken = workToken; + this.cacheToken = cacheToken; + } + } + + /** + * Fails any active work matching an element of the input Map. + * + * @param failedWork a map from sharding_key to tokens for the corresponding work. + */ + synchronized void failWorkForKey(Map> failedWork) { + // Note we can't construct a ShardedKey and look it up in activeWork directly since + // HeartbeatResponse doesn't include the user key. + for (Entry> entry : activeWork.entrySet()) { + List failedTokens = failedWork.get(entry.getKey().shardingKey()); + if (failedTokens == null) continue; + for (FailedTokens failedToken : failedTokens) { + for (Work queuedWork : entry.getValue()) { + WorkItem workItem = queuedWork.getWorkItem(); + if (workItem.getWorkToken() == failedToken.workToken + && workItem.getCacheToken() == failedToken.cacheToken) { + LOG.debug( + "Failing work " + + computationStateCache.getComputation() + + " " + + entry.getKey().shardingKey() + + " " + + failedToken.workToken + + " " + + failedToken.cacheToken + + ". The work will be retried and is not lost."); + queuedWork.setFailed(); + break; + } + } + } + } + } + /** * Removes the complete work from the {@link Queue}. The {@link Work} is marked as completed * if its workToken matches the one that is passed in. Returns the next {@link Work} in the {@link @@ -211,14 +259,14 @@ private synchronized ImmutableMap getStuckCommitsAt( return stuckCommits.build(); } - synchronized ImmutableList getKeysToRefresh( + synchronized ImmutableList getKeyHeartbeats( Instant refreshDeadline, DataflowExecutionStateSampler sampler) { return activeWork.entrySet().stream() - .flatMap(entry -> toKeyedGetDataRequestStream(entry, refreshDeadline, sampler)) + .flatMap(entry -> toHeartbeatRequestStream(entry, refreshDeadline, sampler)) .collect(toImmutableList()); } - private static Stream toKeyedGetDataRequestStream( + private static Stream toHeartbeatRequestStream( Entry> shardedKeyAndWorkQueue, Instant refreshDeadline, DataflowExecutionStateSampler sampler) { @@ -227,12 +275,14 @@ private static Stream toKeyedGetDataRequestStream( return workQueue.stream() .filter(work -> work.getStartTime().isBefore(refreshDeadline)) + // Don't send heartbeats for queued work we already know is failed. + .filter(work -> !work.isFailed()) .map( work -> - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(shardedKey.key()) + Windmill.HeartbeatRequest.newBuilder() .setShardingKey(shardedKey.shardingKey()) .setWorkToken(work.getWorkItem().getWorkToken()) + .setCacheToken(work.getWorkItem().getCacheToken()) .addAllLatencyAttribution( work.getLatencyAttributions(true, work.getLatencyTrackingId(), sampler)) .build()); @@ -250,7 +300,7 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) { for (Map.Entry> entry : activeWork.entrySet()) { Queue workQueue = Preconditions.checkNotNull(entry.getValue()); Work activeWork = Preconditions.checkNotNull(workQueue.peek()); - Windmill.WorkItem workItem = activeWork.getWorkItem(); + WorkItem workItem = activeWork.getWorkItem(); if (activeWork.isCommitPending()) { if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) { continue; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index 4ac1d8bc9fac..8207a6ef2f09 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -19,12 +19,14 @@ import com.google.api.services.dataflow.model.MapTask; import java.io.PrintWriter; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentLinkedQueue; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -98,6 +100,10 @@ public boolean activateWork(ShardedKey shardedKey, Work work) { } } + public void failWork(Map> failedWork) { + activeWorkState.failWorkForKey(failedWork); + } + /** * Marks the work for the given shardedKey as complete. Schedules queued work for the key if any. */ @@ -120,10 +126,10 @@ private void forceExecute(Work work) { executor.forceExecute(work, work.getWorkItem().getSerializedSize()); } - /** Adds any work started before the refreshDeadline to the GetDataRequest builder. */ - public ImmutableList getKeysToRefresh( + /** Gets HeartbeatRequests for any work started before refreshDeadline. */ + public ImmutableList getKeyHeartbeats( Instant refreshDeadline, DataflowExecutionStateSampler sampler) { - return activeWorkState.getKeysToRefresh(refreshDeadline, sampler); + return activeWorkState.getKeyHeartbeats(refreshDeadline, sampler); } public void printActiveWork(PrintWriter writer) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 3a77a8322b4b..69f2a0dcee76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -50,6 +50,8 @@ public class Work implements Runnable { private final Consumer processWorkFn; private TimedState currentState; + private boolean isFailed; + private Work(Windmill.WorkItem workItem, Supplier clock, Consumer processWorkFn) { this.workItem = workItem; this.clock = clock; @@ -57,6 +59,7 @@ private Work(Windmill.WorkItem workItem, Supplier clock, Consumer this.startTime = clock.get(); this.totalDurationPerState = new EnumMap<>(Windmill.LatencyAttribution.State.class); this.currentState = TimedState.initialState(startTime); + this.isFailed = false; } public static Work create( @@ -95,6 +98,10 @@ public void setState(State state) { this.currentState = TimedState.create(state, now); } + public void setFailed() { + this.isFailed = true; + } + public boolean isCommitPending() { return currentState.isCommitPending(); } @@ -180,6 +187,10 @@ private static LatencyAttribution.Builder addActiveLatencyBreakdownToBuilder( return builder; } + public boolean isFailed() { + return isFailed; + } + boolean isStuckCommittingAt(Instant stuckCommitDeadline) { return currentState.state() == Work.State.COMMITTING && currentState.startTime().isBefore(stuckCommitDeadline); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index c327e68d7e91..25581bee2089 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -19,8 +19,11 @@ import java.io.IOException; import java.io.PrintWriter; +import java.util.List; import java.util.Set; +import java.util.function.Consumer; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; @@ -79,6 +82,9 @@ public abstract GetWorkStream getWorkStream( @Override public void appendSummaryHtml(PrintWriter writer) {} + public void setProcessHeartbeatResponses( + Consumer> processHeartbeatResponses) {} + /** Generic Exception type for implementors to use to represent errors while making RPCs. */ public static final class RpcException extends RuntimeException { public RpcException(Throwable cause) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index fa1f797a1911..7c22f4fb5765 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.joda.time.Instant; @@ -59,7 +60,9 @@ Windmill.KeyedGetDataResponse requestKeyedData( Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request); /** Tells windmill processing is ongoing for the given keys. */ - void refreshActiveWork(Map> active); + void refreshActiveWork(Map> heartbeats); + + void onHeartbeatResponse(List responses); } /** Interface for streaming CommitWorkRequests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index a04a961ca9c2..b6600e04a09d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -32,10 +32,15 @@ import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; @@ -64,6 +69,10 @@ public final class GrpcGetDataStream private final ThrottleTimer getDataThrottleTimer; private final JobHeader jobHeader; private final int streamingRpcBatchLimit; + // If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the + // newer ComputationHeartbeatRequests. + private final boolean sendKeyedGetDataRequests; + private Consumer> processHeartbeatResponses; private GrpcGetDataStream( Function, StreamObserver> @@ -75,7 +84,9 @@ private GrpcGetDataStream( ThrottleTimer getDataThrottleTimer, JobHeader jobHeader, AtomicLong idGenerator, - int streamingRpcBatchLimit) { + int streamingRpcBatchLimit, + boolean sendKeyedGetDataRequests, + Consumer> processHeartbeatResponses) { super( startGetDataRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); this.idGenerator = idGenerator; @@ -84,6 +95,8 @@ private GrpcGetDataStream( this.streamingRpcBatchLimit = streamingRpcBatchLimit; this.batches = new ConcurrentLinkedDeque<>(); this.pending = new ConcurrentHashMap<>(); + this.sendKeyedGetDataRequests = sendKeyedGetDataRequests; + this.processHeartbeatResponses = processHeartbeatResponses; } public static GrpcGetDataStream create( @@ -96,7 +109,9 @@ public static GrpcGetDataStream create( ThrottleTimer getDataThrottleTimer, JobHeader jobHeader, AtomicLong idGenerator, - int streamingRpcBatchLimit) { + int streamingRpcBatchLimit, + boolean sendKeyedGetDataRequests, + Consumer> processHeartbeatResponses) { GrpcGetDataStream getDataStream = new GrpcGetDataStream( startGetDataRpcFn, @@ -107,7 +122,9 @@ public static GrpcGetDataStream create( getDataThrottleTimer, jobHeader, idGenerator, - streamingRpcBatchLimit); + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); getDataStream.startStream(); return getDataStream; } @@ -138,6 +155,7 @@ protected void onResponse(StreamingGetDataResponse chunk) { checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); getDataThrottleTimer.stop(); + onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); @@ -171,30 +189,71 @@ public GlobalData requestGlobalData(GlobalDataRequest request) { } @Override - public void refreshActiveWork(Map> active) { - long builderBytes = 0; + public void refreshActiveWork(Map> heartbeats) { StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (Map.Entry> entry : active.entrySet()) { - for (KeyedGetDataRequest request : entry.getValue()) { - // Calculate the bytes with some overhead for proto encoding. - long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; - if (builderBytes > 0 - && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE - || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { - send(builder.build()); - builderBytes = 0; - builder.clear(); + if (sendKeyedGetDataRequests) { + long builderBytes = 0; + for (Map.Entry> entry : heartbeats.entrySet()) { + for (HeartbeatRequest request : entry.getValue()) { + // Calculate the bytes with some overhead for proto encoding. + long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; + if (builderBytes > 0 + && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE + || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { + send(builder.build()); + builderBytes = 0; + builder.clear(); + } + builderBytes += bytes; + builder.addStateRequest( + ComputationGetDataRequest.newBuilder() + .setComputationId(entry.getKey()) + .addRequests( + Windmill.KeyedGetDataRequest.newBuilder() + .setShardingKey(request.getShardingKey()) + .setWorkToken(request.getWorkToken()) + .setCacheToken(request.getCacheToken()) + .addAllLatencyAttribution(request.getLatencyAttributionList()) + .build())); } - builderBytes += bytes; - builder.addStateRequest( - ComputationGetDataRequest.newBuilder() - .setComputationId(entry.getKey()) - .addRequests(request)); + } + + if (builderBytes > 0) { + send(builder.build()); + } + } else { + // No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`. + long builderBytes = 0; + for (Map.Entry> entry : heartbeats.entrySet()) { + ComputationHeartbeatRequest.Builder computationHeartbeatBuilder = + ComputationHeartbeatRequest.newBuilder().setComputationId(entry.getKey()); + for (HeartbeatRequest request : entry.getValue()) { + long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; + if (builderBytes > 0 + && builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) { + builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build()); + } + send(builder.build()); + builderBytes = 0; + builder.clear(); + computationHeartbeatBuilder.clear().setComputationId(entry.getKey()); + } + builderBytes += bytes; + computationHeartbeatBuilder.addHeartbeatRequests(request); + } + builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build()); + } + + if (builderBytes > 0) { + send(builder.build()); } } - if (builderBytes > 0) { - send(builder.build()); - } + } + + @Override + public void onHeartbeatResponse(List responses) { + processHeartbeatResponses.accept(responses); } @Override @@ -277,7 +336,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept waitForSendLatch.await(); } // Finalize the batch so that no additional requests will be added. Leave the batch in the - // queue so that a subsequent batch will wait for it's completion. + // queue so that a subsequent batch will wait for its completion. synchronized (batches) { verify(batch == batches.peekFirst()); batch.markFinalized(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 3a881df71462..9f0126a9cc69 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -28,13 +28,17 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; @@ -93,6 +97,10 @@ public final class GrpcWindmillServer extends WindmillServerStub { private final StreamingEngineThrottleTimers throttleTimers; private Duration maxBackoff; private @Nullable WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub; + // If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the + // newer ComputationHeartbeatRequests. + private final boolean sendKeyedGetDataRequests; + private Consumer> processHeartbeatResponses; private GrpcWindmillServer( StreamingDataflowWorkerOptions options, GrpcDispatcherClient grpcDispatcherClient) { @@ -118,9 +126,21 @@ private GrpcWindmillServer( this.dispatcherClient = grpcDispatcherClient; this.syncApplianceStub = null; + this.sendKeyedGetDataRequests = + !options.isEnableStreamingEngine() + || !DataflowRunner.hasExperiment( + options, "streaming_engine_send_new_heartbeat_requests"); + this.processHeartbeatResponses = (responses) -> {}; } - private static StreamingDataflowWorkerOptions testOptions(boolean enableStreamingEngine) { + @Override + public void setProcessHeartbeatResponses( + Consumer> processHeartbeatResponses) { + this.processHeartbeatResponses = processHeartbeatResponses; + }; + + private static StreamingDataflowWorkerOptions testOptions( + boolean enableStreamingEngine, List additionalExperiments) { StreamingDataflowWorkerOptions options = PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class); options.setProject("project"); @@ -131,6 +151,7 @@ private static StreamingDataflowWorkerOptions testOptions(boolean enableStreamin if (enableStreamingEngine) { experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT); } + experiments.addAll(additionalExperiments); options.setExperiments(experiments); options.setWindmillServiceStreamingRpcBatchLimit(Integer.MAX_VALUE); @@ -162,7 +183,7 @@ public static GrpcWindmillServer create(StreamingDataflowWorkerOptions workerOpt } @VisibleForTesting - static GrpcWindmillServer newTestInstance(String name) { + static GrpcWindmillServer newTestInstance(String name, List experiments) { ManagedChannel inProcessChannel = inProcessChannel(name); CloudWindmillServiceV1Alpha1Stub stub = CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel); @@ -173,14 +194,15 @@ static GrpcWindmillServer newTestInstance(String name) { WindmillStubFactory.inProcessStubFactory(name, unused -> inProcessChannel), dispatcherStubs, dispatcherEndpoints); - return new GrpcWindmillServer(testOptions(/* enableStreamingEngine= */ true), dispatcherClient); + return new GrpcWindmillServer( + testOptions(/* enableStreamingEngine= */ true, experiments), dispatcherClient); } @VisibleForTesting static GrpcWindmillServer newApplianceTestInstance(Channel channel) { GrpcWindmillServer testServer = new GrpcWindmillServer( - testOptions(/* enableStreamingEngine= */ false), + testOptions(/* enableStreamingEngine= */ false, new ArrayList<>()), // No-op, Appliance does not use Dispatcher to call Streaming Engine. GrpcDispatcherClient.create(WindmillStubFactory.inProcessStubFactory("test"))); testServer.syncApplianceStub = createWindmillApplianceStubWithDeadlineInterceptor(channel); @@ -319,7 +341,10 @@ public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver rece @Override public GetDataStream getDataStream() { return windmillStreamFactory.createGetDataStream( - dispatcherClient.getDispatcherStub(), throttleTimers.getDataThrottleTimer()); + dispatcherClient.getDispatcherStub(), + throttleTimers.getDataThrottleTimer(), + sendKeyedGetDataRequests, + this.processHeartbeatResponses); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 099be8db0fda..7dc43e791e31 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -21,6 +21,7 @@ import com.google.auto.value.AutoBuilder; import java.io.PrintWriter; +import java.util.List; import java.util.Set; import java.util.Timer; import java.util.TimerTask; @@ -32,6 +33,7 @@ import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; @@ -152,7 +154,10 @@ public GetWorkStream createDirectGetWorkStream( } public GetDataStream createGetDataStream( - CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer getDataThrottleTimer) { + CloudWindmillServiceV1Alpha1Stub stub, + ThrottleTimer getDataThrottleTimer, + boolean sendKeyedGetDataRequests, + Consumer> processHeartbeatResponses) { return GrpcGetDataStream.create( responseObserver -> withDeadline(stub).getDataStream(responseObserver), grpcBackOff.get(), @@ -162,7 +167,14 @@ public GetDataStream createGetDataStream( getDataThrottleTimer, jobHeader, streamIdGenerator, - streamingRpcBatchLimit); + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + + public GetDataStream createGetDataStream( + CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer getDataThrottleTimer) { + return createGetDataStream(stub, getDataThrottleTimer, false, (response) -> {}); } public CommitWorkStream createCommitWorkStream( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java index 6c1239d6ebd2..5a9e5443a506 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java @@ -296,6 +296,11 @@ private ForComputation(String computation) { this.computation = computation; } + /** Returns the computation associated to this class. */ + public String getComputation() { + return this.computation; + } + /** Invalidate all cache entries for this computation and {@code processingKey}. */ public void invalidate(ByteString processingKey, long shardingKey) { WindmillComputationKey key = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index c28939c59ee2..637b838c7fe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -39,6 +39,7 @@ import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.MetricTrackingWindmillServerStub; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; @@ -123,6 +124,7 @@ public class WindmillStateReader { private final MetricTrackingWindmillServerStub metricTrackingWindmillServerStub; private final ConcurrentHashMap, CoderAndFuture> waiting; private long bytesRead = 0L; + private final Supplier workItemIsFailed; public WindmillStateReader( MetricTrackingWindmillServerStub metricTrackingWindmillServerStub, @@ -130,7 +132,8 @@ public WindmillStateReader( ByteString key, long shardingKey, long workToken, - Supplier readWrapperSupplier) { + Supplier readWrapperSupplier, + Supplier workItemIsFailed) { this.metricTrackingWindmillServerStub = metricTrackingWindmillServerStub; this.computation = computation; this.key = key; @@ -139,6 +142,7 @@ public WindmillStateReader( this.readWrapperSupplier = readWrapperSupplier; this.waiting = new ConcurrentHashMap<>(); this.pendingLookups = new ConcurrentLinkedQueue<>(); + this.workItemIsFailed = workItemIsFailed; } public WindmillStateReader( @@ -147,7 +151,14 @@ public WindmillStateReader( ByteString key, long shardingKey, long workToken) { - this(metricTrackingWindmillServerStub, computation, key, shardingKey, workToken, () -> null); + this( + metricTrackingWindmillServerStub, + computation, + key, + shardingKey, + workToken, + () -> null, + () -> Boolean.FALSE); } private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { @@ -404,6 +415,9 @@ public void performReads() { private KeyedGetDataResponse tryGetDataFromWindmill(HashSet> stateTags) throws Exception { + if (workItemIsFailed.get()) { + throw new WorkItemCancelledException(shardingKey); + } KeyedGetDataRequest keyedGetDataRequest = createRequest(stateTags); try (AutoCloseable ignored = readWrapperSupplier.get()) { return Optional.ofNullable( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index a434b2001207..2cfec6d3139a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -46,8 +46,11 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State; @@ -80,9 +83,10 @@ class FakeWindmillServer extends WindmillServerStub { private final ErrorCollector errorCollector; private final ConcurrentHashMap> droppedStreamingCommits; private int commitsRequested = 0; - private List getDataRequests = new ArrayList<>(); + private final List getDataRequests = new ArrayList<>(); private boolean isReady = true; private boolean dropStreamingCommits = false; + private Consumer> processHeartbeatResponses; public FakeWindmillServer(ErrorCollector errorCollector) { workToOffer = @@ -91,7 +95,7 @@ public FakeWindmillServer(ErrorCollector errorCollector) { dataToOffer = new ResponseQueue() .returnByDefault(GetDataResponse.getDefaultInstance()) - // Sleep for a little bit to ensure that *-windmill-read state-sampled counters show up. + // Sleep for a bit to ensure that *-windmill-read state-sampled counters show up. .delayEachResponseBy(Duration.millis(500)); commitsToOffer = new ResponseQueue() @@ -102,6 +106,13 @@ public FakeWindmillServer(ErrorCollector errorCollector) { this.errorCollector = errorCollector; statsReceived = new ArrayList<>(); droppedStreamingCommits = new ConcurrentHashMap<>(); + processHeartbeatResponses = (responses) -> {}; + } + + @Override + public void setProcessHeartbeatResponses( + Consumer> processHeartbeatResponses) { + this.processHeartbeatResponses = processHeartbeatResponses; } public void setDropStreamingCommits(boolean dropStreamingCommits) { @@ -116,6 +127,10 @@ public ResponseQueue whenGetDataCalled() { return dataToOffer; } + public void sendFailedHeartbeats(List responses) { + getDataStream().onHeartbeatResponse(responses); + } + public ResponseQueue whenCommitWorkCalled() { return commitsToOffer; @@ -304,17 +319,23 @@ public Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) } @Override - public void refreshActiveWork(Map> active) { + public void refreshActiveWork(Map> heartbeats) { Windmill.GetDataRequest.Builder builder = Windmill.GetDataRequest.newBuilder(); - for (Map.Entry> entry : active.entrySet()) { - builder.addRequests( - ComputationGetDataRequest.newBuilder() + for (Map.Entry> entry : heartbeats.entrySet()) { + builder.addComputationHeartbeatRequest( + ComputationHeartbeatRequest.newBuilder() .setComputationId(entry.getKey()) - .addAllRequests(entry.getValue())); + .addAllHeartbeatRequests(entry.getValue())); } + getData(builder.build()); } + @Override + public void onHeartbeatResponse(List responses) { + processHeartbeatResponses.accept(responses); + } + @Override public void close() {} @@ -383,6 +404,18 @@ public void waitForEmptyWorkQueue() { } } + public Map waitForAndGetCommitsWithTimeout( + int numCommits, Duration timeout) { + LOG.debug("waitForAndGetCommitsWithTimeout: {} {}", numCommits, timeout); + Instant waitStart = Instant.now(); + while (commitsReceived.size() < commitsRequested + numCommits + && Instant.now().isBefore(waitStart.plus(timeout))) { + Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); + } + commitsRequested += numCommits; + return commitsReceived; + } + public Map waitForAndGetCommits(int numCommits) { LOG.debug("waitForAndGetCommitsRequest: {}", numCommits); int maxTries = 10; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 31a9af9004a8..9526c96fd04e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -109,9 +109,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.InputMessageBundle; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; @@ -577,8 +580,9 @@ private Windmill.GetWorkResponse makeInput( } /** - * Returns a {@link org.apache.beam.runners.dataflow.windmill.Windmill.WorkItemCommitRequest} - * builder parsed from the provided text format proto. + * Returns a {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest} builder parsed + * from the provided text format proto. */ private WorkItemCommitRequest.Builder parseCommitRequest(String output) throws Exception { WorkItemCommitRequest.Builder builder = Windmill.WorkItemCommitRequest.newBuilder(); @@ -3258,6 +3262,49 @@ public void testActiveWorkRefresh() throws Exception { assertThat(server.numGetDataRequests(), greaterThan(0)); } + @Test + public void testActiveWorkFailure() throws Exception { + List instructions = + Arrays.asList( + makeSourceInstruction(StringUtf8Coder.of()), + makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()), + makeSinkInstruction(StringUtf8Coder.of(), 0)); + + FakeWindmillServer server = new FakeWindmillServer(errorCollector); + StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server); + options.setActiveWorkRefreshPeriodMillis(100); + StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */); + worker.start(); + + // Queue up two work items for the same key. + server + .whenGetWorkCalled() + .thenReturn(makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", DEFAULT_SHARDING_KEY)) + .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key", DEFAULT_SHARDING_KEY)); + server.waitForEmptyWorkQueue(); + + // Mock Windmill sending a heartbeat response failing the second work item while the first + // is still processing. + ComputationHeartbeatResponse.Builder failedHeartbeat = + ComputationHeartbeatResponse.newBuilder(); + failedHeartbeat + .setComputationId(DEFAULT_COMPUTATION_ID) + .addHeartbeatResponsesBuilder() + .setCacheToken(3) + .setWorkToken(1) + .setShardingKey(DEFAULT_SHARDING_KEY) + .setFailed(true); + server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build())); + + // Release the blocked calls. + BlockingFn.blocker.countDown(); + Map commits = + server.waitForAndGetCommitsWithTimeout(2, Duration.standardSeconds((5))); + assertEquals(1, commits.size()); + + worker.stop(); + } + @Test public void testLatencyAttributionProtobufsPopulated() { FakeClock clock = new FakeClock(); @@ -3573,7 +3620,10 @@ public void testDoFnActiveMessageMetadataReportedOnHeartbeat() throws Exception Windmill.GetDataRequest heartbeat = server.getGetDataRequests().get(2); for (LatencyAttribution la : - heartbeat.getRequests(0).getRequests(0).getLatencyAttributionList()) { + heartbeat + .getComputationHeartbeatRequest(0) + .getHeartbeatRequests(0) + .getLatencyAttributionList()) { if (la.getState() == State.ACTIVE) { assertTrue(la.getActiveLatencyBreakdownCount() > 0); assertTrue(la.getActiveLatencyBreakdown(0).hasActiveMessageMetadata()); @@ -3768,7 +3818,7 @@ public void testStuckCommit() throws Exception { server .whenGetWorkCalled() .thenReturn(makeInput(1, TimeUnit.MILLISECONDS.toMicros(1), DEFAULT_KEY_STRING, 1)); - // Ensure that the this work item processes. + // Ensure that this work item processes. Map result = server.waitForAndGetCommits(1); // Now ensure that nothing happens if a dropped commit actually completes. droppedCommits.values().iterator().next().accept(CommitStatus.OK); @@ -4129,7 +4179,7 @@ public void run() { FakeClock.this.schedule(Duration.millis(unit.toMillis(delay)), this); } }); - FakeClock.this.sleep(Duration.ZERO); // Execute work that has an intial delay of zero. + FakeClock.this.sleep(Duration.ZERO); // Execute work that has an initial delay of zero. return null; } } @@ -4167,6 +4217,7 @@ Duration getLatencyAttributionDuration(long workToken, LatencyAttribution.State } boolean isActiveWorkRefresh(GetDataRequest request) { + if (request.getComputationHeartbeatRequestCount() > 0) return true; for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { if (!computationRequest.getComputationId().equals(DEFAULT_COMPUTATION_ID)) { return false; @@ -4203,6 +4254,21 @@ GetDataResponse getData(GetDataRequest request) { } } } + for (ComputationHeartbeatRequest heartbeatRequest : + request.getComputationHeartbeatRequestList()) { + for (HeartbeatRequest heartbeat : heartbeatRequest.getHeartbeatRequestsList()) { + for (LatencyAttribution la : heartbeat.getLatencyAttributionList()) { + EnumMap durations = + totalDurations.computeIfAbsent( + heartbeat.getWorkToken(), + (Long workToken) -> + new EnumMap( + LatencyAttribution.State.class)); + Duration cur = Duration.millis(la.getTotalDurationMillis()); + durations.compute(la.getState(), (s, d) -> d == null || d.isShorterThan(cur) ? cur : d); + } + } + } return EMPTY_DATA_RESPONDER.apply(request); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 60ecaa3e37e0..451ec649aa23 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -137,7 +137,8 @@ public void testTimerInternalsSetTimer() { null, // synchronized processing time stateReader, sideInputStateFetcher, - outputBuilder); + outputBuilder, + null); TimerInternals timerInternals = stepContext.timerInternals(); @@ -187,7 +188,8 @@ public void testTimerInternalsProcessingTimeSkew() { null, // synchronized processing time stateReader, sideInputStateFetcher, - outputBuilder); + outputBuilder, + null); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 6fa2ffe711f8..b488641d1ca5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -37,6 +37,7 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -90,9 +91,12 @@ import org.apache.beam.runners.dataflow.worker.counters.CounterSet; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; +import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; @@ -613,7 +617,8 @@ public void testReadUnboundedReader() throws Exception { null, // synchronized processing time null, // StateReader null, // StateFetcher - Windmill.WorkItemCommitRequest.newBuilder()); + Windmill.WorkItemCommitRequest.newBuilder(), + null); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -931,4 +936,79 @@ public void testGetReaderProgressThrowing() { assertNull(progress.getRemainingParallelism()); logged.verifyWarn("remaining parallelism"); } + + @Test + public void testFailedWorkItemsAbort() throws Exception { + CounterSet counterSet = new CounterSet(); + StreamingModeExecutionStateRegistry executionStateRegistry = + new StreamingModeExecutionStateRegistry(null); + StreamingModeExecutionContext context = + new StreamingModeExecutionContext( + counterSet, + "computationId", + new ReaderCache(Duration.standardMinutes(1), Runnable::run), + /*stateNameMap=*/ ImmutableMap.of(), + new WindmillStateCache(options.getWorkerCacheMb()).forComputation("computationId"), + StreamingStepMetricsContainer.createRegistry(), + new DataflowExecutionStateTracker( + ExecutionStateSampler.newForTest(), + executionStateRegistry.getState( + NameContext.forStage("stageName"), "other", null, NoopProfileScope.NOOP), + counterSet, + PipelineOptionsFactory.create(), + "test-work-item-id"), + executionStateRegistry, + Long.MAX_VALUE); + + options.setNumWorkers(5); + int maxElements = 100; + DataflowPipelineDebugOptions debugOptions = options.as(DataflowPipelineDebugOptions.class); + debugOptions.setUnboundedReaderMaxElements(maxElements); + + ByteString state = ByteString.EMPTY; + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. + .setWorkToken(0) + .setCacheToken(1) + .setSourceState( + Windmill.SourceState.newBuilder().setState(state).build()) // Source state. + .build(); + Work dummyWork = Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {}); + + context.start( + "key", + workItem, + new Instant(0), // input watermark + null, // output watermark + null, // synchronized processing time + null, // StateReader + null, // StateFetcher + Windmill.WorkItemCommitRequest.newBuilder(), + dummyWork::isFailed); + + @SuppressWarnings({"unchecked", "rawtypes"}) + NativeReader>>> reader = + (NativeReader) + WorkerCustomSources.create( + (CloudObject) + serializeToCloudSource(new TestCountingSource(Integer.MAX_VALUE), options) + .getSpec(), + options, + context); + + NativeReaderIterator>>> readerIterator = + reader.iterator(); + int numReads = 0; + while ((numReads == 0) ? readerIterator.start() : readerIterator.advance()) { + WindowedValue>> value = readerIterator.getCurrent(); + assertEquals(KV.of(0, numReads), value.getValue().getValue()); + numReads++; + // Fail the work item after reading two elements. + if (numReads == 2) { + dummyWork.setFailed(); + } + } + assertThat(numReads, equalTo(2)); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index ea57f687fd95..de30fd0f8d5d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -36,7 +36,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -239,7 +239,7 @@ public void testInvalidateStuckCommits() { } @Test - public void testGetKeysToRefresh() { + public void testGetKeyHeartbeats() { Instant refreshDeadline = Instant.now(); Work freshWork = createWork(createWorkItem(3L)); @@ -254,47 +254,51 @@ public void testGetKeysToRefresh() { activeWorkState.activateWorkForKey(shardedKey1, freshWork); activeWorkState.activateWorkForKey(shardedKey2, refreshableWork2); - ImmutableList requests = - activeWorkState.getKeysToRefresh(refreshDeadline, DataflowExecutionStateSampler.instance()); + ImmutableList requests = + activeWorkState.getKeyHeartbeats(refreshDeadline, DataflowExecutionStateSampler.instance()); - ImmutableList expected = + ImmutableList expected = ImmutableList.of( - GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey1, refreshableWork1), - GetDataRequestKeyShardingKeyAndWorkToken.from(shardedKey2, refreshableWork2)); + HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey1, refreshableWork1), + HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey2, refreshableWork2)); - ImmutableList actual = + ImmutableList actual = requests.stream() - .map(GetDataRequestKeyShardingKeyAndWorkToken::from) + .map(HeartbeatRequestShardingKeyWorkTokenAndCacheToken::from) .collect(toImmutableList()); assertThat(actual).containsExactlyElementsIn(expected); } @AutoValue - abstract static class GetDataRequestKeyShardingKeyAndWorkToken { + abstract static class HeartbeatRequestShardingKeyWorkTokenAndCacheToken { - private static GetDataRequestKeyShardingKeyAndWorkToken create( - ByteString key, long shardingKey, long workToken) { - return new AutoValue_ActiveWorkStateTest_GetDataRequestKeyShardingKeyAndWorkToken( - key, shardingKey, workToken); + private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken create( + long shardingKey, long workToken, long cacheToken) { + return new AutoValue_ActiveWorkStateTest_HeartbeatRequestShardingKeyWorkTokenAndCacheToken( + shardingKey, workToken, cacheToken); } - private static GetDataRequestKeyShardingKeyAndWorkToken from( - KeyedGetDataRequest keyedGetDataRequest) { + private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from( + HeartbeatRequest heartbeatRequest) { return create( - keyedGetDataRequest.getKey(), - keyedGetDataRequest.getShardingKey(), - keyedGetDataRequest.getWorkToken()); + heartbeatRequest.getShardingKey(), + heartbeatRequest.getWorkToken(), + heartbeatRequest.getCacheToken()); } - private static GetDataRequestKeyShardingKeyAndWorkToken from(ShardedKey shardedKey, Work work) { - return create(shardedKey.key(), shardedKey.shardingKey(), work.getWorkItem().getWorkToken()); + private static HeartbeatRequestShardingKeyWorkTokenAndCacheToken from( + ShardedKey shardedKey, Work work) { + return create( + shardedKey.shardingKey(), + work.getWorkItem().getWorkToken(), + work.getWorkItem().getCacheToken()); } - abstract ByteString key(); - abstract long shardingKey(); abstract long workToken(); + + abstract long cacheToken(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 5f8a452a0433..0ea253027679 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -44,6 +44,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkStreamTimingInfo; @@ -51,6 +52,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; @@ -126,7 +128,7 @@ public void setUp() throws Exception { .build() .start(); - this.client = GrpcWindmillServer.newTestInstance(name); + this.client = GrpcWindmillServer.newTestInstance(name, new ArrayList<>()); } @After @@ -744,7 +746,7 @@ public void onCompleted() { while (true) { Thread.sleep(100); int tmpErrorsBeforeClose = errorsBeforeClose.get(); - // wait for at least 1 errors before close + // wait for at least 1 error before close if (tmpErrorsBeforeClose > 0) { break; } @@ -765,7 +767,7 @@ public void onCompleted() { while (true) { Thread.sleep(100); int tmpErrorsAfterClose = errorsAfterClose.get(); - // wait for at least 1 errors after close + // wait for at least 1 error after close if (tmpErrorsAfterClose > 0) { break; } @@ -786,22 +788,36 @@ public void onCompleted() { assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); } - private List makeHeartbeatRequest(List keys) { + private List makeGetDataHeartbeatRequest(List keys) { List result = new ArrayList<>(); for (String key : keys) { result.add( Windmill.KeyedGetDataRequest.newBuilder() - .setKey(ByteString.copyFromUtf8(key)) + .setShardingKey(key.hashCode()) .setWorkToken(0) + .setCacheToken(0) + .build()); + } + return result; + } + + private List makeHeartbeatRequest(List keys) { + List result = new ArrayList<>(); + for (String key : keys) { + result.add( + Windmill.HeartbeatRequest.newBuilder() + .setShardingKey(key.hashCode()) + .setWorkToken(0) + .setCacheToken(0) .build()); } return result; } @Test - public void testStreamingGetDataHeartbeats() throws Exception { + public void testStreamingGetDataHeartbeatsAsKeyedGetDataRequests() throws Exception { // This server records the heartbeats observed but doesn't respond. - final Map> heartbeats = new HashMap<>(); + final Map> getDataHeartbeats = new HashMap<>(); serviceRegistry.addService( new CloudWindmillServiceV1Alpha1ImplBase() { @@ -826,16 +842,17 @@ public void onNext(StreamingGetDataRequest chunk) { .build())); sawHeader = true; } else { - LOG.info("Received {} heartbeats", chunk.getStateRequestCount()); + LOG.info("Received {} getDataHeartbeats", chunk.getStateRequestCount()); errorCollector.checkThat( chunk.getSerializedSize(), Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE)); errorCollector.checkThat(chunk.getRequestIdCount(), Matchers.is(0)); - synchronized (heartbeats) { + synchronized (getDataHeartbeats) { for (ComputationGetDataRequest request : chunk.getStateRequestList()) { errorCollector.checkThat(request.getRequestsCount(), Matchers.is(1)); - heartbeats.putIfAbsent(request.getComputationId(), new ArrayList<>()); - heartbeats + getDataHeartbeats.putIfAbsent( + request.getComputationId(), new ArrayList<>()); + getDataHeartbeats .get(request.getComputationId()) .add(request.getRequestsList().get(0)); } @@ -857,7 +874,6 @@ public void onCompleted() { } }); - Map> activeMap = new HashMap<>(); List computation1Keys = new ArrayList<>(); List computation2Keys = new ArrayList<>(); @@ -865,22 +881,141 @@ public void onCompleted() { computation1Keys.add("Computation1Key" + i); computation2Keys.add("Computation2Key" + largeString(i * 20)); } - activeMap.put("Computation1", makeHeartbeatRequest(computation1Keys)); - activeMap.put("Computation2", makeHeartbeatRequest(computation2Keys)); + // We're adding HeartbeatRequests to refreshActiveWork, but expecting to get back + // KeyedGetDataRequests, so make a Map of both types. + Map> expectedKeyedGetDataRequests = new HashMap<>(); + expectedKeyedGetDataRequests.put("Computation1", makeGetDataHeartbeatRequest(computation1Keys)); + expectedKeyedGetDataRequests.put("Computation2", makeGetDataHeartbeatRequest(computation2Keys)); + Map> heartbeatsToRefresh = new HashMap<>(); + heartbeatsToRefresh.put("Computation1", makeHeartbeatRequest(computation1Keys)); + heartbeatsToRefresh.put("Computation2", makeHeartbeatRequest(computation2Keys)); GetDataStream stream = client.getDataStream(); - stream.refreshActiveWork(activeMap); + stream.refreshActiveWork(heartbeatsToRefresh); stream.close(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); - while (true) { + boolean receivedAllGetDataHeartbeats = false; + while (!receivedAllGetDataHeartbeats) { Thread.sleep(100); - synchronized (heartbeats) { - if (heartbeats.size() != activeMap.size()) { + synchronized (getDataHeartbeats) { + if (getDataHeartbeats.size() != expectedKeyedGetDataRequests.size()) { continue; } - assertEquals(heartbeats, activeMap); - break; + assertEquals(expectedKeyedGetDataRequests, getDataHeartbeats); + receivedAllGetDataHeartbeats = true; + } + } + } + + @Test + public void testStreamingGetDataHeartbeatsAsHeartbeatRequests() throws Exception { + // Create a client and server different from the one in SetUp so we can add an experiment to the + // options passed in. + this.server = + InProcessServerBuilder.forName("TestServer") + .fallbackHandlerRegistry(serviceRegistry) + .executor(Executors.newFixedThreadPool(1)) + .build() + .start(); + this.client = + GrpcWindmillServer.newTestInstance( + "TestServer", + Collections.singletonList("streaming_engine_send_new_heartbeat_requests")); + // This server records the heartbeats observed but doesn't respond. + final List receivedHeartbeats = new ArrayList<>(); + + serviceRegistry.addService( + new CloudWindmillServiceV1Alpha1ImplBase() { + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + return new StreamObserver() { + boolean sawHeader = false; + + @Override + public void onNext(StreamingGetDataRequest chunk) { + try { + if (!sawHeader) { + LOG.info("Received header"); + errorCollector.checkThat( + chunk.getHeader(), + Matchers.equalTo( + JobHeader.newBuilder() + .setJobId("job") + .setProjectId("project") + .setWorkerId("worker") + .build())); + sawHeader = true; + } else { + LOG.info( + "Received {} computationHeartbeatRequests", + chunk.getComputationHeartbeatRequestCount()); + errorCollector.checkThat( + chunk.getSerializedSize(), Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE)); + errorCollector.checkThat(chunk.getRequestIdCount(), Matchers.is(0)); + + synchronized (receivedHeartbeats) { + receivedHeartbeats.addAll(chunk.getComputationHeartbeatRequestList()); + } + } + } catch (Exception e) { + errorCollector.addError(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + }); + + List computation1Keys = new ArrayList<>(); + List computation2Keys = new ArrayList<>(); + + // When sending heartbeats as HeartbeatRequest protos, all keys for the same computation should + // be batched into the same ComputationHeartbeatRequest. Compare to the KeyedGetDataRequest + // version in the test above, which only sends one key per ComputationGetDataRequest. + List expectedHeartbeats = new ArrayList<>(); + ComputationHeartbeatRequest.Builder comp1Builder = + ComputationHeartbeatRequest.newBuilder().setComputationId("Computation1"); + ComputationHeartbeatRequest.Builder comp2Builder = + ComputationHeartbeatRequest.newBuilder().setComputationId("Computation2"); + for (int i = 0; i < 100; ++i) { + String computation1Key = "Computation1Key" + i; + computation1Keys.add(computation1Key); + comp1Builder.addHeartbeatRequests( + makeHeartbeatRequest(Collections.singletonList(computation1Key)).get(0)); + String computation2Key = "Computation2Key" + largeString(i * 20); + computation2Keys.add(computation2Key); + comp2Builder.addHeartbeatRequests( + makeHeartbeatRequest(Collections.singletonList(computation2Key)).get(0)); + } + expectedHeartbeats.add(comp1Builder.build()); + expectedHeartbeats.add(comp2Builder.build()); + Map> heartbeatRequestMap = new HashMap<>(); + heartbeatRequestMap.put("Computation1", makeHeartbeatRequest(computation1Keys)); + heartbeatRequestMap.put("Computation2", makeHeartbeatRequest(computation2Keys)); + + GetDataStream stream = client.getDataStream(); + stream.refreshActiveWork(heartbeatRequestMap); + stream.close(); + assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); + + boolean receivedAllHeartbeatRequests = false; + while (!receivedAllHeartbeatRequests) { + Thread.sleep(100); + synchronized (receivedHeartbeats) { + if (receivedHeartbeats.size() != expectedHeartbeats.size()) { + continue; + } + assertEquals(expectedHeartbeats, receivedHeartbeats); + receivedAllHeartbeatRequests = true; } } } @@ -888,7 +1023,7 @@ public void onCompleted() { @Test public void testThrottleSignal() throws Exception { // This server responds with work items until the throttleMessage limit is hit at which point it - // returns RESROUCE_EXHAUSTED errors for throttleTime msecs after which it resumes sending + // returns RESOURCE_EXHAUSTED errors for throttleTime msecs after which it resumes sending // work items. final int throttleTime = 2000; final int throttleMessage = 15; diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 6aaeb57001e0..0c824ca301b3 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -477,9 +477,10 @@ message GetWorkResponse { // GetData message KeyedGetDataRequest { - required bytes key = 1; + optional bytes key = 1; required fixed64 work_token = 2; optional fixed64 sharding_key = 6; + optional fixed64 cache_token = 11; repeated TagValue values_to_fetch = 3; repeated TagValuePrefixRequest tag_value_prefixes_to_fetch = 10; repeated TagBag bags_to_fetch = 8; @@ -507,6 +508,8 @@ message GetDataRequest { // Assigned worker id for the instance. optional string worker_id = 6; + // SE only. Will only be set by compatible client + repeated ComputationHeartbeatRequest computation_heartbeat_request = 7; // DEPRECATED repeated GlobalDataId global_data_to_fetch = 2; } @@ -536,6 +539,44 @@ message ComputationGetDataResponse { message GetDataResponse { repeated ComputationGetDataResponse data = 1; repeated GlobalData global_data = 2; + // Only set if ComputationHeartbeatRequest was sent, prior versions do not + // expect a response for heartbeats. SE only. + repeated ComputationHeartbeatResponse computation_heartbeat_response = 3; +} + +// Heartbeats +// +// Heartbeats are sent over the GetData stream in Streaming Engine and +// indicates the work item that the user worker has previously received from +// GetWork but not yet committed with CommitWork. +// Note that implicit heartbeats not expecting a response may be sent as +// special KeyedGetDataRequests see function KeyedGetDataRequestIsHeartbeat. +// SE only. +message HeartbeatRequest { + optional fixed64 sharding_key = 1; + optional fixed64 work_token = 2; + optional fixed64 cache_token = 3; + repeated LatencyAttribution latency_attribution = 4; +} + +// Responses for heartbeat requests, indicating which work is no longer valid +// on the windmill worker and may be dropped/cancelled in the client. +// SE only. +message HeartbeatResponse { + optional fixed64 sharding_key = 1; + optional fixed64 work_token = 2; + optional fixed64 cache_token = 3; + optional bool failed = 4; +} + +message ComputationHeartbeatRequest { + optional string computation_id = 1; + repeated HeartbeatRequest heartbeat_requests = 2; +} + +message ComputationHeartbeatResponse { + optional string computation_id = 1; + repeated HeartbeatResponse heartbeat_responses = 2; } // CommitWork @@ -772,6 +813,8 @@ message StreamingGetDataRequest { repeated fixed64 request_id = 1; repeated GlobalDataRequest global_data_request = 3; repeated ComputationGetDataRequest state_request = 4; + // Will only be set by compatible client + repeated ComputationHeartbeatRequest computation_heartbeat_request = 5; } message StreamingGetDataResponse { @@ -784,6 +827,12 @@ message StreamingGetDataResponse { repeated bytes serialized_response = 2; // Remaining bytes field applies only to the last serialized_response optional int64 remaining_bytes_for_response = 3; + + // Only set if ComputationHeartbeatRequest was sent, prior versions do not + // expect a response for heartbeats. + repeated ComputationHeartbeatResponse computation_heartbeat_response = 5; + + reserved 4; } message StreamingCommitWorkRequest {