Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check for cachetoken representing a retry before activating and completing work #29082

Merged
merged 12 commits into from
Feb 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
import org.apache.beam.runners.dataflow.worker.streaming.Commit;
import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
Expand All @@ -105,6 +104,7 @@
import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
Expand Down Expand Up @@ -1971,20 +1971,19 @@ private void sendWorkerUpdatesToDataflowService(
}
}

public void handleHeartbeatResponses(List<Windmill.ComputationHeartbeatResponse> responses) {
for (Windmill.ComputationHeartbeatResponse computationHeartbeatResponse : responses) {
public void handleHeartbeatResponses(List<ComputationHeartbeatResponse> responses) {
for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) {
// Maps sharding key to (work token, cache token) for work that should be marked failed.
Map<Long, List<FailedTokens>> failedWork = new HashMap<>();
Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
for (Windmill.HeartbeatResponse heartbeatResponse :
computationHeartbeatResponse.getHeartbeatResponsesList()) {
if (heartbeatResponse.getFailed()) {
failedWork
.computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new ArrayList<>())
.add(
FailedTokens.newBuilder()
.setWorkToken(heartbeatResponse.getWorkToken())
.setCacheToken(heartbeatResponse.getCacheToken())
.build());
failedWork.put(
heartbeatResponse.getShardingKey(),
WorkId.builder()
.setWorkToken(heartbeatResponse.getWorkToken())
.setCacheToken(heartbeatResponse.getCacheToken())
.build());
}
}
ComputationState state = computationMap.get(computationHeartbeatResponse.getComputationId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;

import com.google.auto.value.AutoValue;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
Expand All @@ -48,6 +46,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
Expand Down Expand Up @@ -107,6 +106,29 @@ private static String elapsedString(Instant start, Instant end) {
return activeFor.toString().substring(2);
}

private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
Instant refreshDeadline,
DataflowExecutionStateSampler sampler) {
ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
Deque<Work> workQueue = shardedKeyAndWorkQueue.getValue();

return workQueue.stream()
.filter(work -> work.getStartTime().isBefore(refreshDeadline))
// Don't send heartbeats for queued work we already know is failed.
.filter(work -> !work.isFailed())
.map(
work ->
Windmill.HeartbeatRequest.newBuilder()
.setShardingKey(shardedKey.shardingKey())
.setWorkToken(work.getWorkItem().getWorkToken())
.setCacheToken(work.getWorkItem().getCacheToken())
.addAllLatencyAttribution(
work.getLatencyAttributions(
/* isHeartbeat= */ true, work.getLatencyTrackingId(), sampler))
.build());
}

/**
* Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 4 {@link
* ActivateWorkResult}
Expand Down Expand Up @@ -136,79 +158,57 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w
}

// Check to see if we have this work token queued.
// This set is for adding remove-able WorkItems if they exist in the workQueue. We add them to
// this set since a ConcurrentModificationException will be thrown if we modify the workQueue
// and then resume iteration.
Set<WorkId> queuedWorkToRemove = new HashSet<>();
for (Work queuedWork : workQueue) {
Iterator<Work> workIterator = workQueue.iterator();
while (workIterator.hasNext()) {
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
Work queuedWork = workIterator.next();
if (queuedWork.id().equals(work.id())) {
return ActivateWorkResult.DUPLICATE;
}
if (queuedWork.id().cacheToken() == work.id().cacheToken()) {
if (work.id().workToken() > queuedWork.id().workToken()) {
queuedWorkToRemove.add(queuedWork.id());
// Check to see if the queuedWork is active. We only want to remove it if it is NOT
// currently active.
if (!queuedWork.equals(workQueue.peek())) {
workIterator.remove();
decrementActiveWorkBudget(queuedWork);
}
// Continue here to possibly remove more non-active stale work that is queued.
} else {
return ActivateWorkResult.STALE;
}
}
}

workQueue.removeIf(
queuedWork ->
queuedWorkToRemove.contains(queuedWork.id()) && !queuedWork.equals(workQueue.peek()));

// Queue the work for later processing.
workQueue.addLast(work);
incrementActiveWorkBudget(work);
return ActivateWorkResult.QUEUED;
}

@AutoValue
public abstract static class FailedTokens {
public static Builder newBuilder() {
return new AutoValue_ActiveWorkState_FailedTokens.Builder();
}

public abstract long workToken();

public abstract long cacheToken();

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setWorkToken(long value);

public abstract Builder setCacheToken(long value);

public abstract FailedTokens build();
}
}

/**
* Fails any active work matching an element of the input Map.
*
* @param failedWork a map from sharding_key to tokens for the corresponding work.
*/
synchronized void failWorkForKey(Map<Long, List<FailedTokens>> failedWork) {
synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
// Note we can't construct a ShardedKey and look it up in activeWork directly since
// HeartbeatResponse doesn't include the user key.
for (Entry<ShardedKey, Deque<Work>> entry : activeWork.entrySet()) {
List<FailedTokens> failedTokens = failedWork.get(entry.getKey().shardingKey());
if (failedTokens == null) continue;
for (FailedTokens failedToken : failedTokens) {
Collection<WorkId> failedWorkIds = failedWork.get(entry.getKey().shardingKey());
for (WorkId failedWorkId : failedWorkIds) {
for (Work queuedWork : entry.getValue()) {
WorkItem workItem = queuedWork.getWorkItem();
if (workItem.getWorkToken() == failedToken.workToken()
&& workItem.getCacheToken() == failedToken.cacheToken()) {
if (workItem.getWorkToken() == failedWorkId.workToken()
&& workItem.getCacheToken() == failedWorkId.cacheToken()) {
LOG.debug(
"Failing work "
+ computationStateCache.getComputation()
+ " "
+ entry.getKey().shardingKey()
+ " "
+ failedToken.workToken()
+ failedWorkId.workToken()
+ " "
+ failedToken.cacheToken()
+ failedWorkId.cacheToken()
+ ". The work will be retried and is not lost.");
queuedWork.setFailed();
break;
Expand Down Expand Up @@ -328,29 +328,6 @@ synchronized ImmutableList<HeartbeatRequest> getKeyHeartbeats(
.collect(toImmutableList());
}

private static Stream<HeartbeatRequest> toHeartbeatRequestStream(
Entry<ShardedKey, Deque<Work>> shardedKeyAndWorkQueue,
Instant refreshDeadline,
DataflowExecutionStateSampler sampler) {
ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey();
Deque<Work> workQueue = shardedKeyAndWorkQueue.getValue();

return workQueue.stream()
.filter(work -> work.getStartTime().isBefore(refreshDeadline))
// Don't send heartbeats for queued work we already know is failed.
.filter(work -> !work.isFailed())
.map(
work ->
Windmill.HeartbeatRequest.newBuilder()
.setShardingKey(shardedKey.shardingKey())
.setWorkToken(work.getWorkItem().getWorkToken())
.setCacheToken(work.getWorkItem().getCacheToken())
.addAllLatencyAttribution(
work.getLatencyAttributions(
/* isHeartbeat= */ true, work.getLatencyTrackingId(), sampler))
.build());
}

/**
* Returns the current aggregate {@link GetWorkBudget} that is active on the user worker. Active
* means that the work is received from Windmill, being processed or queued to be processed in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@

import com.google.api.services.dataflow.model.MapTask;
import java.io.PrintWriter;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.FailedTokens;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -103,7 +102,7 @@ public boolean activateWork(ShardedKey shardedKey, Work work) {
}
}

public void failWork(Map<Long, List<FailedTokens>> failedWork) {
public void failWork(Multimap<Long, WorkId> failedWork) {
activeWorkState.failWorkForKey(failedWork);
}

Expand Down
Loading