From f93a67a6ae888cad793920d8dc7ebb266169d661 Mon Sep 17 00:00:00 2001 From: martin trieu Date: Fri, 31 May 2024 13:41:55 -0700 Subject: [PATCH] remove processing/scheduling logic from StreamingDataflowWorker (#31317) * use work processing context in work processing * break out work scheduling/processing from StreamingDataflowWorker --- .../runners/dataflow/worker/PubsubReader.java | 2 +- .../runners/dataflow/worker/ReaderCache.java | 4 +- .../worker/StreamingDataflowWorker.java | 614 +++--------------- .../worker/StreamingModeExecutionContext.java | 176 ++--- .../worker/UngroupedWindmillReader.java | 2 +- .../worker/WindmillTimerInternals.java | 39 +- .../worker/WindowingWindmillReader.java | 2 +- .../dataflow/worker/WorkerCustomSources.java | 2 +- .../worker/streaming/ActiveWorkState.java | 77 +-- .../worker/streaming/ComputationState.java | 52 +- .../streaming/ComputationStateCache.java | 1 + .../streaming/ComputationWorkExecutor.java | 118 ++++ .../worker/streaming/ExecutableWork.java | 48 ++ .../worker/streaming/ExecutionState.java | 54 -- .../dataflow/worker/streaming/Watermarks.java | 69 ++ .../dataflow/worker/streaming/Work.java | 289 ++++++--- .../dataflow/worker/streaming/WorkId.java | 8 + .../sideinput/SideInputStateFetcher.java | 2 + .../util/common/worker/MapTaskExecutor.java | 2 +- .../client/grpc/GrpcDirectGetWorkStream.java | 53 +- .../client/grpc/GrpcDispatcherClient.java | 32 +- .../grpc/GrpcWindmillStreamFactory.java | 11 +- .../client/grpc/StreamingEngineClient.java | 161 +++-- .../client/grpc/WindmillStreamSender.java | 37 +- .../windmill/state/WindmillStateReader.java | 32 +- .../windmill/work/ProcessWorkItemClient.java | 52 -- ...mProcessor.java => WorkItemScheduler.java} | 29 +- .../ComputationWorkExecutorFactory.java | 291 +++++++++ .../processing/StreamingCommitFinalizer.java | 85 +++ .../processing/StreamingWorkScheduler.java | 428 ++++++++++++ .../failures/WorkFailureProcessor.java | 16 +- .../dataflow/worker/PubsubReaderTest.java | 2 +- .../worker/StreamingDataflowWorkerTest.java | 172 ++--- .../StreamingModeExecutionContextTest.java | 40 +- .../worker/WorkerCustomSourcesTest.java | 83 ++- .../worker/streaming/ActiveWorkStateTest.java | 213 +++--- .../streaming/ComputationStateCacheTest.java | 45 +- .../StreamingApplianceWorkCommitterTest.java | 19 +- .../StreamingEngineWorkCommitterTest.java | 28 +- .../grpc/StreamingEngineClientTest.java | 37 +- .../client/grpc/WindmillStreamSenderTest.java | 42 +- .../EvenGetWorkBudgetDistributorTest.java | 14 +- .../failures/WorkFailureProcessorTest.java | 75 ++- .../DispatchedActiveWorkRefresherTest.java | 54 +- 44 files changed, 2212 insertions(+), 1400 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Watermarks.java delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java rename runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/{WorkItemProcessor.java => WorkItemScheduler.java} (61%) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizer.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java index be0bccec0265..5d2c491ff72b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java @@ -104,7 +104,7 @@ public NativeReader create( @Override public NativeReaderIterator> iterator() throws IOException { - return new PubsubReaderIterator(context.getWork()); + return new PubsubReaderIterator(context.getWorkItem()); } class PubsubReaderIterator extends WindmillReaderIteratorBase { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ReaderCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ReaderCache.java index 01010863f1ee..fa18432bb358 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ReaderCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ReaderCache.java @@ -21,6 +21,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; @@ -39,7 +40,8 @@ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -class ReaderCache { +@Internal +public class ReaderCache { private static final Logger LOG = LoggerFactory.getLogger(ReaderCache.class); private final Executor invalidationExecutor; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 52b8adb5615a..b809c2ebe58b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker; -import static org.apache.beam.runners.dataflow.DataflowRunner.hasExperiment; import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; @@ -25,7 +24,6 @@ import com.google.api.services.dataflow.model.MapTask; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -43,30 +41,15 @@ import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.core.metrics.MetricsLogger; import org.apache.beam.runners.dataflow.DataflowRunner; -import org.apache.beam.runners.dataflow.internal.CustomSources; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; -import org.apache.beam.runners.dataflow.util.CloudObject; -import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; -import org.apache.beam.runners.dataflow.worker.counters.NameContext; -import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge; -import org.apache.beam.runners.dataflow.worker.graph.MapTaskToNetworkFunction; -import org.apache.beam.runners.dataflow.worker.graph.Networks; -import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode; -import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node; -import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode; -import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; -import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler; import org.apache.beam.runners.dataflow.worker.status.DebugCapture; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; -import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState; -import org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException; -import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; -import org.apache.beam.runners.dataflow.worker.streaming.Work.State; import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingApplianceComputationConfigFetcher; @@ -75,22 +58,16 @@ import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingWorkerStatusPages; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingWorkerStatusReporter; -import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; -import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; -import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; @@ -104,15 +81,13 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; -import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingApplianceFailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingEngineFailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefresher; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ActiveWorkRefreshers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.JvmInitializers; @@ -120,16 +95,11 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsEnvironment; -import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.util.construction.CoderTranslation; -import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.*; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.graph.MutableNetwork; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; @@ -163,23 +133,12 @@ public class StreamingDataflowWorker { * readers to stop producing more. This can be disabled with 'disable_limiting_bundle_sink_bytes' * experiment. */ - static final int MAX_SINK_BYTES = 10_000_000; + public static final int MAX_SINK_BYTES = 10_000_000; private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); /** The idGenerator to generate unique id globally. */ private static final IdGenerator ID_GENERATOR = IdGenerators.decrementingLongs(); - /** - * Function which converts map tasks to their network representation for execution. - * - *
    - *
  • Translate the map task to a network representation. - *
  • Remove flatten instructions by rewiring edges. - *
- */ - private static final Function> MAP_TASK_TO_BASE_NETWORK_FN = - new MapTaskToNetworkFunction(ID_GENERATOR); - private static final int DEFAULT_STATUS_PORT = 8081; // Maximum size of the result of a GetWork request. private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m @@ -192,53 +151,28 @@ public class StreamingDataflowWorker { private final StreamingWorkerStatusPages statusPages; private final ComputationConfig.Fetcher configFetcher; private final ComputationStateCache computationStateCache; - // Maps from computation ids to per-computation state. - // Cache of tokens to commit callbacks. - // Using Cache with time eviction policy helps us to prevent memory leak when callback ids are - // discarded by Dataflow service and calling commitCallback is best-effort. - private final Cache commitCallbacks = - CacheBuilder.newBuilder().expireAfterWrite(5L, TimeUnit.MINUTES).build(); - private final BoundedQueueExecutor workUnitExecutor; private final WindmillServerStub windmillServer; private final Thread dispatchThread; private final AtomicBoolean running = new AtomicBoolean(); - private final SideInputStateFetcher sideInputStateFetcher; private final DataflowWorkerHarnessOptions options; - private final boolean windmillServiceEnabled; private final long clientId; private final MetricTrackingWindmillServerStub metricTrackingWindmillServer; - - // Map from stage name to StageInfo containing metrics container registry and per stage counters. - private final ConcurrentMap stageInfoMap; - private final MemoryMonitor memoryMonitor; private final Thread memoryMonitorThread; - // Limit on bytes sinked (committed) in a work item. - private final long maxSinkBytes; // = MAX_SINK_BYTES unless disabled in options. private final ReaderCache readerCache; - private final Function> mapTaskToNetwork; - private final ReaderRegistry readerRegistry = ReaderRegistry.defaultRegistry(); - private final SinkRegistry sinkRegistry = SinkRegistry.defaultRegistry(); - private final Supplier clock; - private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; - private final HotKeyLogger hotKeyLogger; - // Possibly overridden by streaming engine config. - private final AtomicInteger maxWorkItemCommitBytes; private final DataflowExecutionStateSampler sampler = DataflowExecutionStateSampler.instance(); private final ActiveWorkRefresher activeWorkRefresher; private final WorkCommitter workCommitter; private final StreamingWorkerStatusReporter workerStatusReporter; - private final FailureTracker failureTracker; - private final WorkFailureProcessor workFailureProcessor; private final StreamingCounters streamingCounters; + private final StreamingWorkScheduler streamingWorkScheduler; - StreamingDataflowWorker( + private StreamingDataflowWorker( WindmillServerStub windmillServer, long clientId, ComputationConfig.Fetcher configFetcher, ComputationStateCache computationStateCache, - ConcurrentMap stageInfoMap, WindmillStateCache windmillStateCache, BoundedQueueExecutor workUnitExecutor, DataflowMapTaskExecutorFactory mapTaskExecutorFactory, @@ -252,21 +186,19 @@ public class StreamingDataflowWorker { MemoryMonitor memoryMonitor, AtomicInteger maxWorkItemCommitBytes, GrpcWindmillStreamFactory windmillStreamFactory, - Function executorSupplier) { + Function executorSupplier, + ConcurrentMap stageInfoMap) { this.configFetcher = configFetcher; this.computationStateCache = computationStateCache; - this.stageInfoMap = stageInfoMap; this.stateCache = windmillStateCache; this.readerCache = new ReaderCache( Duration.standardSeconds(options.getReaderCacheTimeoutSec()), Executors.newCachedThreadPool()); - this.mapTaskExecutorFactory = mapTaskExecutorFactory; this.options = options; - this.hotKeyLogger = hotKeyLogger; - this.clock = clock; - this.maxWorkItemCommitBytes = maxWorkItemCommitBytes; - this.windmillServiceEnabled = options.isEnableStreamingEngine(); + + boolean windmillServiceEnabled = options.isEnableStreamingEngine(); + int numCommitThreads = 1; if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 0) { numCommitThreads = options.getWindmillServiceCommitThreads(); @@ -285,11 +217,6 @@ public class StreamingDataflowWorker { this.workUnitExecutor = workUnitExecutor; - maxSinkBytes = - hasExperiment(options, "disable_limiting_bundle_sink_bytes") - ? Long.MAX_VALUE - : MAX_SINK_BYTES; - memoryMonitorThread = new Thread(memoryMonitor); memoryMonitorThread.setPriority(Thread.MIN_PRIORITY); memoryMonitorThread.setName("MemoryMonitor"); @@ -317,14 +244,9 @@ public class StreamingDataflowWorker { .setNumGetDataStreams(options.getWindmillGetDataStreamCount()) .build(); - this.sideInputStateFetcher = - new SideInputStateFetcher(metricTrackingWindmillServer::getSideInputData, options); - // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); - this.mapTaskToNetwork = MAP_TASK_TO_BASE_NETWORK_FN; - int stuckCommitDurationMillis = windmillServiceEnabled && options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() @@ -364,11 +286,27 @@ public class StreamingDataflowWorker { : statusPagesBuilder.build(); this.workerStatusReporter = workerStatusReporter; - this.failureTracker = failureTracker; - this.workFailureProcessor = workFailureProcessor; this.streamingCounters = streamingCounters; this.memoryMonitor = memoryMonitor; + this.streamingWorkScheduler = + StreamingWorkScheduler.create( + options, + clock, + readerCache, + mapTaskExecutorFactory, + workUnitExecutor, + stateCache::forComputation, + metricTrackingWindmillServer::getSideInputData, + failureTracker, + workFailureProcessor, + streamingCounters, + hotKeyLogger, + sampler, + maxWorkItemCommitBytes, + ID_GENERATOR, + stageInfoMap); + LOG.debug("windmillServiceEnabled: {}", windmillServiceEnabled); LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint()); LOG.debug("WindmillServicePort: {}", options.getWindmillServicePort()); @@ -451,12 +389,12 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o workExecutor, options.getWindmillHarnessUpdateReportingPeriod().getMillis(), options.getPerWorkerMetricsUpdateReportingPeriodMillis()); + return new StreamingDataflowWorker( windmillServer, clientId, configFetcherAndWindmillClient.getLeft(), computationStateCache, - stageInfo, WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()), workExecutor, IntrinsicMapTaskExecutorFactory.defaultFactory(), @@ -470,7 +408,8 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o memoryMonitor, maxWorkItemCommitBytes, windmillStreamFactory, - executorSupplier); + executorSupplier, + stageInfo); } private static Pair> @@ -582,7 +521,6 @@ static StreamingDataflowWorker forTesting( 1L, configFetcher, computationStateCache, - stageInfo, stateCache, workExecutor, mapTaskExecutorFactory, @@ -596,7 +534,8 @@ static StreamingDataflowWorker forTesting( memoryMonitor, maxWorkItemCommitBytes, createWindmillStreamFactory(options, 1), - executorSupplier); + executorSupplier, + stageInfo); } private static void onPipelineConfig( @@ -633,11 +572,6 @@ private static GrpcWindmillStreamFactory createWindmillStreamFactory( .build(); } - @VisibleForTesting - final void reportPeriodicWorkerUpdatesForTest() { - workerStatusReporter.reportPeriodicWorkerUpdates(); - } - private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarnessOptions options) { return new BoundedQueueExecutor( chooseMaxThreads(options), @@ -724,27 +658,6 @@ private static void sleep(int millis) { Uninterruptibles.sleepUninterruptibly(millis, TimeUnit.MILLISECONDS); } - /** Sets the stage name and workId of the current Thread for logging. */ - private static void setUpWorkLoggingContext(String workId, String computationId) { - DataflowWorkerLoggingMDC.setWorkId(workId); - DataflowWorkerLoggingMDC.setStageName(computationId); - } - - private int chooseMaximumNumberOfThreads() { - if (options.getNumberOfWorkerHarnessThreads() != 0) { - return options.getNumberOfWorkerHarnessThreads(); - } - return MAX_PROCESSING_THREADS; - } - - private int chooseMaximumBundlesOutstanding() { - int maxBundles = options.getMaxBundlesFromWindmillOutstanding(); - if (maxBundles > 0) { - return maxBundles; - } - return chooseMaximumNumberOfThreads() + 100; - } - private static int chooseMaxThreads(DataflowWorkerHarnessOptions options) { if (options.getNumberOfWorkerHarnessThreads() != 0) { return options.getNumberOfWorkerHarnessThreads(); @@ -772,6 +685,26 @@ private static void enableBigQueryMetrics() { BigQuerySinkMetrics.setSupportStreamingInsertsMetrics(true); } + @VisibleForTesting + final void reportPeriodicWorkerUpdatesForTest() { + workerStatusReporter.reportPeriodicWorkerUpdates(); + } + + private int chooseMaximumNumberOfThreads() { + if (options.getNumberOfWorkerHarnessThreads() != 0) { + return options.getNumberOfWorkerHarnessThreads(); + } + return MAX_PROCESSING_THREADS; + } + + private int chooseMaximumBundlesOutstanding() { + int maxBundles = options.getMaxBundlesFromWindmillOutstanding(); + if (maxBundles > 0) { + return maxBundles; + } + return chooseMaximumNumberOfThreads() + 100; + } + @VisibleForTesting public boolean workExecutorIsEmpty() { return workUnitExecutor.executorQueueIsEmpty(); @@ -807,7 +740,8 @@ private void startStatusPages() { statusPages.start(options); } - public void stop() { + @VisibleForTesting + void stop() { try { configFetcher.stop(); @@ -858,16 +792,20 @@ private void dispatchLoop() { final ComputationState computationState = maybeComputationState.get(); final Instant inputDataWatermark = WindmillTimeUtils.windmillToHarnessWatermark(computationWork.getInputDataWatermark()); - Preconditions.checkNotNull(inputDataWatermark); - final @Nullable Instant synchronizedProcessingTime = - WindmillTimeUtils.windmillToHarnessWatermark( - computationWork.getDependentRealtimeInputWatermark()); + Watermarks.Builder watermarks = + Watermarks.builder() + .setInputDataWatermark(Preconditions.checkNotNull(inputDataWatermark)) + .setSynchronizedProcessingTime( + WindmillTimeUtils.windmillToHarnessWatermark( + computationWork.getDependentRealtimeInputWatermark())); + for (final Windmill.WorkItem workItem : computationWork.getWorkList()) { - scheduleWorkItem( + streamingWorkScheduler.scheduleWork( computationState, - inputDataWatermark, - synchronizedProcessingTime, workItem, + watermarks.setOutputDataWatermark(workItem.getOutputDataWatermark()).build(), + Work.createProcessingContext( + computationId, metricTrackingWindmillServer::getStateData, workCommitter::commit), /* getWorkStreamLatencies= */ Collections.emptyList()); } } @@ -893,11 +831,18 @@ void streamingDispatchLoop() { .ifPresent( computationState -> { memoryMonitor.waitForResources("GetWork"); - scheduleWorkItem( + streamingWorkScheduler.scheduleWork( computationState, - inputDataWatermark, - synchronizedProcessingTime, workItem, + Watermarks.builder() + .setInputDataWatermark(inputDataWatermark) + .setSynchronizedProcessingTime(synchronizedProcessingTime) + .setOutputDataWatermark(workItem.getOutputDataWatermark()) + .build(), + Work.createProcessingContext( + computationState.getComputationId(), + metricTrackingWindmillServer::getStateData, + workCommitter::commit), getWorkStreamLatencies); })); try { @@ -913,413 +858,6 @@ void streamingDispatchLoop() { } } - private void scheduleWorkItem( - final ComputationState computationState, - final Instant inputDataWatermark, - final Instant synchronizedProcessingTime, - final Windmill.WorkItem workItem, - final Collection getWorkStreamLatencies) { - Preconditions.checkNotNull(inputDataWatermark); - // May be null if output watermark not yet known. - final @Nullable Instant outputDataWatermark = - WindmillTimeUtils.windmillToHarnessWatermark(workItem.getOutputDataWatermark()); - Preconditions.checkState( - outputDataWatermark == null || !outputDataWatermark.isAfter(inputDataWatermark)); - Work scheduledWork = - Work.create( - workItem, - clock, - getWorkStreamLatencies, - work -> - process( - computationState, - inputDataWatermark, - outputDataWatermark, - synchronizedProcessingTime, - work)); - computationState.activateWork( - ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), scheduledWork); - } - - /** - * Extracts the userland key coder, if any, from the coder used in the initial read step of a - * stage. This encodes many assumptions about how the streaming execution context works. - */ - private @Nullable Coder extractKeyCoder(Coder readCoder) { - if (!(readCoder instanceof WindowedValueCoder)) { - throw new RuntimeException( - String.format( - "Expected coder for streaming read to be %s, but received %s", - WindowedValueCoder.class.getSimpleName(), readCoder)); - } - - // Note that TimerOrElementCoder is a backwards-compatibility class - // that is really a FakeKeyedWorkItemCoder - Coder valueCoder = ((WindowedValueCoder) readCoder).getValueCoder(); - - if (valueCoder instanceof KvCoder) { - return ((KvCoder) valueCoder).getKeyCoder(); - } - if (!(valueCoder instanceof WindmillKeyedWorkItem.FakeKeyedWorkItemCoder)) { - return null; - } - - return ((WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) valueCoder).getKeyCoder(); - } - - private void callFinalizeCallbacks(Windmill.WorkItem work) { - for (Long callbackId : work.getSourceState().getFinalizeIdsList()) { - final Runnable callback = commitCallbacks.getIfPresent(callbackId); - // NOTE: It is possible the same callback id may be removed twice if - // windmill restarts. - // TODO: It is also possible for an earlier finalized id to be lost. - // We should automatically discard all older callbacks for the same computation and key. - if (callback != null) { - commitCallbacks.invalidate(callbackId); - workUnitExecutor.forceExecute( - () -> { - try { - callback.run(); - } catch (Throwable t) { - LOG.error("Source checkpoint finalization failed:", t); - } - }, - 0); - } - } - } - - private Windmill.WorkItemCommitRequest.Builder initializeOutputBuilder( - final ByteString key, final Windmill.WorkItem workItem) { - return Windmill.WorkItemCommitRequest.newBuilder() - .setKey(key) - .setShardingKey(workItem.getShardingKey()) - .setWorkToken(workItem.getWorkToken()) - .setCacheToken(workItem.getCacheToken()); - } - - private void process( - final ComputationState computationState, - final Instant inputDataWatermark, - final @Nullable Instant outputDataWatermark, - final @Nullable Instant synchronizedProcessingTime, - final Work work) { - final Windmill.WorkItem workItem = work.getWorkItem(); - final String computationId = computationState.getComputationId(); - final ByteString key = workItem.getKey(); - work.setState(State.PROCESSING); - - setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); - - LOG.debug("Starting processing for {}:\n{}", computationId, work); - - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - - // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires - // cleanup should be done before this, since we might exit early here. - callFinalizeCallbacks(workItem); - if (workItem.getSourceState().getOnlyFinalize()) { - outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(State.COMMIT_QUEUED); - workCommitter.commit(Commit.create(outputBuilder.build(), computationState, work)); - return; - } - - long processingStartTimeNanos = System.nanoTime(); - - final MapTask mapTask = computationState.getMapTask(); - - StageInfo stageInfo = - stageInfoMap.computeIfAbsent( - mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); - - @Nullable ExecutionState executionState = null; - String counterName = "dataflow_source_bytes_processed-" + mapTask.getSystemName(); - - try { - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); - } - executionState = computationState.acquireExecutionState().orElse(null); - if (executionState == null) { - MutableNetwork mapTaskNetwork = mapTaskToNetwork.apply(mapTask); - if (LOG.isDebugEnabled()) { - LOG.debug("Network as Graphviz .dot: {}", Networks.toDot(mapTaskNetwork)); - } - ParallelInstructionNode readNode = - (ParallelInstructionNode) - Iterables.find( - mapTaskNetwork.nodes(), - node -> - node instanceof ParallelInstructionNode - && ((ParallelInstructionNode) node).getParallelInstruction().getRead() - != null); - InstructionOutputNode readOutputNode = - (InstructionOutputNode) Iterables.getOnlyElement(mapTaskNetwork.successors(readNode)); - DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = - new DataflowExecutionContext.DataflowExecutionStateTracker( - sampler, - stageInfo - .executionStateRegistry() - .getState( - NameContext.forStage(mapTask.getStageName()), - "other", - null, - ScopedProfiler.INSTANCE.emptyScope()), - stageInfo.deltaCounters(), - options, - work.getLatencyTrackingId()); - StreamingModeExecutionContext context = - new StreamingModeExecutionContext( - streamingCounters.pendingDeltaCounters(), - computationId, - readerCache, - computationState.getTransformUserNameToStateFamily(), - stateCache.forComputation(computationId), - stageInfo.metricsContainerRegistry(), - executionStateTracker, - stageInfo.executionStateRegistry(), - maxSinkBytes); - DataflowMapTaskExecutor mapTaskExecutor = - mapTaskExecutorFactory.create( - mapTaskNetwork, - options, - mapTask.getStageName(), - readerRegistry, - sinkRegistry, - context, - streamingCounters.pendingDeltaCounters(), - ID_GENERATOR); - ReadOperation readOperation = mapTaskExecutor.getReadOperation(); - // Disable progress updates since its results are unused for streaming - // and involves starting a thread. - readOperation.setProgressUpdatePeriodMs(ReadOperation.DONT_UPDATE_PERIODICALLY); - Preconditions.checkState( - mapTaskExecutor.supportsRestart(), - "Streaming runner requires all operations support restart."); - - Coder readCoder; - readCoder = - CloudObjects.coderFromCloudObject( - CloudObject.fromSpec(readOutputNode.getInstructionOutput().getCodec())); - Coder keyCoder = extractKeyCoder(readCoder); - - // If using a custom source, count bytes read for autoscaling. - if (CustomSources.class - .getName() - .equals( - readNode.getParallelInstruction().getRead().getSource().getSpec().get("@type"))) { - NameContext nameContext = - NameContext.create( - mapTask.getStageName(), - readNode.getParallelInstruction().getOriginalName(), - readNode.getParallelInstruction().getSystemName(), - readNode.getParallelInstruction().getName()); - readOperation.receivers[0].addOutputCounter( - counterName, - new OutputObjectAndByteCounter( - new IntrinsicMapTaskExecutorFactory.ElementByteSizeObservableCoder<>( - readCoder), - mapTaskExecutor.getOutputCounters(), - nameContext) - .setSamplingPeriod(100) - .countBytes(counterName)); - } - - ExecutionState.Builder executionStateBuilder = - ExecutionState.builder() - .setWorkExecutor(mapTaskExecutor) - .setContext(context) - .setExecutionStateTracker(executionStateTracker); - - if (keyCoder != null) { - executionStateBuilder.setKeyCoder(keyCoder); - } - - executionState = executionStateBuilder.build(); - } - - WindmillStateReader stateReader = - new WindmillStateReader( - (request) -> - Optional.ofNullable( - metricTrackingWindmillServer.getStateData(computationId, request)), - key, - workItem.getShardingKey(), - workItem.getWorkToken(), - () -> { - work.setState(State.READING); - return () -> work.setState(State.PROCESSING); - }, - work::isFailed); - SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView(); - - // If the read output KVs, then we can decode Windmill's byte key into a userland - // key object and provide it to the execution context for use with per-key state. - // Otherwise, we pass null. - // - // The coder type that will be present is: - // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - Optional> keyCoder = executionState.keyCoder(); - @Nullable - Object executionKey = - !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); - - if (workItem.hasHotKeyInfo()) { - Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); - Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); - - // The MapTask instruction is ordered by dependencies, such that the first element is - // always going to be the shuffle task. - String stepName = computationState.getMapTask().getInstructions().get(0).getName(); - if (options.isHotKeyLoggingEnabled() && keyCoder.isPresent()) { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); - } else { - hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); - } - } - - executionState - .context() - .start( - executionKey, - workItem, - inputDataWatermark, - outputDataWatermark, - synchronizedProcessingTime, - stateReader, - localSideInputStateFetcher, - outputBuilder, - work::isFailed); - - // Blocks while executing work. - executionState.workExecutor().execute(); - - if (work.isFailed()) { - throw new WorkItemCancelledException(workItem.getShardingKey()); - } - // Reports source bytes processed to WorkItemCommitRequest if available. - try { - long sourceBytesProcessed = 0; - HashMap counters = - ((DataflowMapTaskExecutor) executionState.workExecutor()) - .getReadOperation() - .receivers[0] - .getOutputCounters(); - if (counters.containsKey(counterName)) { - sourceBytesProcessed = - ((OutputObjectAndByteCounter) counters.get(counterName)).getByteCount().getAndReset(); - } - outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); - } catch (Exception e) { - LOG.error(e.toString()); - } - - commitCallbacks.putAll(executionState.context().flushState()); - - // Release the execution state for another thread to use. - computationState.releaseExecutionState(executionState); - executionState = null; - - // Add the output to the commit queue. - work.setState(State.COMMIT_QUEUED); - outputBuilder.addAllPerWorkItemLatencyAttributions( - work.getLatencyAttributions(false, work.getLatencyTrackingId(), sampler)); - - WorkItemCommitRequest commitRequest = outputBuilder.build(); - int byteLimit = maxWorkItemCommitBytes.get(); - int commitSize = commitRequest.getSerializedSize(); - int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize; - - // Detect overflow of integer serialized size or if the byte limit was exceeded. - streamingCounters.windmillMaxObservedWorkItemCommitBytes().addValue(estimatedCommitSize); - if (commitSize < 0 || commitSize > byteLimit) { - KeyCommitTooLargeException e = - KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest); - failureTracker.trackFailure(computationId, workItem, e); - LOG.error(e.toString()); - - // Drop the current request in favor of a new, minimal one requesting truncation. - // Messages, timers, counters, and other commit content will not be used by the service - // so we're purposefully dropping them here - commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize); - } - - workCommitter.commit(Commit.create(commitRequest, computationState, work)); - - // Compute shuffle and state byte statistics these will be flushed asynchronously. - long stateBytesWritten = - outputBuilder - .clearOutputMessages() - .clearPerWorkItemLatencyAttributions() - .build() - .getSerializedSize(); - long shuffleBytesRead = 0; - for (Windmill.InputMessageBundle bundle : workItem.getMessageBundlesList()) { - for (Windmill.Message message : bundle.getMessagesList()) { - shuffleBytesRead += message.getSerializedSize(); - } - } - long stateBytesRead = stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead(); - streamingCounters.windmillShuffleBytesRead().addValue(shuffleBytesRead); - streamingCounters.windmillStateBytesRead().addValue(stateBytesRead); - streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); - - LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); - } catch (Throwable t) { - if (executionState != null) { - try { - executionState.context().invalidateCache(); - executionState.workExecutor().close(); - } catch (Exception e) { - LOG.warn("Failed to close map task executor: ", e); - } finally { - // Release references to potentially large objects early. - executionState = null; - } - } - - workFailureProcessor.logAndProcessFailure( - computationId, - work, - t, - invalidWork -> - computationState.completeWorkAndScheduleNextWorkForKey( - createShardedKey(invalidWork), invalidWork.id())); - } finally { - // Update total processing time counters. Updating in finally clause ensures that - // work items causing exceptions are also accounted in time spent. - long processingTimeMsecs = - TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); - stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); - - // Attribute all the processing to timers if the work item contains any timers. - // Tests show that work items rarely contain both timers and message bundles. It should - // be a fairly close approximation. - // Another option: Derive time split between messages and timers based on recent totals. - // either here or in DFE. - if (work.getWorkItem().hasTimers()) { - stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); - } - - sampler.resetForWorkId(work.getLatencyTrackingId()); - DataflowWorkerLoggingMDC.setWorkId(null); - DataflowWorkerLoggingMDC.setStageName(null); - } - } - - private static ShardedKey createShardedKey(Work work) { - return ShardedKey.create(work.getWorkItem().getKey(), work.getWorkItem().getShardingKey()); - } - - private WorkItemCommitRequest buildWorkItemTruncationRequest( - final ByteString key, final Windmill.WorkItem workItem, final int estimatedCommitSize) { - Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); - outputBuilder.setExceedsMaxWorkItemCommitBytes(true); - outputBuilder.setEstimatedWorkItemCommitBytes(estimatedCommitSize); - return outputBuilder.build(); - } - private void onCompleteCommit(CompleteCommit completeCommit) { if (completeCommit.status() != Windmill.CommitStatus.OK) { readerCache.invalidateReader( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 2e9e7e608a50..dd6353060abc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -34,6 +34,7 @@ import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.SideInputReader; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; @@ -47,6 +48,8 @@ import org.apache.beam.runners.dataflow.worker.counters.CounterFactory; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; @@ -57,8 +60,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateInternals; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -83,19 +88,33 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** {@link DataflowExecutionContext} for use in streaming mode. */ +/** + * {@link DataflowExecutionContext} for use in streaming mode. Contains cached readers and Beam + * state pertaining to a processing its owning computation. Can be reused across processing + * different WorkItems for the same computation. + */ @SuppressWarnings({ + "deprecation", "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) +// TODO(m-trieu) fix nullability issues in StreamingModeExecutionContext.java +@NotThreadSafe +@Internal public class StreamingModeExecutionContext extends DataflowExecutionContext { - private static final Logger LOG = LoggerFactory.getLogger(StreamingModeExecutionContext.class); + private final String computationId; - private final Map, Map>> sideInputCache; - // Per-key cache of active Reader objects in use by this process. private final ImmutableMap stateNameMap; private final WindmillStateCache.ForComputation stateCache; private final ReaderCache readerCache; + private volatile long backlogBytes; + + /** + * Used to fetched cache side inputs for processing a single WorkItem. Cleared before processing a + * different WorkItem. + */ + private final Map, Map>> sideInputCache; + /** * The current user-facing key for this execution context. * @@ -107,13 +126,18 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext activeReader; - private volatile long backlogBytes; - private Supplier workIsFailed; + + /** + * Current reader used for processing {@link Work}. Set by calling {@link + * #setActiveReader(UnboundedReader)}, reset to null and cached when state is persisted {@link + * #flushState()}, or set to null and closed when {@link StreamingModeExecutionContext} is + * invalidated. + */ + private @Nullable UnboundedReader activeReader; public StreamingModeExecutionContext( CounterFactory counterFactory, @@ -136,43 +160,53 @@ public StreamingModeExecutionContext( this.sideInputCache = new HashMap<>(); this.stateNameMap = ImmutableMap.copyOf(stateNameMap); this.stateCache = stateCache; - this.backlogBytes = UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN; - this.workIsFailed = () -> Boolean.FALSE; + this.backlogBytes = UnboundedReader.BACKLOG_UNKNOWN; } @VisibleForTesting - public long getBacklogBytes() { + public final long getBacklogBytes() { return backlogBytes; } public boolean workIsFailed() { - return workIsFailed.get(); + return Optional.ofNullable(work).map(Work::isFailed).orElse(false); } public void start( @Nullable Object key, - Windmill.WorkItem work, - Instant inputDataWatermark, - @Nullable Instant outputDataWatermark, - @Nullable Instant synchronizedProcessingTime, + Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder, - @Nullable Supplier workFailed) { + Windmill.WorkItemCommitRequest.Builder outputBuilder) { this.key = key; this.work = work; - this.workIsFailed = (workFailed != null) ? workFailed : () -> Boolean.FALSE; - this.computationKey = - WindmillComputationKey.create(computationId, work.getKey(), work.getShardingKey()); + this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); this.sideInputStateFetcher = sideInputStateFetcher; this.outputBuilder = outputBuilder; this.sideInputCache.clear(); clearSinkFullHint(); + Instant processingTime = computeProcessingTime(work.getWorkItem().getTimers().getTimersList()); + + Collection stepContexts = getAllStepContexts(); + if (!stepContexts.isEmpty()) { + // This must be only created once for the workItem as token validation will fail if the same + // work token is reused. + WindmillStateCache.ForKey cacheForKey = + stateCache.forKey(getComputationKey(), getWorkItem().getCacheToken(), getWorkToken()); + for (StepContext stepContext : stepContexts) { + stepContext.start(stateReader, processingTime, cacheForKey, work.watermarks()); + } + } + } + + /** + * Ensure that the processing time is greater than any fired processing time timers. Otherwise, a + * trigger could ignore the timer and orphan the window. + */ + private static Instant computeProcessingTime(List timers) { Instant processingTime = Instant.now(); - // Ensure that the processing time is greater than any fired processing time - // timers. Otherwise, a trigger could ignore the timer and orphan the window. - for (Windmill.Timer timer : work.getTimers().getTimersList()) { + for (Windmill.Timer timer : timers) { if (timer.getType() == Windmill.Timer.Type.REALTIME) { Instant inferredFiringTime = WindmillTimeUtils.windmillToHarnessTimestamp(timer.getTimestamp()) @@ -183,22 +217,7 @@ public void start( } } - Collection stepContexts = getAllStepContexts(); - if (!stepContexts.isEmpty()) { - // This must be only created once for the workItem as token validation will fail if the same - // work token is reused. - WindmillStateCache.ForKey cacheForKey = - stateCache.forKey(getComputationKey(), getWork().getCacheToken(), getWorkToken()); - for (StepContext stepContext : stepContexts) { - stepContext.start( - stateReader, - inputDataWatermark, - processingTime, - cacheForKey, - outputDataWatermark, - synchronizedProcessingTime); - } - } + return processingTime; } @Override @@ -208,7 +227,7 @@ public StepContext createStepContext(DataflowOperationContext operationContext) @Override protected SideInputReader getSideInputReader( - Iterable sideInputInfos, DataflowOperationContext operationContext) { + Iterable sideInputInfo, DataflowOperationContext operationContext) { throw new UnsupportedOperationException( "Cannot call getSideInputReader for StreamingDataflowWorker: " + "the MapTask specification should not have had any SideInputInfo descriptors " @@ -231,8 +250,8 @@ private TupleTag getInternalTag(PCollectionView view) { * until the active work item is finished. * *

If the side input was not cached, throws {@code IllegalStateException} if the state is - * {@literal CACHED_IN_WORK_ITEM} or returns {@link SideInput} which contains {@link - * Optional}. + * {@link SideInputState#CACHED_IN_WORK_ITEM} or returns {@link SideInput} which contains + * {@link Optional}. */ private SideInput fetchSideInput( PCollectionView view, @@ -285,15 +304,15 @@ private SideInput fetchSideInputFromWindmill( } public Iterable getSideInputNotifications() { - return work.getGlobalDataIdNotificationsList(); + return work.getWorkItem().getGlobalDataIdNotificationsList(); } private List getFiredTimers() { - return work.getTimers().getTimersList(); + return work.getWorkItem().getTimers().getTimersList(); } public @Nullable ByteString getSerializedKey() { - return work == null ? null : work.getKey(); + return work.getWorkItem().getKey(); } public WindmillComputationKey getComputationKey() { @@ -301,11 +320,11 @@ public WindmillComputationKey getComputationKey() { } public long getWorkToken() { - return work.getWorkToken(); + return work.getWorkItem().getWorkToken(); } - public Windmill.WorkItem getWork() { - return work; + public Windmill.WorkItem getWorkItem() { + return work.getWorkItem(); } public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { @@ -316,12 +335,12 @@ public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { * Returns cached reader for this key if one exists. The reader is removed from the cache. NOTE: * The caller is responsible for the reader and should appropriately close it as required. */ - public UnboundedSource.UnboundedReader getCachedReader() { + public UnboundedReader getCachedReader() { return readerCache.acquireReader( - getComputationKey(), getWork().getCacheToken(), getWork().getWorkToken()); + getComputationKey(), getWorkItem().getCacheToken(), getWorkItem().getWorkToken()); } - public void setActiveReader(UnboundedSource.UnboundedReader reader) { + public void setActiveReader(UnboundedReader reader) { checkState(activeReader == null, "not expected to be overwritten"); activeReader = reader; } @@ -339,18 +358,18 @@ public void invalidateCache() { } } activeReader = null; - stateCache.invalidate(key, getWork().getShardingKey()); + stateCache.invalidate(key, getWorkItem().getShardingKey()); } } - public UnboundedSource.CheckpointMark getReaderCheckpoint( + public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( Coder coder) { try { - ByteString state = work.getSourceState().getState(); - if (state.isEmpty()) { + ByteString sourceStateState = work.getWorkItem().getSourceState().getState(); + if (sourceStateState.isEmpty()) { return null; } - return coder.decode(state.newInput(), Coder.Context.OUTER); + return coder.decode(sourceStateState.newInput(), Coder.Context.OUTER); } catch (IOException e) { throw new RuntimeException("Exception while decoding checkpoint", e); } @@ -396,7 +415,7 @@ public Map flushState() { outputBuilder.setSourceWatermark(WindmillTimeUtils.harnessToWindmillTimestamp(watermark)); backlogBytes = activeReader.getSplitBacklogBytes(); - if (backlogBytes == UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN + if (backlogBytes == UnboundedReader.BACKLOG_UNKNOWN && WorkerCustomSources.isFirstUnboundedSourceSplit(getSerializedKey())) { // Only call getTotalBacklogBytes() on the first split. backlogBytes = activeReader.getTotalBacklogBytes(); @@ -404,7 +423,10 @@ public Map flushState() { outputBuilder.setSourceBacklogBytes(backlogBytes); readerCache.cacheReader( - getComputationKey(), getWork().getCacheToken(), getWork().getWorkToken(), activeReader); + getComputationKey(), + getWorkItem().getCacheToken(), + getWorkItem().getWorkToken(), + activeReader); activeReader = null; } return callbacks; @@ -437,7 +459,7 @@ void writePCollectionViewData( } /** - * Execution states in Streaming are shared between multiple map-task executors. Thus this class + * Execution states in Streaming are shared between multiple map-task executors. Thus, this class * needs to be thread safe for multiple writers. A single stage could have multiple executors * running concurrently. */ @@ -566,8 +588,7 @@ public void writePCollectionViewData( Iterable data, Coder> dataCoder, W window, - Coder windowCoder) - throws IOException { + Coder windowCoder) { throw new IllegalStateException("User DoFns cannot write PCollectionView data"); } @@ -648,6 +669,7 @@ public boolean isEmpty() { } } + @NotThreadSafe class StepContext extends DataflowExecutionContext.DataflowStepContext implements StreamingModeStepContext { @@ -668,8 +690,7 @@ class StepContext extends DataflowExecutionContext.DataflowStepContext private NavigableSet modifiedUserProcessingTimersOrdered = null; private NavigableSet modifiedUserSynchronizedProcessingTimersOrdered = null; // A list of timer keys that were modified by user processing earlier in this bundle. This - // serves a tombstone, so - // that we know not to fire any bundle timers that were modified. + // serves a tombstone, so that we know not to fire any bundle timers that were modified. private Table modifiedUserTimerKeys = null; public StepContext(DataflowOperationContext operationContext) { @@ -683,17 +704,15 @@ public StepContext(DataflowOperationContext operationContext) { /** Update the {@code stateReader} used by this {@code StepContext}. */ public void start( WindmillStateReader stateReader, - Instant inputDataWatermark, Instant processingTime, WindmillStateCache.ForKey cacheForKey, - @Nullable Instant outputDataWatermark, - @Nullable Instant synchronizedProcessingTime) { + Watermarks watermarks) { this.stateInternals = new WindmillStateInternals<>( key, stateFamily, stateReader, - work.getIsNewKey(), + work.getWorkItem().getIsNewKey(), cacheForKey.forFamily(stateFamily), scopedReadStateSupplier); @@ -701,20 +720,16 @@ public void start( new WindmillTimerInternals( stateFamily, WindmillNamespacePrefix.SYSTEM_NAMESPACE_PREFIX, - inputDataWatermark, processingTime, - outputDataWatermark, - synchronizedProcessingTime, + watermarks, td -> {}); this.userTimerInternals = new WindmillTimerInternals( stateFamily, WindmillNamespacePrefix.USER_NAMESPACE_PREFIX, - inputDataWatermark, processingTime, - outputDataWatermark, - synchronizedProcessingTime, + watermarks, this::onUserTimerModified); this.cachedFiredSystemTimers = null; @@ -780,11 +795,12 @@ private void onUserTimerModified(TimerData timerData) { WindmillTimerInternals.getTimerDataKey(timerData), timerData.getNamespace(), timerData); } - private boolean timerModified(TimerData timerData) { - String timerKey = WindmillTimerInternals.getTimerDataKey(timerData); + private boolean isTimerUnmodified(TimerData timerData) { @Nullable - TimerData updatedTimer = modifiedUserTimerKeys.get(timerKey, timerData.getNamespace()); - return updatedTimer != null && !updatedTimer.equals(timerData); + TimerData updatedTimer = + modifiedUserTimerKeys.get( + WindmillTimerInternals.getTimerDataKey(timerData), timerData.getNamespace()); + return updatedTimer == null || updatedTimer.equals(timerData); } public TimerData getNextFiredUserTimer(Coder windowCoder) { @@ -815,7 +831,7 @@ public TimerData getNextFiredUserTimer(Coder window while (!modifiedUserTimersOrdered.isEmpty() && modifiedUserTimersOrdered.first().compareTo(nextInBundle) <= 0) { TimerData earlierTimer = modifiedUserTimersOrdered.pollFirst(); - if (!timerModified(earlierTimer)) { + if (isTimerUnmodified(earlierTimer)) { // We must delete the timer. This prevents it from being committed to the backing store. // It also handles the // case where the timer had been set to the far future and then modified in bundle; @@ -828,7 +844,7 @@ public TimerData getNextFiredUserTimer(Coder window } // There is no earlier timer to fire, so return the next timer in the bundle. nextInBundle = cachedFiredUserTimers.next(); - if (!timerModified(nextInBundle)) { + if (isTimerUnmodified(nextInBundle)) { // User timers must be explicitly deleted when delivered, to release the implied hold. userTimerInternals.deleteTimer(nextInBundle); return nextInBundle; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java index 4aac93ceb3fa..cdce6b88ba66 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java @@ -91,7 +91,7 @@ public NativeReader create( @Override public NativeReaderIterator> iterator() throws IOException { - return new UngroupedWindmillReaderIterator(context.getWork()); + return new UngroupedWindmillReaderIterator(context.getWorkItem()); } class UngroupedWindmillReaderIterator extends WindmillReaderIteratorBase { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java index cb397a32e552..502fb605316a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java @@ -26,6 +26,7 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer; import org.apache.beam.sdk.coders.Coder; @@ -65,31 +66,25 @@ class WindmillTimerInternals implements TimerInternals { // though technically in Windmill this is only enforced per ID and namespace // and TimeDomain. This TimerInternals is scoped to a step and key, shared // across namespaces. - private Table timers = HashBasedTable.create(); + private final Table timers = HashBasedTable.create(); // Map from timer id to whether it is to be deleted or set - private Table timerStillPresent = HashBasedTable.create(); + private final Table timerStillPresent = HashBasedTable.create(); - private Instant inputDataWatermark; - private Instant processingTime; - private @Nullable Instant outputDataWatermark; - private @Nullable Instant synchronizedProcessingTime; - private String stateFamily; - private WindmillNamespacePrefix prefix; - private Consumer onTimerModified; + private final Watermarks watermarks; + private final Instant processingTime; + private final String stateFamily; + private final WindmillNamespacePrefix prefix; + private final Consumer onTimerModified; public WindmillTimerInternals( String stateFamily, // unique identifies a step WindmillNamespacePrefix prefix, // partitions user and system namespaces into "/u" and "/s" - Instant inputDataWatermark, Instant processingTime, - @Nullable Instant outputDataWatermark, - @Nullable Instant synchronizedProcessingTime, + Watermarks watermarks, Consumer onTimerModified) { - this.inputDataWatermark = checkNotNull(inputDataWatermark); + this.watermarks = watermarks; this.processingTime = checkNotNull(processingTime); - this.outputDataWatermark = outputDataWatermark; - this.synchronizedProcessingTime = synchronizedProcessingTime; this.stateFamily = stateFamily; this.prefix = prefix; this.onTimerModified = onTimerModified; @@ -97,13 +92,7 @@ public WindmillTimerInternals( public WindmillTimerInternals withPrefix(WindmillNamespacePrefix prefix) { return new WindmillTimerInternals( - stateFamily, - prefix, - inputDataWatermark, - processingTime, - outputDataWatermark, - synchronizedProcessingTime, - onTimerModified); + stateFamily, prefix, processingTime, watermarks, onTimerModified); } @Override @@ -170,7 +159,7 @@ public Instant currentProcessingTime() { @Override public @Nullable Instant currentSynchronizedProcessingTime() { - return synchronizedProcessingTime; + return watermarks.synchronizedProcessingTime(); } /** @@ -184,7 +173,7 @@ public Instant currentProcessingTime() { */ @Override public Instant currentInputWatermarkTime() { - return inputDataWatermark; + return watermarks.inputDataWatermark(); } /** @@ -198,7 +187,7 @@ public Instant currentInputWatermarkTime() { */ @Override public @Nullable Instant currentOutputWatermarkTime() { - return outputDataWatermark; + return watermarks.outputDataWatermark(); } public void persistTo(Windmill.WorkItemCommitRequest.Builder outputBuilder) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 38d00319a4c3..e7c1cdfe0e9e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -117,7 +117,7 @@ public static WindowingWindmillReader create( @Override public NativeReaderIterator>> iterator() throws IOException { final K key = keyCoder.decode(context.getSerializedKey().newInput(), Coder.Context.OUTER); - final WorkItem workItem = context.getWork(); + final WorkItem workItem = context.getWorkItem(); KeyedWorkItem keyedWorkItem = new WindmillKeyedWorkItem<>(key, workItem, windowCoder, windowsCoder, valueCoder); final boolean isEmptyWorkItem = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java index b965110b3ef1..8a25b2e958e8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java @@ -445,7 +445,7 @@ public NativeReaderIterator>> iterator() thro UnboundedSource splitSource = parseSource(splitIndex); - UnboundedSource.CheckpointMark checkpoint = null; + UnboundedSource.@Nullable CheckpointMark checkpoint = null; if (splitSource.getCheckpointMarkCoder() != null) { checkpoint = context.getReaderCheckpoint(splitSource.getCheckpointMarkCoder()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index a989206408e7..3e226514d57e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -69,20 +69,20 @@ public final class ActiveWorkState { * Queue} is actively processing. */ @GuardedBy("this") - private final Map> activeWork; + private final Map> activeWork; @GuardedBy("this") private final WindmillStateCache.ForComputation computationStateCache; /** * Current budget that is being processed or queued on the user worker. Incremented when work is - * activated in {@link #activateWorkForKey(ShardedKey, Work)}, and decremented when work is + * activated in {@link #activateWorkForKey(ExecutableWork)}, and decremented when work is * completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, WorkId)}. */ private final AtomicReference activeGetWorkBudget; private ActiveWorkState( - Map> activeWork, + Map> activeWork, WindmillStateCache.ForComputation computationStateCache) { this.activeWork = activeWork; this.computationStateCache = computationStateCache; @@ -95,7 +95,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState @VisibleForTesting static ActiveWorkState forTesting( - Map> activeWork, + Map> activeWork, WindmillStateCache.ForComputation computationStateCache) { return new ActiveWorkState(activeWork, computationStateCache); } @@ -107,13 +107,14 @@ private static String elapsedString(Instant start, Instant end) { } private static Stream toHeartbeatRequestStream( - Entry> shardedKeyAndWorkQueue, + Entry> shardedKeyAndWorkQueue, Instant refreshDeadline, DataflowExecutionStateSampler sampler) { ShardedKey shardedKey = shardedKeyAndWorkQueue.getKey(); - Deque workQueue = shardedKeyAndWorkQueue.getValue(); + Deque workQueue = shardedKeyAndWorkQueue.getValue(); return workQueue.stream() + .map(ExecutableWork::work) .filter(work -> work.getStartTime().isBefore(refreshDeadline)) // Don't send heartbeats for queued work we already know is failed. .filter(work -> !work.isFailed()) @@ -124,8 +125,7 @@ private static Stream toHeartbeatRequestStream( .setWorkToken(work.getWorkItem().getWorkToken()) .setCacheToken(work.getWorkItem().getCacheToken()) .addAllLatencyAttribution( - work.getLatencyAttributions( - /* isHeartbeat= */ true, work.getLatencyTrackingId(), sampler)) + work.getLatencyAttributions(/* isHeartbeat= */ true, sampler)) .build()); } @@ -146,31 +146,32 @@ private static Stream toHeartbeatRequestStream( *

4. STALE: A work queue for the {@link ShardedKey} exists, and there is a queued {@link Work} * with a greater workToken than the passed in {@link Work}. */ - synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work work) { - Deque workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>()); + synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork) { + ShardedKey shardedKey = executableWork.work().getShardedKey(); + Deque workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>()); // This key does not have any work queued up on it. Create one, insert Work, and mark the work // to be executed. if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) { - workQueue.addLast(work); + workQueue.addLast(executableWork); activeWork.put(shardedKey, workQueue); - incrementActiveWorkBudget(work); + incrementActiveWorkBudget(executableWork.work()); return ActivateWorkResult.EXECUTE; } // Check to see if we have this work token queued. - Iterator workIterator = workQueue.iterator(); + Iterator workIterator = workQueue.iterator(); while (workIterator.hasNext()) { - Work queuedWork = workIterator.next(); - if (queuedWork.id().equals(work.id())) { + ExecutableWork queuedWork = workIterator.next(); + if (queuedWork.id().equals(executableWork.id())) { return ActivateWorkResult.DUPLICATE; } - if (queuedWork.id().cacheToken() == work.id().cacheToken()) { - if (work.id().workToken() > queuedWork.id().workToken()) { + if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()) { + if (executableWork.id().workToken() > queuedWork.id().workToken()) { // Check to see if the queuedWork is active. We only want to remove it if it is NOT // currently active. if (!queuedWork.equals(workQueue.peek())) { workIterator.remove(); - decrementActiveWorkBudget(queuedWork); + decrementActiveWorkBudget(queuedWork.work()); } // Continue here to possibly remove more non-active stale work that is queued. } else { @@ -180,8 +181,8 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w } // Queue the work for later processing. - workQueue.addLast(work); - incrementActiveWorkBudget(work); + workQueue.addLast(executableWork); + incrementActiveWorkBudget(executableWork.work()); return ActivateWorkResult.QUEUED; } @@ -193,11 +194,11 @@ synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, Work w synchronized void failWorkForKey(Multimap failedWork) { // Note we can't construct a ShardedKey and look it up in activeWork directly since // HeartbeatResponse doesn't include the user key. - for (Entry> entry : activeWork.entrySet()) { + for (Entry> entry : activeWork.entrySet()) { Collection failedWorkIds = failedWork.get(entry.getKey().shardingKey()); for (WorkId failedWorkId : failedWorkIds) { - for (Work queuedWork : entry.getValue()) { - WorkItem workItem = queuedWork.getWorkItem(); + for (ExecutableWork queuedWork : entry.getValue()) { + WorkItem workItem = queuedWork.work().getWorkItem(); if (workItem.getWorkToken() == failedWorkId.workToken() && workItem.getCacheToken() == failedWorkId.cacheToken()) { LOG.debug( @@ -210,7 +211,7 @@ synchronized void failWorkForKey(Multimap failedWork) { + " " + failedWorkId.cacheToken() + ". The work will be retried and is not lost."); - queuedWork.setFailed(); + queuedWork.work().setFailed(); break; } } @@ -234,9 +235,9 @@ private void decrementActiveWorkBudget(Work work) { * ShardedKey}'s work queue, if one exists else removes the {@link ShardedKey} from {@link * #activeWork}. */ - synchronized Optional completeWorkAndGetNextWorkForKey( + synchronized Optional completeWorkAndGetNextWorkForKey( ShardedKey shardedKey, WorkId workId) { - @Nullable Queue workQueue = activeWork.get(shardedKey); + @Nullable Queue workQueue = activeWork.get(shardedKey); if (workQueue == null) { // Work may have been completed due to clearing of stuck commits. LOG.warn("Unable to complete inactive work for key {} and token {}.", shardedKey, workId); @@ -247,10 +248,10 @@ synchronized Optional completeWorkAndGetNextWorkForKey( } private synchronized void removeCompletedWorkFromQueue( - Queue workQueue, ShardedKey shardedKey, WorkId workId) { + Queue workQueue, ShardedKey shardedKey, WorkId workId) { // avoid Preconditions.checkState here to prevent eagerly evaluating the // format string parameters for the error message. - Work completedWork = workQueue.peek(); + ExecutableWork completedWork = workQueue.peek(); if (completedWork == null) { // Work may have been completed due to clearing of stuck commits. LOG.warn("Active key {} without work, expected token {}", shardedKey, workId); @@ -272,11 +273,12 @@ private synchronized void removeCompletedWorkFromQueue( // We consumed the matching work item. workQueue.remove(); - decrementActiveWorkBudget(completedWork); + decrementActiveWorkBudget(completedWork.work()); } - private synchronized Optional getNextWork(Queue workQueue, ShardedKey shardedKey) { - Optional nextWork = Optional.ofNullable(workQueue.peek()); + private synchronized Optional getNextWork( + Queue workQueue, ShardedKey shardedKey) { + Optional nextWork = Optional.ofNullable(workQueue.peek()); if (!nextWork.isPresent()) { Preconditions.checkState(workQueue == activeWork.remove(shardedKey)); } @@ -304,10 +306,11 @@ private synchronized ImmutableMap getStuckCommitsAt( // Determine the stuck commit keys but complete them outside the loop iterating over // activeWork as completeWork may delete the entry from activeWork. ImmutableMap.Builder stuckCommits = ImmutableMap.builder(); - for (Entry> entry : activeWork.entrySet()) { + for (Entry> entry : activeWork.entrySet()) { ShardedKey shardedKey = entry.getKey(); - @Nullable Work work = entry.getValue().peek(); - if (work != null) { + @Nullable ExecutableWork executableWork = entry.getValue().peek(); + if (executableWork != null) { + Work work = executableWork.work(); if (work.isStuckCommittingAt(stuckCommitDeadline)) { LOG.error( "Detected key {} stuck in COMMITTING state since {}, completing it with error.", @@ -346,9 +349,9 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) { // Use StringBuilder because we are appending in loop. StringBuilder activeWorkStatus = new StringBuilder(); int commitsPendingCount = 0; - for (Map.Entry> entry : activeWork.entrySet()) { - Queue workQueue = Preconditions.checkNotNull(entry.getValue()); - Work activeWork = Preconditions.checkNotNull(workQueue.peek()); + for (Map.Entry> entry : activeWork.entrySet()) { + Queue workQueue = Preconditions.checkNotNull(entry.getValue()); + Work activeWork = Preconditions.checkNotNull(workQueue.peek()).work(); WorkItem workItem = activeWork.getWorkItem(); if (activeWork.isCommitPending()) { if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index f9466f25577b..434e78484799 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -28,6 +28,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -46,7 +47,8 @@ public class ComputationState { private final ImmutableMap transformUserNameToStateFamily; private final ActiveWorkState activeWorkState; private final BoundedQueueExecutor executor; - private final ConcurrentLinkedQueue executionStateQueue; + private final ConcurrentLinkedQueue computationWorkExecutors; + private final String sourceBytesProcessCounterName; public ComputationState( String computationId, @@ -60,8 +62,10 @@ public ComputationState( this.mapTask = mapTask; this.executor = executor; this.transformUserNameToStateFamily = ImmutableMap.copyOf(transformUserNameToStateFamily); - this.executionStateQueue = new ConcurrentLinkedQueue<>(); + this.computationWorkExecutors = new ConcurrentLinkedQueue<>(); this.activeWorkState = ActiveWorkState.create(computationStateCache); + this.sourceBytesProcessCounterName = + "dataflow_source_bytes_processed-" + mapTask.getSystemName(); } public String getComputationId() { @@ -77,19 +81,20 @@ public ImmutableMap getTransformUserNameToStateFamily() { } /** - * Cache the {@link ExecutionState} so that it can be re-used in future {@link - * #acquireExecutionState()} calls. + * Cache the {@link ComputationWorkExecutor} so that it can be re-used in future {@link + * #acquireComputationWorkExecutor()} calls. */ - public void releaseExecutionState(ExecutionState executionState) { - executionStateQueue.offer(executionState); + public void releaseComputationWorkExecutor(ComputationWorkExecutor computationWorkExecutor) { + computationWorkExecutors.offer(computationWorkExecutor); } /** - * Returns {@link ExecutionState} that was previously offered in {@link - * #releaseExecutionState(ExecutionState)} or {@link Optional#empty()} if one does not exist. + * Returns {@link ComputationWorkExecutor} that was previously offered in {@link + * #releaseComputationWorkExecutor(ComputationWorkExecutor)} or {@link Optional#empty()} if one + * does not exist. */ - public Optional acquireExecutionState() { - return Optional.ofNullable(executionStateQueue.poll()); + public Optional acquireComputationWorkExecutor() { + return Optional.ofNullable(computationWorkExecutors.poll()); } /** @@ -97,8 +102,8 @@ public Optional acquireExecutionState() { * Work} if there is no active {@link Work} for the {@link ShardedKey} already processing. Returns * whether the {@link Work} will be activated, either immediately or sometime in the future. */ - public boolean activateWork(ShardedKey shardedKey, Work work) { - switch (activeWorkState.activateWorkForKey(shardedKey, work)) { + public boolean activateWork(ExecutableWork executableWork) { + switch (activeWorkState.activateWorkForKey(executableWork)) { case DUPLICATE: // Fall through intentionally. Work was not and will not be activated in these cases. case STALE: @@ -107,7 +112,7 @@ public boolean activateWork(ShardedKey shardedKey, Work work) { return true; case EXECUTE: { - execute(work); + execute(executableWork); return true; } default: @@ -134,12 +139,12 @@ public void invalidateStuckCommits(Instant stuckCommitDeadline) { stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey); } - private void execute(Work work) { - executor.execute(work, work.getWorkItem().getSerializedSize()); + private void execute(ExecutableWork executableWork) { + executor.execute(executableWork, executableWork.work().getWorkItem().getSerializedSize()); } - private void forceExecute(Work work) { - executor.forceExecute(work, work.getWorkItem().getSerializedSize()); + private void forceExecute(ExecutableWork executableWork) { + executor.forceExecute(executableWork, executableWork.work().getWorkItem().getSerializedSize()); } /** Gets HeartbeatRequests for any work started before refreshDeadline. */ @@ -156,11 +161,16 @@ public void printActiveWork(PrintWriter writer) { activeWorkState.printActiveWork(writer, Instant.now()); } + public String sourceBytesProcessCounterName() { + return sourceBytesProcessCounterName; + } + + @VisibleForTesting public final void close() { - @Nullable ExecutionState executionState; - while ((executionState = executionStateQueue.poll()) != null) { - executionState.workExecutor().close(); + @Nullable ComputationWorkExecutor computationWorkExecutor; + while ((computationWorkExecutor = computationWorkExecutors.poll()) != null) { + computationWorkExecutor.invalidate(); } - executionStateQueue.clear(); + computationWorkExecutors.clear(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java index 33e28075487f..199ad26aed00 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java @@ -193,6 +193,7 @@ public void loadCacheForTesting( /** * Close all {@link ComputationState}(s) present in the cache, then invalidate the entire cache. */ + @VisibleForTesting public void closeAndInvalidateAll() { computationCache.asMap().values().forEach(ComputationState::close); computationCache.invalidateAll(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java new file mode 100644 index 000000000000..dd34e85bc93c --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import java.util.HashMap; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.runners.core.metrics.ExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; +import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Used to process {@link Work} by executing user DoFns for a specific computation. May be reused to + * process future work items owned a computation. + * + *

Should only be accessed by 1 thread at a time. + * + * @implNote Once closed, it cannot be reused. + */ +// TODO(m-trieu): See if this can be combined/cleaned up with StreamingModeExecutionContext as the +// seperation of responsibilities are unclear. +@AutoValue +@Internal +@NotThreadSafe +public abstract class ComputationWorkExecutor { + private static final Logger LOG = LoggerFactory.getLogger(ComputationWorkExecutor.class); + + public static ComputationWorkExecutor.Builder builder() { + return new AutoValue_ComputationWorkExecutor.Builder(); + } + + public abstract DataflowWorkExecutor workExecutor(); + + public abstract StreamingModeExecutionContext context(); + + public abstract Optional> keyCoder(); + + public abstract ExecutionStateTracker executionStateTracker(); + + /** + * Executes DoFns for the Work. Blocks the calling thread until DoFn(s) have completed execution. + */ + public final void executeWork( + @Nullable Object key, + Work work, + WindmillStateReader stateReader, + SideInputStateFetcher sideInputStateFetcher, + Windmill.WorkItemCommitRequest.Builder outputBuilder) + throws Exception { + context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder); + workExecutor().execute(); + } + + /** + * Callers should only invoke invalidate() when execution of work fails. Once closed, the instance + * cannot be reused. + */ + public final void invalidate() { + context().invalidateCache(); + try { + workExecutor().close(); + } catch (Exception e) { + LOG.warn("Failed to close map task executor: ", e); + } + } + + public final long computeSourceBytesProcessed(String sourceBytesCounterName) { + HashMap counters = + ((DataflowMapTaskExecutor) workExecutor()) + .getReadOperation() + .receivers[0] + .getOutputCounters(); + + return Optional.ofNullable(counters.get(sourceBytesCounterName)) + .map(counter -> ((OutputObjectAndByteCounter) counter).getByteCount().getAndReset()) + .orElse(0L); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); + + public abstract Builder setContext(StreamingModeExecutionContext context); + + public abstract Builder setKeyCoder(Coder keyCoder); + + public abstract Builder setExecutionStateTracker(ExecutionStateTracker executionStateTracker); + + public abstract ComputationWorkExecutor build(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java new file mode 100644 index 000000000000..bdf8a7814ea3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import java.util.function.Consumer; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; + +/** {@link Work} instance and a processing function used to process the work. */ +@AutoValue +public abstract class ExecutableWork implements Runnable { + + public static ExecutableWork create(Work work, Consumer executeWorkFn) { + return new AutoValue_ExecutableWork(work, executeWorkFn); + } + + public abstract Work work(); + + abstract Consumer executeWorkFn(); + + @Override + public void run() { + executeWorkFn().accept(work()); + } + + public final WorkId id() { + return work().id(); + } + + public final Windmill.WorkItem getWorkItem() { + return work().getWorkItem(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java deleted file mode 100644 index ba35179a75b3..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutionState.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.streaming; - -import com.google.auto.value.AutoValue; -import java.util.Optional; -import org.apache.beam.runners.core.metrics.ExecutionStateTracker; -import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; -import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; -import org.apache.beam.sdk.coders.Coder; - -@AutoValue -public abstract class ExecutionState { - - public abstract DataflowWorkExecutor workExecutor(); - - public abstract StreamingModeExecutionContext context(); - - public abstract Optional> keyCoder(); - - public abstract ExecutionStateTracker executionStateTracker(); - - public static ExecutionState.Builder builder() { - return new AutoValue_ExecutionState.Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setWorkExecutor(DataflowWorkExecutor workExecutor); - - public abstract Builder setContext(StreamingModeExecutionContext context); - - public abstract Builder setKeyCoder(Coder keyCoder); - - public abstract Builder setExecutionStateTracker(ExecutionStateTracker executionStateTracker); - - public abstract ExecutionState build(); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Watermarks.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Watermarks.java new file mode 100644 index 000000000000..db6291aa4ee8 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Watermarks.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.joda.time.Instant; + +/** Watermarks for stream pipeline processing. */ +@AutoValue +@Internal +public abstract class Watermarks { + + public static Builder builder() { + return new AutoValue_Watermarks.Builder(); + } + + public abstract Instant inputDataWatermark(); + + public abstract @Nullable Instant synchronizedProcessingTime(); + + public abstract @Nullable Instant outputDataWatermark(); + + @AutoValue.Builder + public abstract static class Builder { + private static boolean hasValidOutputDataWatermark(Watermarks watermarks) { + @Nullable Instant outputDataWatermark = watermarks.outputDataWatermark(); + return outputDataWatermark == null + || !outputDataWatermark.isAfter(watermarks.inputDataWatermark()); + } + + public abstract Builder setInputDataWatermark(Instant value); + + public abstract Builder setSynchronizedProcessingTime(@Nullable Instant value); + + public abstract Builder setOutputDataWatermark(@Nullable Instant value); + + public final Builder setOutputDataWatermark(long outputDataWatermark) { + return setOutputDataWatermark( + WindmillTimeUtils.windmillToHarnessWatermark(outputDataWatermark)); + } + + abstract Watermarks autoBuild(); + + public final Watermarks build() { + Watermarks watermarks = autoBuild(); + Preconditions.checkState(hasValidOutputDataWatermark(watermarks)); + return watermarks; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 64b0eaf5cc08..fa46bac36b58 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -17,74 +17,152 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; + import com.google.auto.value.AutoValue; -import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.IntSummaryStatistics; -import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.ActiveMessageMetadata; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.ActiveLatencyBreakdown; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.ActiveLatencyBreakdown.ActiveElementMetadata; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.ActiveLatencyBreakdown.Distribution; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Duration; import org.joda.time.Instant; +/** + * Represents the state of an attempt to process a {@link WorkItem} by executing user code. + * + * @implNote Not thread safe, should not be executed or accessed by more than 1 thread at a time. + */ @NotThreadSafe -public class Work implements Runnable { - private final Windmill.WorkItem workItem; +@Internal +public class Work { + private final ShardedKey shardedKey; + private final WorkItem workItem; + private final ProcessingContext processingContext; + private final Watermarks watermarks; private final Supplier clock; private final Instant startTime; - private final Map totalDurationPerState; - private final Consumer processWorkFn; + private final Map totalDurationPerState; private final WorkId id; + private final String latencyTrackingId; private TimedState currentState; private volatile boolean isFailed; - private Work(Windmill.WorkItem workItem, Supplier clock, Consumer processWorkFn) { + private Work( + WorkItem workItem, + Watermarks watermarks, + ProcessingContext processingContext, + Supplier clock) { + this.shardedKey = ShardedKey.create(workItem.getKey(), workItem.getShardingKey()); this.workItem = workItem; + this.processingContext = processingContext; + this.watermarks = watermarks; this.clock = clock; - this.processWorkFn = processWorkFn; this.startTime = clock.get(); - this.totalDurationPerState = new EnumMap<>(Windmill.LatencyAttribution.State.class); + this.totalDurationPerState = new EnumMap<>(LatencyAttribution.State.class); + this.id = WorkId.of(workItem); + this.latencyTrackingId = + Long.toHexString(workItem.getShardingKey()) + + '-' + + Long.toHexString(workItem.getWorkToken()); this.currentState = TimedState.initialState(startTime); this.isFailed = false; - this.id = - WorkId.builder() - .setCacheToken(workItem.getCacheToken()) - .setWorkToken(workItem.getWorkToken()) - .build(); } public static Work create( - Windmill.WorkItem workItem, + WorkItem workItem, + Watermarks watermarks, + ProcessingContext processingContext, Supplier clock, - Collection getWorkStreamLatencies, - Consumer processWorkFn) { - Work work = new Work(workItem, clock, processWorkFn); + Collection getWorkStreamLatencies) { + Work work = new Work(workItem, watermarks, processingContext, clock); work.recordGetWorkStreamLatencies(getWorkStreamLatencies); return work; } - @Override - public void run() { - processWorkFn.accept(this); + public static ProcessingContext createProcessingContext( + String computationId, + BiFunction getKeyedDataFn, + Consumer workCommitter) { + return ProcessingContext.create(computationId, getKeyedDataFn, workCommitter); + } + + private static LatencyAttribution.Builder createLatencyAttributionWithActiveLatencyBreakdown( + boolean isHeartbeat, String workId, DataflowExecutionStateSampler sampler) { + LatencyAttribution.Builder latencyAttribution = LatencyAttribution.newBuilder(); + if (isHeartbeat) { + ActiveLatencyBreakdown.Builder stepBuilder = ActiveLatencyBreakdown.newBuilder(); + Optional activeMessage = + sampler.getActiveMessageMetadataForWorkId(workId); + if (!activeMessage.isPresent()) { + return latencyAttribution; + } + stepBuilder.setUserStepName(activeMessage.get().userStepName()); + ActiveElementMetadata.Builder activeElementBuilder = ActiveElementMetadata.newBuilder(); + activeElementBuilder.setProcessingTimeMillis( + activeMessage.get().stopwatch().elapsed().toMillis()); + stepBuilder.setActiveMessageMetadata(activeElementBuilder); + latencyAttribution.addActiveLatencyBreakdown(stepBuilder.build()); + return latencyAttribution; + } + + Map processingDistributions = + sampler.getProcessingDistributionsForWorkId(workId); + for (Entry entry : processingDistributions.entrySet()) { + ActiveLatencyBreakdown.Builder stepBuilder = ActiveLatencyBreakdown.newBuilder(); + stepBuilder.setUserStepName(entry.getKey()); + Distribution.Builder distributionBuilder = + Distribution.newBuilder() + .setCount(entry.getValue().getCount()) + .setMin(entry.getValue().getMin()) + .setMax(entry.getValue().getMax()) + .setMean((long) entry.getValue().getAverage()) + .setSum(entry.getValue().getSum()); + stepBuilder.setProcessingTimesDistribution(distributionBuilder.build()); + latencyAttribution.addActiveLatencyBreakdown(stepBuilder.build()); + } + return latencyAttribution; } - public Windmill.WorkItem getWorkItem() { + public WorkItem getWorkItem() { return workItem; } + public ShardedKey getShardedKey() { + return shardedKey; + } + + public Optional fetchKeyedState(KeyedGetDataRequest keyedGetDataRequest) { + return processingContext.keyedDataFetcher().apply(keyedGetDataRequest); + } + + public Watermarks watermarks() { + return watermarks; + } + public Instant getStartTime() { return startTime; } @@ -115,84 +193,68 @@ public Instant getStateStartTime() { } public String getLatencyTrackingId() { - StringBuilder workIdBuilder = new StringBuilder(33); - workIdBuilder.append(Long.toHexString(workItem.getShardingKey())); - workIdBuilder.append('-'); - workIdBuilder.append(Long.toHexString(workItem.getWorkToken())); - return workIdBuilder.toString(); + return latencyTrackingId; + } + + public final void queueCommit( + WorkItemCommitRequest commitRequest, ComputationState computationState) { + setState(State.COMMIT_QUEUED); + processingContext.workCommitter().accept(Commit.create(commitRequest, computationState, this)); + } + + public WindmillStateReader createWindmillStateReader() { + return WindmillStateReader.forWork(this); } public WorkId id() { return id; } - private void recordGetWorkStreamLatencies( - Collection getWorkStreamLatencies) { - for (Windmill.LatencyAttribution latency : getWorkStreamLatencies) { + private void recordGetWorkStreamLatencies(Collection getWorkStreamLatencies) { + for (LatencyAttribution latency : getWorkStreamLatencies) { totalDurationPerState.put( latency.getState(), Duration.millis(latency.getTotalDurationMillis())); } } public ImmutableList getLatencyAttributions( - boolean isHeartbeat, String workId, DataflowExecutionStateSampler sampler) { - List list = new ArrayList<>(); - for (Windmill.LatencyAttribution.State state : Windmill.LatencyAttribution.State.values()) { - Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); - if (state == this.currentState.state().toLatencyAttributionState()) { - duration = duration.plus(new Duration(this.currentState.startTime(), clock.get())); - } - if (duration.equals(Duration.ZERO)) { - continue; - } - LatencyAttribution.Builder laBuilder = Windmill.LatencyAttribution.newBuilder(); - if (state == LatencyAttribution.State.ACTIVE) { - laBuilder = addActiveLatencyBreakdownToBuilder(isHeartbeat, laBuilder, workId, sampler); - } - Windmill.LatencyAttribution la = - laBuilder.setState(state).setTotalDurationMillis(duration.getMillis()).build(); - list.add(la); - } - return ImmutableList.copyOf(list); + boolean isHeartbeat, DataflowExecutionStateSampler sampler) { + return Arrays.stream(LatencyAttribution.State.values()) + .map(state -> Pair.of(state, getTotalDurationAtLatencyAttributionState(state))) + .filter( + stateAndLatencyAttribution -> + !stateAndLatencyAttribution.getValue().isEqual(Duration.ZERO)) + .map( + stateAndLatencyAttribution -> + createLatencyAttribution( + stateAndLatencyAttribution.getKey(), + isHeartbeat, + sampler, + stateAndLatencyAttribution.getValue())) + .collect(toImmutableList()); } - private static LatencyAttribution.Builder addActiveLatencyBreakdownToBuilder( - boolean isHeartbeat, - LatencyAttribution.Builder builder, - String workId, - DataflowExecutionStateSampler sampler) { - if (isHeartbeat) { - ActiveLatencyBreakdown.Builder stepBuilder = ActiveLatencyBreakdown.newBuilder(); - Optional activeMessage = - sampler.getActiveMessageMetadataForWorkId(workId); - if (!activeMessage.isPresent()) { - return builder; - } - stepBuilder.setUserStepName(activeMessage.get().userStepName()); - ActiveElementMetadata.Builder activeElementBuilder = ActiveElementMetadata.newBuilder(); - activeElementBuilder.setProcessingTimeMillis( - activeMessage.get().stopwatch().elapsed().toMillis()); - stepBuilder.setActiveMessageMetadata(activeElementBuilder); - builder.addActiveLatencyBreakdown(stepBuilder.build()); - return builder; - } + private Duration getTotalDurationAtLatencyAttributionState(LatencyAttribution.State state) { + Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO); + return state == this.currentState.state().toLatencyAttributionState() + ? duration.plus(new Duration(this.currentState.startTime(), clock.get())) + : duration; + } - Map processingDistributions = - sampler.getProcessingDistributionsForWorkId(workId); - for (Entry entry : processingDistributions.entrySet()) { - ActiveLatencyBreakdown.Builder stepBuilder = ActiveLatencyBreakdown.newBuilder(); - stepBuilder.setUserStepName(entry.getKey()); - Distribution.Builder distributionBuilder = - Distribution.newBuilder() - .setCount(entry.getValue().getCount()) - .setMin(entry.getValue().getMin()) - .setMax(entry.getValue().getMax()) - .setMean((long) entry.getValue().getAverage()) - .setSum(entry.getValue().getSum()); - stepBuilder.setProcessingTimesDistribution(distributionBuilder.build()); - builder.addActiveLatencyBreakdown(stepBuilder.build()); - } - return builder; + private LatencyAttribution createLatencyAttribution( + LatencyAttribution.State state, + boolean isHeartbeat, + DataflowExecutionStateSampler sampler, + Duration latencyAttributionDuration) { + LatencyAttribution.Builder latencyAttribution = + state == LatencyAttribution.State.ACTIVE + ? createLatencyAttributionWithActiveLatencyBreakdown( + isHeartbeat, latencyTrackingId, sampler) + : LatencyAttribution.newBuilder(); + return latencyAttribution + .setState(state) + .setTotalDurationMillis(latencyAttributionDuration.getMillis()) + .build(); } public boolean isFailed() { @@ -205,24 +267,22 @@ boolean isStuckCommittingAt(Instant stuckCommitDeadline) { } public enum State { - QUEUED(Windmill.LatencyAttribution.State.QUEUED), - PROCESSING(Windmill.LatencyAttribution.State.ACTIVE), - READING(Windmill.LatencyAttribution.State.READING), - COMMIT_QUEUED(Windmill.LatencyAttribution.State.COMMITTING), - COMMITTING(Windmill.LatencyAttribution.State.COMMITTING), - GET_WORK_IN_WINDMILL_WORKER(Windmill.LatencyAttribution.State.GET_WORK_IN_WINDMILL_WORKER), - GET_WORK_IN_TRANSIT_TO_DISPATCHER( - Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_DISPATCHER), - GET_WORK_IN_TRANSIT_TO_USER_WORKER( - Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_USER_WORKER); - - private final Windmill.LatencyAttribution.State latencyAttributionState; - - State(Windmill.LatencyAttribution.State latencyAttributionState) { + QUEUED(LatencyAttribution.State.QUEUED), + PROCESSING(LatencyAttribution.State.ACTIVE), + READING(LatencyAttribution.State.READING), + COMMIT_QUEUED(LatencyAttribution.State.COMMITTING), + COMMITTING(LatencyAttribution.State.COMMITTING), + GET_WORK_IN_WINDMILL_WORKER(LatencyAttribution.State.GET_WORK_IN_WINDMILL_WORKER), + GET_WORK_IN_TRANSIT_TO_DISPATCHER(LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_DISPATCHER), + GET_WORK_IN_TRANSIT_TO_USER_WORKER(LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_USER_WORKER); + + private final LatencyAttribution.State latencyAttributionState; + + State(LatencyAttribution.State latencyAttributionState) { this.latencyAttributionState = latencyAttributionState; } - Windmill.LatencyAttribution.State toLatencyAttributionState() { + LatencyAttribution.State toLatencyAttributionState() { return latencyAttributionState; } } @@ -249,4 +309,31 @@ private boolean isCommitPending() { abstract Instant startTime(); } + + @AutoValue + public abstract static class ProcessingContext { + + private static ProcessingContext create( + String computationId, + BiFunction getKeyedDataFn, + Consumer workCommitter) { + return new AutoValue_Work_ProcessingContext( + computationId, + request -> Optional.ofNullable(getKeyedDataFn.apply(computationId, request)), + workCommitter); + } + + /** Computation that the {@link Work} belongs to. */ + public abstract String computationId(); + + /** Handles GetData requests to streaming backend. */ + public abstract Function> + keyedDataFetcher(); + + /** + * {@link WorkCommitter} that commits completed work to the backend Windmill worker handling the + * {@link WorkItem}. + */ + public abstract Consumer workCommitter(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java index d56b56c184c9..f8f8d1901914 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkId.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.streaming; import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; /** * A composite key used to identify a unit of {@link Work}. If multiple units of {@link Work} have @@ -33,6 +34,13 @@ public static Builder builder() { return new AutoValue_WorkId.Builder(); } + public static WorkId of(Windmill.WorkItem workItem) { + return WorkId.builder() + .setCacheToken(workItem.getCacheToken()) + .setWorkToken(workItem.getWorkToken()) + .build(); + } + abstract long cacheToken(); abstract long workToken(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java index 2d0885f9f690..7fd2487575c2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java @@ -46,6 +46,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.slf4j.Logger; @@ -69,6 +70,7 @@ public SideInputStateFetcher( this(fetchGlobalDataFn, SideInputCache.create(options)); } + @VisibleForTesting SideInputStateFetcher( Function fetchGlobalDataFn, SideInputCache sideInputCache) { this.fetchGlobalDataFn = fetchGlobalDataFn; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java index e364b0039574..877e3198e91d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java @@ -130,7 +130,7 @@ public NativeReader.DynamicSplitResult requestDynamicSplit( return getReadOperation().requestDynamicSplit(splitRequest); } - public ReadOperation getReadOperation() throws Exception { + public ReadOperation getReadOperation() { if (operations == null || operations.isEmpty()) { throw new IllegalStateException("Map task has no operation."); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index a56097cc8136..6f4b5b7b33fb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -28,6 +28,8 @@ import java.util.function.Supplier; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationWorkItemMetadata; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; @@ -36,10 +38,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItemClient; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; @@ -55,7 +57,7 @@ * Implementation of {@link GetWorkStream} that passes along a specific {@link * org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream} and {@link * org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream} to the - * processing context {@link ProcessWorkItemClient}. During the work item processing lifecycle, + * processing context {@link Work.ProcessingContext}. During the work item processing lifecycle, * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal @@ -76,10 +78,11 @@ public final class GrpcDirectGetWorkStream private final AtomicReference nextBudgetAdjustment; private final AtomicReference pendingResponseBudget; private final GetWorkRequest request; - private final WorkItemProcessor workItemProcessorFn; + private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; private final Supplier getDataStream; - private final Supplier commitWorkStream; + private final Supplier workCommitter; + /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they * come in. Once all chunks for a response has been received, the chunk is processed and the @@ -99,18 +102,18 @@ private GrpcDirectGetWorkStream( int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, Supplier getDataStream, - Supplier commitWorkStream, - WorkItemProcessor workItemProcessorFn) { + Supplier workCommitter, + WorkItemScheduler workItemScheduler) { super( startGetWorkRpcFn, backoff, streamObserverFactory, streamRegistry, logEveryNStreamFailures); this.request = request; this.getWorkThrottleTimer = getWorkThrottleTimer; - this.workItemProcessorFn = workItemProcessorFn; + this.workItemScheduler = workItemScheduler; this.workItemBuffers = new ConcurrentHashMap<>(); // Use the same GetDataStream and CommitWorkStream instances to process all the work in this // stream. this.getDataStream = Suppliers.memoize(getDataStream::get); - this.commitWorkStream = Suppliers.memoize(commitWorkStream::get); + this.workCommitter = Suppliers.memoize(workCommitter::get); this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); @@ -128,8 +131,8 @@ public static GrpcDirectGetWorkStream create( int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, Supplier getDataStream, - Supplier commitWorkStream, - WorkItemProcessor workItemProcessorFn) { + Supplier workCommitter, + WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( startGetWorkRpcFn, @@ -140,12 +143,20 @@ public static GrpcDirectGetWorkStream create( logEveryNStreamFailures, getWorkThrottleTimer, getDataStream, - commitWorkStream, - workItemProcessorFn); + workCommitter, + workItemScheduler); getWorkStream.startStream(); return getWorkStream; } + private static Watermarks createWatermarks(WorkItem workItem, ComputationMetadata metadata) { + return Watermarks.builder() + .setInputDataWatermark(metadata.inputDataWatermark()) + .setOutputDataWatermark(workItem.getOutputDataWatermark()) + .setSynchronizedProcessingTime(metadata.synchronizedProcessingTime()) + .build(); + } + private synchronized GetWorkBudget getThenResetBudgetAdjustment() { return nextBudgetAdjustment.getAndUpdate(unused -> GetWorkBudget.noBudget()); } @@ -299,13 +310,10 @@ private void runAndReset() { try { WorkItem workItem = WorkItem.parseFrom(data.newInput()); updatePendingResponseBudget(1, workItem.getSerializedSize()); - Preconditions.checkNotNull(metadata); - workItemProcessorFn.processWork( - metadata.computationId(), - metadata.inputDataWatermark(), - metadata.synchronizedProcessingTime(), - ProcessWorkItemClient.create( - WorkItem.parseFrom(data.newInput()), getDataStream.get(), commitWorkStream.get()), + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, Preconditions.checkNotNull(metadata)), + createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), // After the work item is successfully queued or dropped by ActiveWorkState, remove it // from the pendingResponseBudget. queuedWorkItem -> updatePendingResponseBudget(-1, -workItem.getSerializedSize()), @@ -316,5 +324,10 @@ private void runAndReset() { workTimingInfosTracker.reset(); data = ByteString.EMPTY; } + + private Work.ProcessingContext createProcessingContext(String computationId) { + return Work.createProcessingContext( + computationId, getDataStream.get()::requestKeyedData, workCommitter.get()::commit); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index 824b40fa42b2..033990017b24 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -24,6 +24,8 @@ import java.util.List; import java.util.Random; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -46,6 +48,7 @@ public class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); private final WindmillStubFactory windmillStubFactory; + private final CountDownLatch onInitializedEndpoints; /** * Current dispatcher endpoints and stubs used to communicate with Windmill Dispatcher. @@ -64,6 +67,7 @@ private GrpcDispatcherClient( this.windmillStubFactory = windmillStubFactory; this.rand = rand; this.dispatcherStubs = new AtomicReference<>(initialDispatcherStubs); + this.onInitializedEndpoints = new CountDownLatch(1); } public static GrpcDispatcherClient create(WindmillStubFactory windmillStubFactory) { @@ -86,7 +90,7 @@ static GrpcDispatcherClient forTesting( new Random()); } - CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { + public CloudWindmillServiceV1Alpha1Stub getWindmillServiceStub() { ImmutableList windmillServiceStubs = dispatcherStubs.get().windmillServiceStubs(); Preconditions.checkState( @@ -101,11 +105,28 @@ ImmutableSet getDispatcherEndpoints() { return dispatcherStubs.get().dispatcherEndpoints(); } - CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStub() { + /** Will block the calling thread until the initial endpoints are present. */ + CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStubBlocking() { + boolean initialized = false; + long secondsWaited = 0; + while (!initialized) { + LOG.info( + "Blocking until Windmill Service endpoint has been set. " + + "Currently waited for [{}] seconds.", + secondsWaited); + try { + initialized = onInitializedEndpoints.await(10, TimeUnit.SECONDS); + secondsWaited += 10; + } catch (InterruptedException e) { + LOG.error( + "Interrupted while waiting for initial Windmill Service endpoints. " + + "These endpoints are required to do any pipeline processing.", + e); + } + } + ImmutableList windmillMetadataServiceStubs = dispatcherStubs.get().windmillMetadataServiceStubs(); - Preconditions.checkState( - !windmillMetadataServiceStubs.isEmpty(), "windmillServiceEndpoint has not been set"); return (windmillMetadataServiceStubs.size() == 1 ? windmillMetadataServiceStubs.get(0) @@ -121,7 +142,7 @@ private synchronized T randomlySelectNextStub(List stubs) { * #dispatcherStubs} will always have a value as empty updates will trigger an {@link * IllegalStateException}. */ - boolean hasInitializedEndpoints() { + public boolean hasInitializedEndpoints() { return dispatcherStubs.get().hasInitializedEndpoints(); } @@ -144,6 +165,7 @@ public synchronized void consumeWindmillDispatcherEndpoints( LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", dispatcherEndpoints); dispatcherStubs.set(DispatcherStubs.create(dispatcherEndpoints, windmillStubFactory)); + onInitializedEndpoints.countDown(); } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 8696c464a0ff..c652e98e5568 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -43,10 +43,11 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; @@ -139,8 +140,8 @@ public GetWorkStream createDirectGetWorkStream( GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, Supplier getDataStream, - Supplier commitWorkStream, - WorkItemProcessor workItemProcessor) { + Supplier workCommitter, + WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), request, @@ -150,8 +151,8 @@ public GetWorkStream createDirectGetWorkStream( logEveryNStreamFailures, getWorkThrottleTimer, getDataStream, - commitWorkStream, - workItemProcessor); + workCommitter, + workItemScheduler); } public GetDataStream createGetDataStream( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java index d7573a55c161..a9ca749ff1cd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import java.util.Collection; import java.util.List; @@ -30,22 +30,26 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; @@ -54,9 +58,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Instant; import org.slf4j.Logger; @@ -73,11 +78,9 @@ public final class StreamingEngineClient { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineClient.class); private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; - - private final AtomicBoolean started; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; - private final WorkItemProcessor workItemProcessor; + private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; private final AtomicBoolean isBudgetRefreshPaused; @@ -89,26 +92,31 @@ public final class StreamingEngineClient { private final long clientId; private final Supplier getWorkerMetadataStream; private final Queue newWindmillEndpoints; + private final Function workCommitterFactory; + private final Consumer> heartbeatResponseProcessor; /** Writes are guarded by synchronization, reads are lock free. */ private final AtomicReference connections; + private volatile boolean started; + @SuppressWarnings("FutureReturnValueIgnored") private StreamingEngineClient( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, - AtomicReference connections, GrpcWindmillStreamFactory streamFactory, - WorkItemProcessor workItemProcessor, + WorkItemScheduler workItemScheduler, ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId) { + long clientId, + Function workCommitterFactory, + Consumer> heartbeatResponseProcessor) { this.jobHeader = jobHeader; - this.started = new AtomicBoolean(); + this.started = false; this.streamFactory = streamFactory; - this.workItemProcessor = workItemProcessor; - this.connections = connections; + this.workItemScheduler = workItemScheduler; + this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; this.isBudgetRefreshPaused = new AtomicBoolean(false); @@ -132,12 +140,14 @@ private StreamingEngineClient( Suppliers.memoize( () -> streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStub(), + dispatcherClient.getWindmillMetadataServiceStubBlocking(), getWorkerMetadataThrottleTimer, endpoints -> // Run this on a separate thread than the grpc stream thread. newWorkerMetadataPublisher.submit( () -> newWindmillEndpoints.add(endpoints)))); + this.workCommitterFactory = workCommitterFactory; + this.heartbeatResponseProcessor = heartbeatResponseProcessor; } private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { @@ -154,65 +164,98 @@ private static ExecutorService singleThreadedExecutorServiceOf(String threadName } /** - * Creates an instance of {@link StreamingEngineClient} and starts the {@link - * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. {@link - * GetWorkerMetadataStream} will populate {@link #connections} when a response is received. + * Creates an instance of {@link StreamingEngineClient} in a non-started state. * - * @implNote Does not block the calling thread. + * @implNote Does not block the calling thread. Callers must explicitly call {@link #start()}. */ public static StreamingEngineClient create( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, GrpcWindmillStreamFactory streamingEngineStreamFactory, - WorkItemProcessor processWorkItem, + WorkItemScheduler processWorkItem, ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, - GrpcDispatcherClient dispatcherClient) { - StreamingEngineClient streamingEngineClient = - new StreamingEngineClient( - jobHeader, - totalGetWorkBudget, - new AtomicReference<>(StreamingEngineConnectionState.EMPTY), - streamingEngineStreamFactory, - processWorkItem, - channelCachingStubFactory, - getWorkBudgetDistributor, - dispatcherClient, - new Random().nextLong()); - streamingEngineClient.start(); - return streamingEngineClient; + GrpcDispatcherClient dispatcherClient, + Function workCommitterFactory, + Consumer> heartbeatProcessor) { + return new StreamingEngineClient( + jobHeader, + totalGetWorkBudget, + streamingEngineStreamFactory, + processWorkItem, + channelCachingStubFactory, + getWorkBudgetDistributor, + dispatcherClient, + /* clientId= */ new Random().nextLong(), + workCommitterFactory, + heartbeatProcessor); } @VisibleForTesting static StreamingEngineClient forTesting( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, - AtomicReference connections, GrpcWindmillStreamFactory streamFactory, - WorkItemProcessor processWorkItem, + WorkItemScheduler processWorkItem, ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId) { + long clientId, + Function workCommitterFactory, + Consumer> heartbeatResponseProcessor) { StreamingEngineClient streamingEngineClient = new StreamingEngineClient( jobHeader, totalGetWorkBudget, - connections, streamFactory, processWorkItem, stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId); + clientId, + workCommitterFactory, + heartbeatResponseProcessor); streamingEngineClient.start(); return streamingEngineClient; } - private void start() { - startGetWorkerMetadataStream(); + @SuppressWarnings("ReturnValueIgnored") + public synchronized void start() { + Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); + // Starts the stream, this value is memoized. + getWorkerMetadataStream.get(); startWorkerMetadataConsumer(); getWorkBudgetRefresher.start(); + started = true; + } + + public ImmutableSet currentWindmillEndpoints() { + return connections.get().windmillConnections().keySet().stream() + .map(Endpoint::directEndpoint) + .filter(Optional::isPresent) + .map(Optional::get) + .filter( + windmillServiceAddress -> + windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) + .map( + windmillServiceAddress -> + windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? windmillServiceAddress.gcpServiceAddress() + : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .collect(toImmutableSet()); + } + + /** + * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link + * GetDataStream} pointing to dispatcher. + */ + public GetDataStream getGlobalDataStream(String globalDataKey) { + return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) + .map(Supplier::get) + .orElseGet( + () -> + streamFactory.createGetDataStream( + dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); } @SuppressWarnings("FutureReturnValueIgnored") @@ -227,10 +270,8 @@ private void startWorkerMetadataConsumer() { } @VisibleForTesting - void finish() { - if (!started.compareAndSet(true, false)) { - return; - } + public synchronized void finish() { + Preconditions.checkState(started, "StreamingEngineClient never started."); getWorkerMetadataStream.get().close(); getWorkBudgetRefresher.stop(); newWorkerMetadataPublisher.shutdownNow(); @@ -266,27 +307,23 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi getWorkBudgetRefresher.requestBudgetRefresh(); } - public ImmutableList getAndResetThrottleTimes() { - StreamingEngineConnectionState currentConnections = connections.get(); - - ImmutableList keyedWorkStreamThrottleTimes = - currentConnections.windmillStreams().values().stream() + /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ + public long getAndResetThrottleTimes() { + return connections.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) - .collect(toImmutableList()); + .reduce(0L, Long::sum) + + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); + } - return ImmutableList.builder() - .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime()) - .addAll(keyedWorkStreamThrottleTimes) - .build(); + public long currentActiveCommitBytes() { + return connections.get().windmillStreams().values().stream() + .map(WindmillStreamSender::getCurrentActiveCommitBytes) + .reduce(0L, Long::sum); } - /** Starts {@link GetWorkerMetadataStream}. */ - @SuppressWarnings({ - "ReturnValueIgnored", // starts the stream, this value is memoized. - }) - private void startGetWorkerMetadataStream() { - started.set(true); - getWorkerMetadataStream.get(); + @VisibleForTesting + StreamingEngineConnectionState getCurrentConnections() { + return connections.get(); } private synchronized ImmutableMap createNewWindmillConnections( @@ -371,7 +408,9 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( .build(), GetWorkBudget.noBudget(), streamFactory, - workItemProcessor); + workItemScheduler, + workCommitterFactory, + heartbeatResponseProcessor); windmillStreamSender.startStreams(); return windmillStreamSender; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java index bef710329ffa..ff9ddc00c3f0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java @@ -17,17 +17,22 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; @@ -59,6 +64,7 @@ public class WindmillStreamSender { private final Supplier getWorkStream; private final Supplier getDataStream; private final Supplier commitWorkStream; + private final Supplier workCommitter; private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; private WindmillStreamSender( @@ -66,7 +72,9 @@ private WindmillStreamSender( GetWorkRequest getWorkRequest, AtomicReference getWorkBudget, GrpcWindmillStreamFactory streamingEngineStreamFactory, - WorkItemProcessor workItemProcessor) { + WorkItemScheduler workItemScheduler, + Function workCommitterFactory, + Consumer> heartbeatResponseProcessor) { this.started = new AtomicBoolean(false); this.getWorkBudget = getWorkBudget; this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create(); @@ -79,12 +87,17 @@ private WindmillStreamSender( Suppliers.memoize( () -> streamingEngineStreamFactory.createGetDataStream( - stub, streamingEngineThrottleTimers.getDataThrottleTimer())); + stub, + streamingEngineThrottleTimers.getDataThrottleTimer(), + false, + heartbeatResponseProcessor)); this.commitWorkStream = Suppliers.memoize( () -> streamingEngineStreamFactory.createCommitWorkStream( stub, streamingEngineThrottleTimers.commitWorkThrottleTimer())); + this.workCommitter = + Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get())); this.getWorkStream = Suppliers.memoize( () -> @@ -93,8 +106,8 @@ private WindmillStreamSender( withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), getDataStream, - commitWorkStream, - workItemProcessor)); + workCommitter, + workItemScheduler)); } public static WindmillStreamSender create( @@ -102,13 +115,17 @@ public static WindmillStreamSender create( GetWorkRequest getWorkRequest, GetWorkBudget getWorkBudget, GrpcWindmillStreamFactory streamingEngineStreamFactory, - WorkItemProcessor workItemProcessor) { + WorkItemScheduler workItemScheduler, + Function workCommitterFactory, + Consumer> heartbeatResponseProcessor) { return new WindmillStreamSender( stub, getWorkRequest, new AtomicReference<>(getWorkBudget), streamingEngineStreamFactory, - workItemProcessor); + workItemScheduler, + workCommitterFactory, + heartbeatResponseProcessor); } private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkBudget budget) { @@ -120,6 +137,7 @@ void startStreams() { getWorkStream.get(); getDataStream.get(); commitWorkStream.get(); + workCommitter.get().start(); // *stream.get() is all memoized in a threadsafe manner. started.set(true); } @@ -131,6 +149,7 @@ void closeAllStreams() { if (started.get()) { getWorkStream.get().close(); getDataStream.get().close(); + workCommitter.get().stop(); commitWorkStream.get().close(); } } @@ -153,4 +172,8 @@ public GetWorkBudget remainingGetWorkBudget() { public long getAndResetThrottleTime() { return streamingEngineThrottleTimers.getAndResetThrottleTime(); } + + public long getCurrentActiveCommitBytes() { + return started.get() ? workCommitter.get().currentActiveCommitBytes() : 0; + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index 36d8254f8442..d11a71807e8e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -39,6 +39,7 @@ import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; @@ -81,37 +82,37 @@ public class WindmillStateReader { * Ideal maximum bytes in a TagBag response. However, Windmill will always return at least one * value if possible irrespective of this limit. */ - public static final long INITIAL_MAX_BAG_BYTES = 8L << 20; // 8MB + @VisibleForTesting static final long INITIAL_MAX_BAG_BYTES = 8L << 20; // 8MB - public static final long CONTINUATION_MAX_BAG_BYTES = 32L << 20; // 32MB + @VisibleForTesting static final long CONTINUATION_MAX_BAG_BYTES = 32L << 20; // 32MB /** * Ideal maximum bytes in a TagMultimapFetchResponse response. However, Windmill will always * return at least one value if possible irrespective of this limit. */ - public static final long INITIAL_MAX_MULTIMAP_BYTES = 8L << 20; // 8MB + @VisibleForTesting static final long INITIAL_MAX_MULTIMAP_BYTES = 8L << 20; // 8MB - public static final long CONTINUATION_MAX_MULTIMAP_BYTES = 32L << 20; // 32MB + @VisibleForTesting static final long CONTINUATION_MAX_MULTIMAP_BYTES = 32L << 20; // 32MB /** * Ideal maximum bytes in a TagSortedList response. However, Windmill will always return at least * one value if possible irrespective of this limit. */ - public static final long MAX_ORDERED_LIST_BYTES = 8L << 20; // 8MB + @VisibleForTesting static final long MAX_ORDERED_LIST_BYTES = 8L << 20; // 8MB /** * Ideal maximum bytes in a tag-value prefix response. However, Windmill will always return at * least one value if possible irrespective of this limit. */ - public static final long MAX_TAG_VALUE_PREFIX_BYTES = 8L << 20; // 8MB + @VisibleForTesting static final long MAX_TAG_VALUE_PREFIX_BYTES = 8L << 20; // 8MB /** * Ideal maximum bytes in a KeyedGetDataResponse. However, Windmill will always return at least * one value if possible irrespective of this limit. */ - public static final long MAX_KEY_BYTES = 16L << 20; // 16MB + @VisibleForTesting static final long MAX_KEY_BYTES = 16L << 20; // 16MB - public static final long MAX_CONTINUATION_KEY_BYTES = 72L << 20; // 72MB + @VisibleForTesting static final long MAX_CONTINUATION_KEY_BYTES = 72L << 20; // 72MB @VisibleForTesting final ConcurrentLinkedQueue> pendingLookups; private final ByteString key; private final long shardingKey; @@ -125,7 +126,7 @@ public class WindmillStateReader { private long bytesRead = 0L; private final Supplier workItemIsFailed; - public WindmillStateReader( + private WindmillStateReader( Function> fetchStateFromWindmillFn, ByteString key, long shardingKey, @@ -152,6 +153,19 @@ static WindmillStateReader forTesting( fetchStateFromWindmillFn, key, shardingKey, workToken, () -> null, () -> Boolean.FALSE); } + public static WindmillStateReader forWork(Work work) { + return new WindmillStateReader( + work::fetchKeyedState, + work.getWorkItem().getKey(), + work.getWorkItem().getShardingKey(), + work.getWorkItem().getWorkToken(), + () -> { + work.setState(Work.State.READING); + return () -> work.setState(Work.State.PROCESSING); + }, + work::isFailed); + } + private Future stateFuture(StateTag stateTag, @Nullable Coder coder) { CoderAndFuture coderAndFuture = new CoderAndFuture<>(coder, SettableFuture.create()); CoderAndFuture existingCoderAndFutureWildcard = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java deleted file mode 100644 index 1adfe02f45fc..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/ProcessWorkItemClient.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.work; - -import com.google.auto.value.AutoValue; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; -import org.apache.beam.sdk.annotations.Internal; - -/** - * A client context to process {@link WorkItem} and route all subsequent Windmill WorkItem API calls - * to the same backend worker. Wraps the {@link WorkItem}. - */ -@AutoValue -@Internal -public abstract class ProcessWorkItemClient { - public static ProcessWorkItemClient create( - WorkItem workItem, GetDataStream getDataStream, CommitWorkStream commitWorkStream) { - return new AutoValue_ProcessWorkItemClient(workItem, getDataStream, commitWorkStream); - } - - /** {@link WorkItem} being processed. */ - public abstract WorkItem workItem(); - - /** - * {@link GetDataStream} that connects to the backend Windmill worker handling the {@link - * WorkItem}. - */ - public abstract GetDataStream getDataStream(); - - /** - * {@link CommitWorkStream} that connects to backend Windmill worker handling the {@link - * WorkItem}. - */ - public abstract CommitWorkStream commitWorkStream(); -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemScheduler.java similarity index 61% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemScheduler.java index 4ebc77775fcd..17c9f7d80d5d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/WorkItemScheduler.java @@ -20,38 +20,31 @@ import java.util.Collection; import java.util.function.Consumer; import javax.annotation.CheckReturnValue; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.sdk.annotations.Internal; -import org.checkerframework.checker.nullness.qual.Nullable; -import org.joda.time.Instant; @FunctionalInterface @CheckReturnValue @Internal -public interface WorkItemProcessor { +public interface WorkItemScheduler { /** - * Receives and processes {@link WorkItem}(s) wrapped in its {@link ProcessWorkItemClient} - * processing context. + * Schedule {@link WorkItem}(s). * - * @param computation the Computation that the Work belongs to. - * @param inputDataWatermark Watermark of when the input data was received by the computation. - * @param synchronizedProcessingTime Aggregate system watermark that also depends on each - * computation's received dependent system watermark value to propagate the system watermark - * downstream. - * @param wrappedWorkItem A workItem and it's processing context, used to route subsequent - * WorkItem API (GetData, CommitWork) RPC calls to the same backend worker, where the WorkItem - * was returned from GetWork. + * @param workItem {@link WorkItem} to be processed. + * @param watermarks processing watermarks for the workItem. + * @param processingContext for processing the workItem. * @param ackWorkItemQueued Called after an attempt to queue the work item for processing. Used to * free up pending budget. * @param getWorkStreamLatencies Latencies per processing stage for the WorkItem for reporting * back to Streaming Engine backend. */ - void processWork( - String computation, - @Nullable Instant inputDataWatermark, - @Nullable Instant synchronizedProcessingTime, - ProcessWorkItemClient wrappedWorkItem, + void scheduleWork( + WorkItem workItem, + Watermarks watermarks, + Work.ProcessingContext processingContext, Consumer ackWorkItemQueued, Collection getWorkStreamLatencies); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java new file mode 100644 index 000000000000..f9b1b45be6c1 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/ComputationWorkExecutorFactory.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.processing; + +import static org.apache.beam.runners.dataflow.DataflowRunner.hasExperiment; + +import com.google.api.services.dataflow.model.MapTask; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.internal.CustomSources; +import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; +import org.apache.beam.runners.dataflow.util.CloudObject; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext; +import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; +import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.IntrinsicMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.ReaderRegistry; +import org.apache.beam.runners.dataflow.worker.SinkRegistry; +import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.apache.beam.runners.dataflow.worker.WindmillKeyedWorkItem; +import org.apache.beam.runners.dataflow.worker.counters.CounterSet; +import org.apache.beam.runners.dataflow.worker.counters.NameContext; +import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge; +import org.apache.beam.runners.dataflow.worker.graph.MapTaskToNetworkFunction; +import org.apache.beam.runners.dataflow.worker.graph.Networks; +import org.apache.beam.runners.dataflow.worker.graph.Nodes; +import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node; +import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; +import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.util.common.worker.MapTaskExecutor; +import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputObjectAndByteCounter; +import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.fn.IdGenerator; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.graph.MutableNetwork; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Factory class for generating {@link ComputationWorkExecutor} instances. */ +final class ComputationWorkExecutorFactory { + + private static final Logger LOG = LoggerFactory.getLogger(ComputationWorkExecutorFactory.class); + private static final String DISABLE_SINK_BYTE_LIMIT_EXPERIMENT = + "disable_limiting_bundle_sink_bytes"; + + private final DataflowWorkerHarnessOptions options; + private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; + private final ReaderCache readerCache; + private final Function stateCacheFactory; + private final ReaderRegistry readerRegistry; + private final SinkRegistry sinkRegistry; + private final DataflowExecutionStateSampler sampler; + private final CounterSet pendingDeltaCounters; + + /** + * Function which converts map tasks to their network representation for execution. + * + *

    + *
  • Translate the map task to a network representation. + *
  • Remove flatten instructions by rewiring edges. + *
+ */ + private final Function> mapTaskToNetwork; + + private final long maxSinkBytes; + private final IdGenerator idGenerator; + + ComputationWorkExecutorFactory( + DataflowWorkerHarnessOptions options, + DataflowMapTaskExecutorFactory mapTaskExecutorFactory, + ReaderCache readerCache, + Function stateCacheFactory, + DataflowExecutionStateSampler sampler, + CounterSet pendingDeltaCounters, + IdGenerator idGenerator) { + this.options = options; + this.mapTaskExecutorFactory = mapTaskExecutorFactory; + this.readerCache = readerCache; + this.stateCacheFactory = stateCacheFactory; + this.idGenerator = idGenerator; + this.readerRegistry = ReaderRegistry.defaultRegistry(); + this.sinkRegistry = SinkRegistry.defaultRegistry(); + this.sampler = sampler; + this.pendingDeltaCounters = pendingDeltaCounters; + this.mapTaskToNetwork = new MapTaskToNetworkFunction(idGenerator); + this.maxSinkBytes = + hasExperiment(options, DISABLE_SINK_BYTE_LIMIT_EXPERIMENT) + ? Long.MAX_VALUE + : StreamingDataflowWorker.MAX_SINK_BYTES; + } + + private static Nodes.ParallelInstructionNode extractReadNode( + MutableNetwork mapTaskNetwork) { + return (Nodes.ParallelInstructionNode) + Iterables.find( + mapTaskNetwork.nodes(), + node -> + node instanceof Nodes.ParallelInstructionNode + && ((Nodes.ParallelInstructionNode) node).getParallelInstruction().getRead() + != null); + } + + private static boolean isCustomSource(Nodes.ParallelInstructionNode readNode) { + return CustomSources.class + .getName() + .equals(readNode.getParallelInstruction().getRead().getSource().getSpec().get("@type")); + } + + private static void trackAutoscalingBytesRead( + MapTask mapTask, + Nodes.ParallelInstructionNode readNode, + Coder readCoder, + ReadOperation readOperation, + MapTaskExecutor mapTaskExecutor, + String counterName) { + NameContext nameContext = + NameContext.create( + mapTask.getStageName(), + readNode.getParallelInstruction().getOriginalName(), + readNode.getParallelInstruction().getSystemName(), + readNode.getParallelInstruction().getName()); + readOperation.receivers[0].addOutputCounter( + counterName, + new OutputObjectAndByteCounter( + new IntrinsicMapTaskExecutorFactory.ElementByteSizeObservableCoder<>(readCoder), + mapTaskExecutor.getOutputCounters(), + nameContext) + .setSamplingPeriod(100) + .countBytes(counterName)); + } + + private static ReadOperation getValidatedReadOperation(MapTaskExecutor mapTaskExecutor) { + ReadOperation readOperation = mapTaskExecutor.getReadOperation(); + // Disable progress updates since its results are unused for streaming + // and involves starting a thread. + readOperation.setProgressUpdatePeriodMs(ReadOperation.DONT_UPDATE_PERIODICALLY); + Preconditions.checkState( + mapTaskExecutor.supportsRestart(), + "Streaming runner requires all operations support restart."); + return readOperation; + } + + ComputationWorkExecutor createComputationWorkExecutor( + StageInfo stageInfo, ComputationState computationState, String workLatencyTrackingId) { + MapTask mapTask = computationState.getMapTask(); + MutableNetwork mapTaskNetwork = mapTaskToNetwork.apply(mapTask); + if (LOG.isDebugEnabled()) { + LOG.debug("Network as Graphviz .dot: {}", Networks.toDot(mapTaskNetwork)); + } + + Nodes.ParallelInstructionNode readNode = extractReadNode(mapTaskNetwork); + Nodes.InstructionOutputNode readOutputNode = + (Nodes.InstructionOutputNode) Iterables.getOnlyElement(mapTaskNetwork.successors(readNode)); + + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = + createExecutionStateTracker(stageInfo, mapTask, workLatencyTrackingId); + StreamingModeExecutionContext context = + createExecutionContext(computationState, stageInfo, executionStateTracker); + DataflowMapTaskExecutor mapTaskExecutor = + createMapTaskExecutor(context, mapTask, mapTaskNetwork); + ReadOperation readOperation = getValidatedReadOperation(mapTaskExecutor); + + Coder readCoder = + CloudObjects.coderFromCloudObject( + CloudObject.fromSpec(readOutputNode.getInstructionOutput().getCodec())); + Coder keyCoder = extractKeyCoder(readCoder); + + // If using a custom source, count bytes read for autoscaling. + if (isCustomSource(readNode)) { + trackAutoscalingBytesRead( + mapTask, + readNode, + readCoder, + readOperation, + mapTaskExecutor, + computationState.sourceBytesProcessCounterName()); + } + + ComputationWorkExecutor.Builder executionStateBuilder = + ComputationWorkExecutor.builder() + .setWorkExecutor(mapTaskExecutor) + .setContext(context) + .setExecutionStateTracker(executionStateTracker); + + if (keyCoder != null) { + executionStateBuilder.setKeyCoder(keyCoder); + } + + return executionStateBuilder.build(); + } + + /** + * Extracts the userland key coder, if any, from the coder used in the initial read step of a + * stage. This encodes many assumptions about how the streaming execution context works. + */ + private @Nullable Coder extractKeyCoder(Coder readCoder) { + if (!(readCoder instanceof WindowedValue.WindowedValueCoder)) { + throw new RuntimeException( + String.format( + "Expected coder for streaming read to be %s, but received %s", + WindowedValue.WindowedValueCoder.class.getSimpleName(), readCoder)); + } + + // Note that TimerOrElementCoder is a backwards-compatibility class + // that is really a FakeKeyedWorkItemCoder + Coder valueCoder = ((WindowedValue.WindowedValueCoder) readCoder).getValueCoder(); + + if (valueCoder instanceof KvCoder) { + return ((KvCoder) valueCoder).getKeyCoder(); + } + if (!(valueCoder instanceof WindmillKeyedWorkItem.FakeKeyedWorkItemCoder)) { + return null; + } + + return ((WindmillKeyedWorkItem.FakeKeyedWorkItemCoder) valueCoder).getKeyCoder(); + } + + private StreamingModeExecutionContext createExecutionContext( + ComputationState computationState, + StageInfo stageInfo, + DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker) { + String computationId = computationState.getComputationId(); + return new StreamingModeExecutionContext( + pendingDeltaCounters, + computationId, + readerCache, + computationState.getTransformUserNameToStateFamily(), + stateCacheFactory.apply(computationId), + stageInfo.metricsContainerRegistry(), + executionStateTracker, + stageInfo.executionStateRegistry(), + maxSinkBytes); + } + + private DataflowMapTaskExecutor createMapTaskExecutor( + StreamingModeExecutionContext context, + MapTask mapTask, + MutableNetwork mapTaskNetwork) { + return mapTaskExecutorFactory.create( + mapTaskNetwork, + options, + mapTask.getStageName(), + readerRegistry, + sinkRegistry, + context, + pendingDeltaCounters, + idGenerator); + } + + private DataflowExecutionContext.DataflowExecutionStateTracker createExecutionStateTracker( + StageInfo stageInfo, MapTask mapTask, String workLatencyTrackingId) { + return new DataflowExecutionContext.DataflowExecutionStateTracker( + sampler, + stageInfo + .executionStateRegistry() + .getState( + NameContext.forStage(mapTask.getStageName()), + "other", + null, + ScopedProfiler.INSTANCE.emptyScope()), + stageInfo.deltaCounters(), + options, + workLatencyTrackingId); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizer.java new file mode 100644 index 000000000000..d663b4fca27a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizer.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.processing; + +import java.time.Duration; +import java.util.Map; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ThreadSafe +@Internal +final class StreamingCommitFinalizer { + private static final Logger LOG = LoggerFactory.getLogger(StreamingCommitFinalizer.class); + private static final Duration DEFAULT_CACHE_ENTRY_EXPIRY = Duration.ofMinutes(5L); + private final Cache commitFinalizerCache; + private final BoundedQueueExecutor finalizationExecutor; + + private StreamingCommitFinalizer( + Cache commitFinalizerCache, BoundedQueueExecutor finalizationExecutor) { + this.commitFinalizerCache = commitFinalizerCache; + this.finalizationExecutor = finalizationExecutor; + } + + static StreamingCommitFinalizer create(BoundedQueueExecutor workExecutor) { + return new StreamingCommitFinalizer( + CacheBuilder.newBuilder().expireAfterWrite(DEFAULT_CACHE_ENTRY_EXPIRY).build(), + workExecutor); + } + + /** + * Stores a map of user worker generated finalization ids and callbacks to execute once a commit + * has been successfully committed to the backing state store. + */ + void cacheCommitFinalizers(Map commitCallbacks) { + commitFinalizerCache.putAll(commitCallbacks); + } + + /** + * When this method is called, the commits associated with the provided finalizeIds have been + * successfully persisted in the backing state store. If the commitCallback for the finalizationId + * is still cached it is invoked. + */ + void finalizeCommits(Iterable finalizeIds) { + for (long finalizeId : finalizeIds) { + @Nullable Runnable finalizeCommit = commitFinalizerCache.getIfPresent(finalizeId); + // NOTE: It is possible the same callback id may be removed twice if + // windmill restarts. + // TODO: It is also possible for an earlier finalized id to be lost. + // We should automatically discard all older callbacks for the same computation and key. + if (finalizeCommit != null) { + commitFinalizerCache.invalidate(finalizeId); + finalizationExecutor.forceExecute( + () -> { + try { + finalizeCommit.run(); + } catch (Throwable t) { + LOG.error("Source checkpoint finalization failed:", t); + } + }, + 0); + } + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java new file mode 100644 index 000000000000..334ab8efeae2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.work.processing; + +import com.google.api.services.dataflow.model.MapTask; +import com.google.auto.value.AutoValue; +import java.util.Collection; +import java.util.Optional; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; +import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; +import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; +import org.apache.beam.runners.dataflow.worker.HotKeyLogger; +import org.apache.beam.runners.dataflow.worker.ReaderCache; +import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationWorkExecutor; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException; +import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.WorkFailureProcessor; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.fn.IdGenerator; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Schedules execution of user code to process a {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} then commits the work item + * back to streaming execution backend. + */ +@Internal +@ThreadSafe +public final class StreamingWorkScheduler { + private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); + + private final DataflowWorkerHarnessOptions options; + private final Supplier clock; + private final ComputationWorkExecutorFactory computationWorkExecutorFactory; + private final SideInputStateFetcher sideInputStateFetcher; + private final FailureTracker failureTracker; + private final WorkFailureProcessor workFailureProcessor; + private final StreamingCommitFinalizer commitFinalizer; + private final StreamingCounters streamingCounters; + private final HotKeyLogger hotKeyLogger; + private final ConcurrentMap stageInfoMap; + private final DataflowExecutionStateSampler sampler; + private final AtomicInteger maxWorkItemCommitBytes; + + public StreamingWorkScheduler( + DataflowWorkerHarnessOptions options, + Supplier clock, + ComputationWorkExecutorFactory computationWorkExecutorFactory, + SideInputStateFetcher sideInputStateFetcher, + FailureTracker failureTracker, + WorkFailureProcessor workFailureProcessor, + StreamingCommitFinalizer commitFinalizer, + StreamingCounters streamingCounters, + HotKeyLogger hotKeyLogger, + ConcurrentMap stageInfoMap, + DataflowExecutionStateSampler sampler, + AtomicInteger maxWorkItemCommitBytes) { + this.options = options; + this.clock = clock; + this.computationWorkExecutorFactory = computationWorkExecutorFactory; + this.sideInputStateFetcher = sideInputStateFetcher; + this.failureTracker = failureTracker; + this.workFailureProcessor = workFailureProcessor; + this.commitFinalizer = commitFinalizer; + this.streamingCounters = streamingCounters; + this.hotKeyLogger = hotKeyLogger; + this.stageInfoMap = stageInfoMap; + this.sampler = sampler; + this.maxWorkItemCommitBytes = maxWorkItemCommitBytes; + } + + public static StreamingWorkScheduler create( + DataflowWorkerHarnessOptions options, + Supplier clock, + ReaderCache readerCache, + DataflowMapTaskExecutorFactory mapTaskExecutorFactory, + BoundedQueueExecutor workExecutor, + Function stateCacheFactory, + Function fetchGlobalDataFn, + FailureTracker failureTracker, + WorkFailureProcessor workFailureProcessor, + StreamingCounters streamingCounters, + HotKeyLogger hotKeyLogger, + DataflowExecutionStateSampler sampler, + AtomicInteger maxWorkItemCommitBytes, + IdGenerator idGenerator, + ConcurrentMap stageInfoMap) { + ComputationWorkExecutorFactory computationWorkExecutorFactory = + new ComputationWorkExecutorFactory( + options, + mapTaskExecutorFactory, + readerCache, + stateCacheFactory, + sampler, + streamingCounters.pendingDeltaCounters(), + idGenerator); + + return new StreamingWorkScheduler( + options, + clock, + computationWorkExecutorFactory, + new SideInputStateFetcher(fetchGlobalDataFn, options), + failureTracker, + workFailureProcessor, + StreamingCommitFinalizer.create(workExecutor), + streamingCounters, + hotKeyLogger, + stageInfoMap, + sampler, + maxWorkItemCommitBytes); + } + + private static long computeShuffleBytesRead(Windmill.WorkItem workItem) { + return workItem.getMessageBundlesList().stream() + .flatMap(bundle -> bundle.getMessagesList().stream()) + .map(Windmill.Message::getSerializedSize) + .map(size -> (long) size) + .reduce(0L, Long::sum); + } + + private static Windmill.WorkItemCommitRequest.Builder initializeOutputBuilder( + ByteString key, Windmill.WorkItem workItem) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(key) + .setShardingKey(workItem.getShardingKey()) + .setWorkToken(workItem.getWorkToken()) + .setCacheToken(workItem.getCacheToken()); + } + + private static Windmill.WorkItemCommitRequest buildWorkItemTruncationRequest( + ByteString key, Windmill.WorkItem workItem, int estimatedCommitSize) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + outputBuilder.setExceedsMaxWorkItemCommitBytes(true); + outputBuilder.setEstimatedWorkItemCommitBytes(estimatedCommitSize); + return outputBuilder.build(); + } + + /** Sets the stage name and workId of the Thread executing the {@link Work} for logging. */ + private static void setUpWorkLoggingContext(String workLatencyTrackingId, String computationId) { + DataflowWorkerLoggingMDC.setWorkId(workLatencyTrackingId); + DataflowWorkerLoggingMDC.setStageName(computationId); + } + + private static String getShuffleTaskStepName(MapTask mapTask) { + // The MapTask instruction is ordered by dependencies, such that the first element is + // always going to be the shuffle task. + return mapTask.getInstructions().get(0).getName(); + } + + /** Resets logging context of the Thread executing the {@link Work} for logging. */ + private void resetWorkLoggingContext(String workLatencyTrackingId) { + sampler.resetForWorkId(workLatencyTrackingId); + DataflowWorkerLoggingMDC.setWorkId(null); + DataflowWorkerLoggingMDC.setStageName(null); + } + + /** + * Schedule work for execution. Work may be executed immediately, or queued and executed in the + * future. Only one work may be "active" (currently executing) per key at a time. + */ + public void scheduleWork( + ComputationState computationState, + Windmill.WorkItem workItem, + Watermarks watermarks, + Work.ProcessingContext processingContext, + Collection getWorkStreamLatencies) { + computationState.activateWork( + ExecutableWork.create( + Work.create(workItem, watermarks, processingContext, clock, getWorkStreamLatencies), + work -> processWork(computationState, work))); + } + + /** + * Executes the user DoFns processing {@link Work} then queues the {@link Commit}(s) to be sent to + * backing persistent store to mark that the {@link Work} has finished processing. May retry + * internally if processing fails due to uncaught {@link Exception}(s). + * + * @implNote This will block the calling thread during execution of user DoFns. + */ + private void processWork(ComputationState computationState, Work work) { + Windmill.WorkItem workItem = work.getWorkItem(); + String computationId = computationState.getComputationId(); + ByteString key = workItem.getKey(); + work.setState(Work.State.PROCESSING); + setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); + LOG.debug("Starting processing for {}:\n{}", computationId, work); + + // Before any processing starts, call any pending OnCommit callbacks. Nothing that requires + // cleanup should be done before this, since we might exit early here. + commitFinalizer.finalizeCommits(workItem.getSourceState().getFinalizeIdsList()); + if (workItem.getSourceState().getOnlyFinalize()) { + Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); + work.setState(Work.State.COMMIT_QUEUED); + work.queueCommit(outputBuilder.build(), computationState); + return; + } + + long processingStartTimeNanos = System.nanoTime(); + MapTask mapTask = computationState.getMapTask(); + StageInfo stageInfo = + stageInfoMap.computeIfAbsent( + mapTask.getStageName(), s -> StageInfo.create(s, mapTask.getSystemName())); + + try { + if (work.isFailed()) { + throw new WorkItemCancelledException(workItem.getShardingKey()); + } + + // Execute the user code for the Work. + ExecuteWorkResult executeWorkResult = executeWork(work, stageInfo, computationState); + Windmill.WorkItemCommitRequest.Builder commitRequest = executeWorkResult.commitWorkRequest(); + + // Validate the commit request, possibly requesting truncation if the commitSize is too large. + Windmill.WorkItemCommitRequest validatedCommitRequest = + validateCommitRequestSize(commitRequest.build(), computationId, workItem); + + // Queue the commit. + work.queueCommit(validatedCommitRequest, computationState); + recordProcessingStats(commitRequest, workItem, executeWorkResult); + LOG.debug("Processing done for work token: {}", workItem.getWorkToken()); + } catch (Throwable t) { + workFailureProcessor.logAndProcessFailure( + computationId, + ExecutableWork.create(work, retry -> processWork(computationState, retry)), + t, + invalidWork -> + computationState.completeWorkAndScheduleNextWorkForKey( + invalidWork.getShardedKey(), invalidWork.id())); + } finally { + // Update total processing time counters. Updating in finally clause ensures that + // work items causing exceptions are also accounted in time spent. + long processingTimeMsecs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos); + stageInfo.totalProcessingMsecs().addValue(processingTimeMsecs); + + // Attribute all the processing to timers if the work item contains any timers. + // Tests show that work items rarely contain both timers and message bundles. It should + // be a fairly close approximation. + // Another option: Derive time split between messages and timers based on recent totals. + // either here or in DFE. + if (work.getWorkItem().hasTimers()) { + stageInfo.timerProcessingMsecs().addValue(processingTimeMsecs); + } + + resetWorkLoggingContext(work.getLatencyTrackingId()); + } + } + + private Windmill.WorkItemCommitRequest validateCommitRequestSize( + Windmill.WorkItemCommitRequest commitRequest, + String computationId, + Windmill.WorkItem workItem) { + int byteLimit = maxWorkItemCommitBytes.get(); + int commitSize = commitRequest.getSerializedSize(); + int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize; + + // Detect overflow of integer serialized size or if the byte limit was exceeded. + // Commit is too large if overflow has occurred or the commitSize has exceeded the allowed + // commit byte limit. + streamingCounters.windmillMaxObservedWorkItemCommitBytes().addValue(estimatedCommitSize); + if (commitSize >= 0 && commitSize < byteLimit) { + return commitRequest; + } + + KeyCommitTooLargeException e = + KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest); + failureTracker.trackFailure(computationId, workItem, e); + LOG.error(e.toString()); + + // Drop the current request in favor of a new, minimal one requesting truncation. + // Messages, timers, counters, and other commit content will not be used by the service + // so, we're purposefully dropping them here + return buildWorkItemTruncationRequest(workItem.getKey(), workItem, estimatedCommitSize); + } + + private void recordProcessingStats( + Windmill.WorkItemCommitRequest.Builder outputBuilder, + Windmill.WorkItem workItem, + ExecuteWorkResult executeWorkResult) { + // Compute shuffle and state byte statistics these will be flushed asynchronously. + long stateBytesWritten = + outputBuilder + .clearOutputMessages() + .clearPerWorkItemLatencyAttributions() + .build() + .getSerializedSize(); + + streamingCounters.windmillShuffleBytesRead().addValue(computeShuffleBytesRead(workItem)); + streamingCounters.windmillStateBytesRead().addValue(executeWorkResult.stateBytesRead()); + streamingCounters.windmillStateBytesWritten().addValue(stateBytesWritten); + } + + private ExecuteWorkResult executeWork( + Work work, StageInfo stageInfo, ComputationState computationState) throws Exception { + Windmill.WorkItem workItem = work.getWorkItem(); + ByteString key = workItem.getKey(); + Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem); + ComputationWorkExecutor computationWorkExecutor = + computationState + .acquireComputationWorkExecutor() + .orElseGet( + () -> + computationWorkExecutorFactory.createComputationWorkExecutor( + stageInfo, computationState, work.getLatencyTrackingId())); + + try { + WindmillStateReader stateReader = work.createWindmillStateReader(); + SideInputStateFetcher localSideInputStateFetcher = sideInputStateFetcher.byteTrackingView(); + + // If the read output KVs, then we can decode Windmill's byte key into userland + // key object and provide it to the execution context for use with per-key state. + // Otherwise, we pass null. + // + // The coder type that will be present is: + // WindowedValueCoder(TimerOrElementCoder(KvCoder)) + Optional> keyCoder = computationWorkExecutor.keyCoder(); + @SuppressWarnings("deprecation") + @Nullable + Object executionKey = + !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); + + if (workItem.hasHotKeyInfo()) { + Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo(); + Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000); + + String stepName = getShuffleTaskStepName(computationState.getMapTask()); + if (options.isHotKeyLoggingEnabled() && keyCoder.isPresent()) { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey); + } else { + hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge); + } + } + + // Blocks while executing work. + computationWorkExecutor.executeWork( + executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + + if (work.isFailed()) { + throw new WorkItemCancelledException(workItem.getShardingKey()); + } + + // Reports source bytes processed to WorkItemCommitRequest if available. + try { + long sourceBytesProcessed = + computationWorkExecutor.computeSourceBytesProcessed( + computationState.sourceBytesProcessCounterName()); + outputBuilder.setSourceBytesProcessed(sourceBytesProcessed); + } catch (Exception e) { + LOG.error(e.toString()); + } + + commitFinalizer.cacheCommitFinalizers(computationWorkExecutor.context().flushState()); + + // Release the execution state for another thread to use. + computationState.releaseComputationWorkExecutor(computationWorkExecutor); + + work.setState(Work.State.COMMIT_QUEUED); + outputBuilder.addAllPerWorkItemLatencyAttributions( + work.getLatencyAttributions(false, sampler)); + + return ExecuteWorkResult.create( + outputBuilder, stateReader.getBytesRead() + localSideInputStateFetcher.getBytesRead()); + } catch (Throwable t) { + // If processing failed due to a thrown exception, close the executionState. Do not + // return/release the executionState back to computationState as that will lead to this + // executionState instance being reused. + computationWorkExecutor.invalidate(); + + // Re-throw the exception, it will be caught and handled by workFailureProcessor downstream. + throw t; + } + } + + @AutoValue + abstract static class ExecuteWorkResult { + private static ExecuteWorkResult create( + Windmill.WorkItemCommitRequest.Builder commitWorkRequest, long stateBytesRead) { + return new AutoValue_StreamingWorkScheduler_ExecuteWorkResult( + commitWorkRequest, stateBytesRead); + } + + abstract Windmill.WorkItemCommitRequest.Builder commitWorkRequest(); + + abstract long stateBytesRead(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java index 594c29e0ad25..2786f287d225 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessor.java @@ -24,6 +24,7 @@ import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.sdk.annotations.Internal; @@ -102,14 +103,17 @@ private static boolean isOutOfMemoryError(Throwable t) { * attempt to retry execution of the {@link Work} or drop it if it is invalid. */ public void logAndProcessFailure( - String computationId, Work work, Throwable t, Consumer onInvalidWork) { - if (shouldRetryLocally(computationId, work, t)) { + String computationId, + ExecutableWork executableWork, + Throwable t, + Consumer onInvalidWork) { + if (shouldRetryLocally(computationId, executableWork.work(), t)) { // Try again after some delay and at the end of the queue to avoid a tight loop. - executeWithDelay(retryLocallyDelayMs, work); + executeWithDelay(retryLocallyDelayMs, executableWork); } else { // Consider the item invalid. It will eventually be retried by Windmill if it still needs to // be processed. - onInvalidWork.accept(work); + onInvalidWork.accept(executableWork.work()); } } @@ -120,9 +124,9 @@ private String tryToDumpHeap() { .orElseGet(() -> "not written"); } - private void executeWithDelay(long delayMs, Work work) { + private void executeWithDelay(long delayMs, ExecutableWork executableWork) { Uninterruptibles.sleepUninterruptibly(delayMs, TimeUnit.MILLISECONDS); - workUnitExecutor.forceExecute(work, work.getWorkItem().getSerializedSize()); + workUnitExecutor.forceExecute(executableWork, executableWork.getWorkItem().getSerializedSize()); } private boolean shouldRetryLocally(String computationId, Work work, Throwable t) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PubsubReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PubsubReaderTest.java index f77166c41583..f65b4af75d5f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PubsubReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PubsubReaderTest.java @@ -51,7 +51,7 @@ public void setUp() throws Exception { } private void testReadWith(String parseFn) throws Exception { - when(mockContext.getWork()) + when(mockContext.getWorkItem()) .thenReturn( Windmill.WorkItem.newBuilder() .setKey(ByteString.copyFromUtf8("key")) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 148adedbafc9..e3aee23e511e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -98,7 +98,9 @@ import org.apache.beam.runners.dataflow.util.Structs; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.testing.RestoreDataflowLoggingMDC; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; @@ -307,15 +309,32 @@ public void cleanUp() { .ifPresent(ComputationStateCache::closeAndInvalidateAll); } - static Work createMockWork(long workToken) { - return createMockWork(workToken, work -> {}); + private static ExecutableWork createMockWork( + ShardedKey shardedKey, long workToken, String computationId) { + return createMockWork(shardedKey, workToken, computationId, ignored -> {}); } - static Work createMockWork(long workToken, Consumer processWorkFn) { - return Work.create( - Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), - Instant::now, - Collections.emptyList(), + private static ExecutableWork createMockWork( + ShardedKey shardedKey, long workToken, Consumer processWorkFn) { + return createMockWork(shardedKey, workToken, "computationId", processWorkFn); + } + + private static ExecutableWork createMockWork( + ShardedKey shardedKey, long workToken, String computationId, Consumer processWorkFn) { + return ExecutableWork.create( + Work.create( + Windmill.WorkItem.newBuilder() + .setKey(shardedKey.key()) + .setShardingKey(shardedKey.shardingKey()) + .setWorkToken(workToken) + .build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + computationId, + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + Instant::now, + Collections.emptyList()), processWorkFn); } @@ -324,6 +343,12 @@ private byte[] intervalWindowBytes(IntervalWindow window) throws Exception { DEFAULT_WINDOW_COLLECTION_CODER, Collections.singletonList(window)); } + /** + * Add options with the following format: "--{OPTION_NAME}={OPTION_VALUE}" with 1 option flag per + * String. + * + *

Example: {@code defaultWorkerParams("--option_a=1", "--option_b=foo", "--option_c=bar");} + */ private StreamingDataflowWorkerTestParams.Builder defaultWorkerParams(String... options) { return StreamingDataflowWorkerTestParams.builder() .setOptions(createTestingPipelineOptions(options)); @@ -2707,36 +2732,37 @@ public void testUnboundedSourceWorkRetry() throws Exception { } @Test - public void testActiveWork() throws Exception { + public void testActiveWork() { BoundedQueueExecutor mockExecutor = Mockito.mock(BoundedQueueExecutor.class); + String computationId = "computation"; ComputationState computationState = new ComputationState( - "computation", + computationId, defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), mockExecutor, ImmutableMap.of(), null); ShardedKey key1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1); - ShardedKey key2 = ShardedKey.create(ByteString.copyFromUtf8("key2"), 2); - Work m1 = createMockWork(1); - assertTrue(computationState.activateWork(key1, m1)); + ExecutableWork m1 = createMockWork(key1, 1, computationId); + assertTrue(computationState.activateWork(m1)); Mockito.verify(mockExecutor).execute(m1, m1.getWorkItem().getSerializedSize()); computationState.completeWorkAndScheduleNextWorkForKey(key1, m1.id()); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify work queues. - Work m2 = createMockWork(2); - assertTrue(computationState.activateWork(key1, m2)); + ExecutableWork m2 = createMockWork(key1, 2, computationId); + assertTrue(computationState.activateWork(m2)); Mockito.verify(mockExecutor).execute(m2, m2.getWorkItem().getSerializedSize()); - Work m3 = createMockWork(3); - assertTrue(computationState.activateWork(key1, m3)); + ExecutableWork m3 = createMockWork(key1, 3, computationId); + assertTrue(computationState.activateWork(m3)); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify another key is a separate queue. - Work m4 = createMockWork(4); - assertTrue(computationState.activateWork(key2, m4)); + ShardedKey key2 = ShardedKey.create(ByteString.copyFromUtf8("key2"), 2); + ExecutableWork m4 = createMockWork(key2, 4, computationId); + assertTrue(computationState.activateWork(m4)); Mockito.verify(mockExecutor).execute(m4, m4.getWorkItem().getSerializedSize()); computationState.completeWorkAndScheduleNextWorkForKey(key2, m4.id()); Mockito.verifyNoMoreInteractions(mockExecutor); @@ -2747,21 +2773,22 @@ public void testActiveWork() throws Exception { Mockito.verifyNoMoreInteractions(mockExecutor); // Verify duplicate work dropped. - Work m5 = createMockWork(5); - computationState.activateWork(key1, m5); + ExecutableWork m5 = createMockWork(key1, 5, computationId); + computationState.activateWork(m5); Mockito.verify(mockExecutor).execute(m5, m5.getWorkItem().getSerializedSize()); - assertFalse(computationState.activateWork(key1, m5)); + assertFalse(computationState.activateWork(m5)); Mockito.verifyNoMoreInteractions(mockExecutor); computationState.completeWorkAndScheduleNextWorkForKey(key1, m5.id()); Mockito.verifyNoMoreInteractions(mockExecutor); } @Test - public void testActiveWorkForShardedKeys() throws Exception { + public void testActiveWorkForShardedKeys() { BoundedQueueExecutor mockExecutor = Mockito.mock(BoundedQueueExecutor.class); + String computationId = "computation"; ComputationState computationState = new ComputationState( - "computation", + computationId, defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), mockExecutor, ImmutableMap.of(), @@ -2770,29 +2797,30 @@ public void testActiveWorkForShardedKeys() throws Exception { ShardedKey key1Shard1 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 1); ShardedKey key1Shard2 = ShardedKey.create(ByteString.copyFromUtf8("key1"), 2); - Work m1 = createMockWork(1); - assertTrue(computationState.activateWork(key1Shard1, m1)); + ExecutableWork m1 = createMockWork(key1Shard1, 1, computationId); + assertTrue(computationState.activateWork(m1)); Mockito.verify(mockExecutor).execute(m1, m1.getWorkItem().getSerializedSize()); computationState.completeWorkAndScheduleNextWorkForKey(key1Shard1, m1.id()); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify work queues. - Work m2 = createMockWork(2); - assertTrue(computationState.activateWork(key1Shard1, m2)); + ExecutableWork m2 = createMockWork(key1Shard1, 2, computationId); + assertTrue(computationState.activateWork(m2)); Mockito.verify(mockExecutor).execute(m2, m2.getWorkItem().getSerializedSize()); - Work m3 = createMockWork(3); - assertTrue(computationState.activateWork(key1Shard1, m3)); + ExecutableWork m3 = createMockWork(key1Shard1, 3, computationId); + assertTrue(computationState.activateWork(m3)); Mockito.verifyNoMoreInteractions(mockExecutor); // Verify a different shard of key is a separate queue. - Work m4 = createMockWork(3); - assertFalse(computationState.activateWork(key1Shard1, m4)); + ExecutableWork m4 = createMockWork(key1Shard1, 3, computationId); + assertFalse(computationState.activateWork(m4)); Mockito.verifyNoMoreInteractions(mockExecutor); - assertTrue(computationState.activateWork(key1Shard2, m4)); - Mockito.verify(mockExecutor).execute(m4, m4.getWorkItem().getSerializedSize()); + ExecutableWork m4Shard2 = createMockWork(key1Shard2, 3, computationId); + assertTrue(computationState.activateWork(m4Shard2)); + Mockito.verify(mockExecutor).execute(m4Shard2, m4Shard2.getWorkItem().getSerializedSize()); // Verify duplicate work dropped - assertFalse(computationState.activateWork(key1Shard2, m4)); + assertFalse(computationState.activateWork(m4Shard2)); computationState.completeWorkAndScheduleNextWorkForKey(key1Shard2, m4.id()); Mockito.verifyNoMoreInteractions(mockExecutor); } @@ -2838,11 +2866,11 @@ public void testMaxThreadMetric() throws Exception { } }; - Work m2 = createMockWork(2, sleepProcessWorkFn); - Work m3 = createMockWork(3, sleepProcessWorkFn); + ExecutableWork m2 = createMockWork(key1Shard1, 2, sleepProcessWorkFn); + ExecutableWork m3 = createMockWork(key1Shard1, 3, sleepProcessWorkFn); - assertTrue(computationState.activateWork(key1Shard1, m2)); - assertTrue(computationState.activateWork(key1Shard1, m3)); + assertTrue(computationState.activateWork(m2)); + assertTrue(computationState.activateWork(m3)); executor.execute(m2, m2.getWorkItem().getSerializedSize()); executor.execute(m3, m3.getWorkItem().getSerializedSize()); @@ -2879,7 +2907,7 @@ public void testActiveThreadMetric() throws Exception { ComputationState computationState = new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), executor, ImmutableMap.of(), null); @@ -2897,21 +2925,21 @@ public void testActiveThreadMetric() throws Exception { } }; - Work m2 = createMockWork(2, sleepProcessWorkFn); + ExecutableWork m2 = createMockWork(key1Shard1, 2, sleepProcessWorkFn); - Work m3 = createMockWork(3, sleepProcessWorkFn); + ExecutableWork m3 = createMockWork(key1Shard1, 3, sleepProcessWorkFn); - Work m4 = createMockWork(4, sleepProcessWorkFn); + ExecutableWork m4 = createMockWork(key1Shard1, 4, sleepProcessWorkFn); assertEquals(0, executor.activeCount()); - assertTrue(computationState.activateWork(key1Shard1, m2)); + assertTrue(computationState.activateWork(m2)); // activate work starts executing work if no other work is queued for that shard executor.execute(m2, m2.getWorkItem().getSerializedSize()); processStart1.await(); assertEquals(2, executor.activeCount()); - assertTrue(computationState.activateWork(key1Shard1, m3)); - assertTrue(computationState.activateWork(key1Shard1, m4)); + assertTrue(computationState.activateWork(m3)); + assertTrue(computationState.activateWork(m4)); executor.execute(m3, m3.getWorkItem().getSerializedSize()); processStart2.await(); @@ -2948,7 +2976,7 @@ public void testOutstandingBytesMetric() throws Exception { ComputationState computationState = new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), executor, ImmutableMap.of(), null); @@ -2965,23 +2993,23 @@ public void testOutstandingBytesMetric() throws Exception { } }; - Work m2 = createMockWork(2, sleepProcessWorkFn); + ExecutableWork m2 = createMockWork(key1Shard1, 2, sleepProcessWorkFn); - Work m3 = createMockWork(3, sleepProcessWorkFn); + ExecutableWork m3 = createMockWork(key1Shard1, 3, sleepProcessWorkFn); - Work m4 = createMockWork(4, sleepProcessWorkFn); + ExecutableWork m4 = createMockWork(key1Shard1, 4, sleepProcessWorkFn); assertEquals(0, executor.bytesOutstanding()); long bytes = m2.getWorkItem().getSerializedSize(); - assertTrue(computationState.activateWork(key1Shard1, m2)); + assertTrue(computationState.activateWork(m2)); // activate work starts executing work if no other work is queued for that shard bytes += m2.getWorkItem().getSerializedSize(); executor.execute(m2, m2.getWorkItem().getSerializedSize()); processStart1.await(); assertEquals(bytes, executor.bytesOutstanding()); - assertTrue(computationState.activateWork(key1Shard1, m3)); - assertTrue(computationState.activateWork(key1Shard1, m4)); + assertTrue(computationState.activateWork(m3)); + assertTrue(computationState.activateWork(m4)); bytes += m3.getWorkItem().getSerializedSize(); executor.execute(m3, m3.getWorkItem().getSerializedSize()); @@ -3021,7 +3049,7 @@ public void testOutstandingBundlesMetric() throws Exception { ComputationState computationState = new ComputationState( "computation", - defaultMapTask(Arrays.asList(makeSourceInstruction(StringUtf8Coder.of()))), + defaultMapTask(Collections.singletonList(makeSourceInstruction(StringUtf8Coder.of()))), executor, ImmutableMap.of(), null); @@ -3038,21 +3066,21 @@ public void testOutstandingBundlesMetric() throws Exception { } }; - Work m2 = createMockWork(2, sleepProcessWorkFn); + ExecutableWork m2 = createMockWork(key1Shard1, 2, sleepProcessWorkFn); - Work m3 = createMockWork(3, sleepProcessWorkFn); + ExecutableWork m3 = createMockWork(key1Shard1, 3, sleepProcessWorkFn); - Work m4 = createMockWork(4, sleepProcessWorkFn); + ExecutableWork m4 = createMockWork(key1Shard1, 4, sleepProcessWorkFn); assertEquals(0, executor.elementsOutstanding()); - assertTrue(computationState.activateWork(key1Shard1, m2)); + assertTrue(computationState.activateWork(m2)); // activate work starts executing work if no other work is queued for that shard executor.execute(m2, m2.getWorkItem().getSerializedSize()); processStart1.await(); assertEquals(2, executor.elementsOutstanding()); - assertTrue(computationState.activateWork(key1Shard1, m3)); - assertTrue(computationState.activateWork(key1Shard1, m4)); + assertTrue(computationState.activateWork(m3)); + assertTrue(computationState.activateWork(m4)); executor.execute(m3, m3.getWorkItem().getSerializedSize()); processStart2.await(); @@ -3378,14 +3406,14 @@ public void testLatencyAttributionProtobufsPopulated() { FakeClock clock = new FakeClock(); Work work = Work.create( - Windmill.WorkItem.newBuilder() - .setKey(ByteString.EMPTY) - .setWorkToken(1L) - .setCacheToken(1L) - .build(), + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(1L).build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), clock, - Collections.emptyList(), - unused -> {}); + Collections.emptyList()); clock.sleep(Duration.millis(10)); work.setState(Work.State.PROCESSING); @@ -3400,7 +3428,7 @@ public void testLatencyAttributionProtobufsPopulated() { clock.sleep(Duration.millis(60)); Iterator it = - work.getLatencyAttributions(false, "", DataflowExecutionStateSampler.instance()).iterator(); + work.getLatencyAttributions(false, DataflowExecutionStateSampler.instance()).iterator(); assertTrue(it.hasNext()); LatencyAttribution lat = it.next(); assertSame(State.QUEUED, lat.getState()); @@ -4058,16 +4086,6 @@ public void processElement(ProcessContext c) { } } - private static class MockWork { - Work create(long workToken) { - return Work.create( - Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(workToken).build(), - Instant::now, - Collections.emptyList(), - work -> {}); - } - } - static class TestExceptionInvalidatesCacheFn extends DoFn>, String> { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 158fbee37533..2193f20f3fe3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -30,6 +30,7 @@ import com.google.api.services.dataflow.model.CounterStructuredNameAndMetadata; import com.google.api.services.dataflow.model.CounterUpdate; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -55,6 +56,8 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; @@ -89,6 +92,8 @@ public class StreamingModeExecutionContextTest { @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; + private static final String COMPUTATION_ID = "computationId"; + private final StreamingModeExecutionStateRegistry executionStateRegistry = new StreamingModeExecutionStateRegistry(); private StreamingModeExecutionContext executionContext; @@ -104,7 +109,7 @@ public void setUp() { executionContext = new StreamingModeExecutionContext( counterSet, - "computationId", + COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()), stateNameMap, WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"), @@ -120,6 +125,18 @@ public void setUp() { Long.MAX_VALUE); } + private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { + return Work.create( + workItem, + watermarks, + Work.createProcessingContext( + COMPUTATION_ID, + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + Instant::now, + Collections.emptyList()); + } + @Test public void testTimerInternalsSetTimer() { Windmill.WorkItemCommitRequest.Builder outputBuilder = @@ -129,16 +146,15 @@ public void testTimerInternalsSetTimer() { executionContext.createOperationContext(nameContext); StreamingModeExecutionContext.StepContext stepContext = executionContext.getStepContext(operationContext); + executionContext.start( "key", - Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), - new Instant(1000), // input watermark - null, // output watermark - null, // synchronized processing time + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, - outputBuilder, - null); + outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); @@ -182,14 +198,12 @@ public void testTimerInternalsProcessingTimeSkew() { executionContext.start( "key", - workItemBuilder.build(), - new Instant(1000), // input watermark - null, // output watermark - null, // synchronized processing time + createMockWork( + workItemBuilder.build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, - outputBuilder, - null); + outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index cc9e6da4a735..9f97c9835ddc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -49,6 +49,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage; +import static org.mockito.Mockito.mock; import com.google.api.services.dataflow.model.ApproximateReportedProgress; import com.google.api.services.dataflow.model.DataflowPackage; @@ -87,12 +88,15 @@ import org.apache.beam.runners.dataflow.worker.counters.CounterSet; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.NoopProfileScope; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; @@ -138,6 +142,8 @@ public class WorkerCustomSourcesTest { @Rule public ExpectedException expectedException = ExpectedException.none(); @Rule public ExpectedLogs logged = ExpectedLogs.none(WorkerCustomSources.class); + private static final String COMPUTATION_ID = "computationId"; + private DataflowPipelineOptions options; @Before @@ -186,6 +192,18 @@ public void testSplitAndReadBundlesBack() throws Exception { } } + private static Work createMockWork(Windmill.WorkItem workItem, Watermarks watermarks) { + return Work.create( + workItem, + watermarks, + Work.createProcessingContext( + COMPUTATION_ID, + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + Instant::now, + Collections.emptyList()); + } + private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -579,7 +597,7 @@ public void testReadUnboundedReader() throws Exception { StreamingModeExecutionContext context = new StreamingModeExecutionContext( counterSet, - "computationId", + COMPUTATION_ID, readerCache, /*stateNameMap=*/ ImmutableMap.of(), /*stateCache=*/ null, @@ -605,20 +623,18 @@ public void testReadUnboundedReader() throws Exception { // Initialize streaming context with state from previous iteration. context.start( "key", - Windmill.WorkItem.newBuilder() - .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. - .setWorkToken(i) // Must be increasing across activations for cache to be used. - .setCacheToken(1) - .setSourceState( - Windmill.SourceState.newBuilder().setState(state).build()) // Source state. - .build(), - new Instant(0), // input watermark - null, // output watermark - null, // synchronized processing time - null, // StateReader - null, // StateFetcher - Windmill.WorkItemCommitRequest.newBuilder(), - null); + createMockWork( + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("0000000000000001")) // key is zero-padded index. + .setWorkToken(i) // Must be increasing across activations for cache to be used. + .setCacheToken(1) + .setSourceState( + Windmill.SourceState.newBuilder().setState(state).build()) // Source state. + .build(), + Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), + mock(WindmillStateReader.class), + mock(SideInputStateFetcher.class), + Windmill.WorkItemCommitRequest.newBuilder()); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -645,10 +661,10 @@ public void testReadUnboundedReader() throws Exception { numReadOnThisIteration++; } Instant afterReading = Instant.now(); - long maxReadMs = debugOptions.getUnboundedReaderMaxReadTimeMs(); + long maxReadSec = debugOptions.getUnboundedReaderMaxReadTimeSec(); assertThat( - new Duration(beforeReading, afterReading).getMillis(), - lessThanOrEqualTo(maxReadMs + 1000L)); + new Duration(beforeReading, afterReading).getStandardSeconds(), + lessThanOrEqualTo(maxReadSec + 1)); assertThat( numReadOnThisIteration, lessThanOrEqualTo(debugOptions.getUnboundedReaderMaxElements())); @@ -665,7 +681,7 @@ public void testReadUnboundedReader() throws Exception { assertNotNull( readerCache.acquireReader( context.getComputationKey(), - context.getWork().getCacheToken(), + context.getWorkItem().getCacheToken(), context.getWorkToken() + 1)); assertEquals(7L, context.getBacklogBytes()); } @@ -945,11 +961,10 @@ public void testFailedWorkItemsAbort() throws Exception { StreamingModeExecutionContext context = new StreamingModeExecutionContext( counterSet, - "computationId", + COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Runnable::run), /*stateNameMap=*/ ImmutableMap.of(), - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()) - .forComputation("computationId"), + WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation(COMPUTATION_ID), StreamingStepMetricsContainer.createRegistry(), new DataflowExecutionStateTracker( ExecutionStateSampler.newForTest(), @@ -975,18 +990,22 @@ public void testFailedWorkItemsAbort() throws Exception { .setSourceState( Windmill.SourceState.newBuilder().setState(state).build()) // Source state. .build(); - Work dummyWork = Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {}); - + Work dummyWork = + Work.create( + workItem, + Watermarks.builder().setInputDataWatermark(new Instant(0)).build(), + Work.createProcessingContext( + COMPUTATION_ID, + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + gnored -> {}), + Instant::now, + Collections.emptyList()); context.start( "key", - workItem, - new Instant(0), // input watermark - null, // output watermark - null, // synchronized processing time - null, // StateReader - null, // StateFetcher - Windmill.WorkItemCommitRequest.newBuilder(), - dummyWork::isFailed); + dummyWork, + mock(WindmillStateReader.class), + mock(SideInputStateFetcher.class), + Windmill.WorkItemCommitRequest.newBuilder()); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index c581638d98bf..3a3e0a34c217 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -53,7 +53,7 @@ public class ActiveWorkStateTest { private final WindmillStateCache.ForComputation computationStateCache = mock(WindmillStateCache.ForComputation.class); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private Map> readOnlyActiveWork; + private Map> readOnlyActiveWork; private ActiveWorkState activeWorkState; @@ -61,22 +61,44 @@ private static ShardedKey shardedKey(String str, long shardKey) { return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey); } - private static Work createWork(Windmill.WorkItem workItem) { - return Work.create(workItem, Instant::now, Collections.emptyList(), unused -> {}); + private static ExecutableWork createWork(Windmill.WorkItem workItem) { + return ExecutableWork.create( + Work.create( + workItem, + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + createWorkProcessingContext(), + Instant::now, + Collections.emptyList()), + ignored -> {}); } - private static Work expiredWork(Windmill.WorkItem workItem) { - return Work.create(workItem, () -> Instant.EPOCH, Collections.emptyList(), unused -> {}); + private static ExecutableWork expiredWork(Windmill.WorkItem workItem) { + return ExecutableWork.create( + Work.create( + workItem, + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + createWorkProcessingContext(), + () -> Instant.EPOCH, + Collections.emptyList()), + ignored -> {}); + } + + private static Work.ProcessingContext createWorkProcessingContext() { + return Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}); } private static WorkId workId(long workToken, long cacheToken) { return WorkId.builder().setCacheToken(cacheToken).setWorkToken(workToken).build(); } - private static Windmill.WorkItem createWorkItem(long workToken, long cacheToken) { + private static Windmill.WorkItem createWorkItem( + long workToken, long cacheToken, ShardedKey shardedKey) { return Windmill.WorkItem.newBuilder() - .setKey(ByteString.copyFromUtf8("")) - .setShardingKey(1) + .setShardingKey(shardedKey.shardingKey()) + .setKey(shardedKey.key()) .setWorkToken(workToken) .setCacheToken(cacheToken) .build(); @@ -84,7 +106,7 @@ private static Windmill.WorkItem createWorkItem(long workToken, long cacheToken) @Before public void setup() { - Map> readWriteActiveWorkMap = new HashMap<>(); + Map> readWriteActiveWorkMap = new HashMap<>(); // Only use readOnlyActiveWork to verify internal behavior in reaction to exposed API calls. readOnlyActiveWork = Collections.unmodifiableMap(readWriteActiveWorkMap); activeWorkState = ActiveWorkState.forTesting(readWriteActiveWorkMap, computationStateCache); @@ -94,7 +116,7 @@ public void setup() { public void testActivateWorkForKey_EXECUTE_unknownKey() { ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey( - shardedKey("someKey", 1L), createWork(createWorkItem(1L, 1L))); + createWork(createWorkItem(1L, 1L, shardedKey("someKey", 1L)))); assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); } @@ -107,9 +129,9 @@ public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() { ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey( - shardedKey, createWork(createWorkItem(workToken, cacheToken))); + createWork(createWorkItem(workToken, cacheToken, shardedKey))); - Optional nextWorkForKey = + Optional nextWorkForKey = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workId(workToken, cacheToken)); assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); @@ -123,9 +145,9 @@ public void testActivateWorkForKey_DUPLICATE() { ShardedKey shardedKey = shardedKey("someKey", 1L); // ActivateWork with the same shardedKey, and the same workTokens. - activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(workToken, 1L))); + activeWorkState.activateWorkForKey(createWork(createWorkItem(workToken, 1L, shardedKey))); ActivateWorkResult activateWorkResult = - activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(workToken, 1L))); + activeWorkState.activateWorkForKey(createWork(createWorkItem(workToken, 1L, shardedKey))); assertEquals(ActivateWorkResult.DUPLICATE, activateWorkResult); } @@ -135,9 +157,9 @@ public void testActivateWorkForKey_QUEUED() { ShardedKey shardedKey = shardedKey("someKey", 1L); // ActivateWork with the same shardedKey, but different workTokens. - activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(1L, 1L))); + activeWorkState.activateWorkForKey(createWork(createWorkItem(1L, 1L, shardedKey))); ActivateWorkResult activateWorkResult = - activeWorkState.activateWorkForKey(shardedKey, createWork(createWorkItem(2L, 1L))); + activeWorkState.activateWorkForKey(createWork(createWorkItem(2L, 1L, shardedKey))); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); } @@ -156,10 +178,12 @@ public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() { long workTokenInQueue = 2L; long otherWorkToken = 1L; long cacheToken = 1L; - Work workInQueue = createWork(createWorkItem(workTokenInQueue, cacheToken)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, workInQueue); + ExecutableWork workInQueue = + createWork(createWorkItem(workTokenInQueue, cacheToken, shardedKey)); + + activeWorkState.activateWorkForKey(workInQueue); activeWorkState.completeWorkAndGetNextWorkForKey( shardedKey, workId(otherWorkToken, cacheToken)); @@ -169,12 +193,13 @@ public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() { @Test public void testCompleteWorkAndGetNextWorkForKey_removesWorkFromQueueWhenComplete() { - Work activeWork = createWork(createWorkItem(1L, 1L)); - Work nextWork = createWork(createWorkItem(2L, 2L)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, activeWork); - activeWorkState.activateWorkForKey(shardedKey, nextWork); + ExecutableWork activeWork = createWork(createWorkItem(1L, 1L, shardedKey)); + ExecutableWork nextWork = createWork(createWorkItem(2L, 2L, shardedKey)); + + activeWorkState.activateWorkForKey(activeWork); + activeWorkState.activateWorkForKey(nextWork); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, activeWork.id()); assertEquals(nextWork, readOnlyActiveWork.get(shardedKey).peek()); @@ -184,10 +209,11 @@ public void testCompleteWorkAndGetNextWorkForKey_removesWorkFromQueueWhenComplet @Test public void testCompleteWorkAndGetNextWorkForKey_removesQueueIfNoWorkPresent() { - Work workInQueue = createWork(createWorkItem(1L, 1L)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, workInQueue); + ExecutableWork workInQueue = createWork(createWorkItem(1L, 1L, shardedKey)); + + activeWorkState.activateWorkForKey(workInQueue); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workInQueue.id()); assertFalse(readOnlyActiveWork.containsKey(shardedKey)); @@ -195,21 +221,22 @@ public void testCompleteWorkAndGetNextWorkForKey_removesQueueIfNoWorkPresent() { @Test public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() { - Work workToBeCompleted = createWork(createWorkItem(1L, 1L)); - Work nextWork = createWork(createWorkItem(2L, 2L)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, workToBeCompleted); - activeWorkState.activateWorkForKey(shardedKey, nextWork); + ExecutableWork workToBeCompleted = createWork(createWorkItem(1L, 1L, shardedKey)); + ExecutableWork nextWork = createWork(createWorkItem(2L, 2L, shardedKey)); + + activeWorkState.activateWorkForKey(workToBeCompleted); + activeWorkState.activateWorkForKey(nextWork); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workToBeCompleted.id()); - Optional nextWorkOpt = + Optional nextWorkOpt = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workToBeCompleted.id()); assertTrue(nextWorkOpt.isPresent()); assertSame(nextWork, nextWorkOpt.get()); - Optional endOfWorkQueue = + Optional endOfWorkQueue = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, nextWork.id()); assertFalse(endOfWorkQueue.isPresent()); @@ -219,11 +246,11 @@ public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() { @Test public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneShardKey() { ShardedKey shardedKey = shardedKey("someKey", 1L); - Work work1 = createWork(createWorkItem(1L, 1L)); - Work work2 = createWork(createWorkItem(2L, 2L)); + ExecutableWork work1 = createWork(createWorkItem(1L, 1L, shardedKey)); + ExecutableWork work2 = createWork(createWorkItem(2L, 2L, shardedKey)); - activeWorkState.activateWorkForKey(shardedKey, work1); - activeWorkState.activateWorkForKey(shardedKey, work2); + activeWorkState.activateWorkForKey(work1); + activeWorkState.activateWorkForKey(work2); GetWorkBudget expectedActiveBudget1 = GetWorkBudget.builder() @@ -248,11 +275,11 @@ public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneS @Test public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_whenWorkCompleted() { ShardedKey shardedKey = shardedKey("someKey", 1L); - Work work1 = createWork(createWorkItem(1L, 1L)); - Work work2 = createWork(createWorkItem(2L, 2L)); + ExecutableWork work1 = createWork(createWorkItem(1L, 1L, shardedKey)); + ExecutableWork work2 = createWork(createWorkItem(2L, 2L, shardedKey)); - activeWorkState.activateWorkForKey(shardedKey, work1); - activeWorkState.activateWorkForKey(shardedKey, work2); + activeWorkState.activateWorkForKey(work1); + activeWorkState.activateWorkForKey(work2); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, work1.id()); GetWorkBudget expectedActiveBudget = @@ -268,11 +295,11 @@ public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_when public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_multipleShardKeys() { ShardedKey shardedKey1 = shardedKey("someKey", 1L); ShardedKey shardedKey2 = shardedKey("someKey", 2L); - Work work1 = createWork(createWorkItem(1L, 1L)); - Work work2 = createWork(createWorkItem(2L, 2L)); + ExecutableWork work1 = createWork(createWorkItem(1L, 1L, shardedKey1)); + ExecutableWork work2 = createWork(createWorkItem(2L, 2L, shardedKey2)); - activeWorkState.activateWorkForKey(shardedKey1, work1); - activeWorkState.activateWorkForKey(shardedKey2, work2); + activeWorkState.activateWorkForKey(work1); + activeWorkState.activateWorkForKey(work2); GetWorkBudget expectedActiveBudget = GetWorkBudget.builder() @@ -287,16 +314,16 @@ public void testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_mult @Test public void testInvalidateStuckCommits() { Map invalidatedCommits = new HashMap<>(); - - Work stuckWork1 = expiredWork(createWorkItem(1L, 1L)); - stuckWork1.setState(Work.State.COMMITTING); - Work stuckWork2 = expiredWork(createWorkItem(2L, 1L)); - stuckWork2.setState(Work.State.COMMITTING); ShardedKey shardedKey1 = shardedKey("someKey", 1L); ShardedKey shardedKey2 = shardedKey("anotherKey", 2L); - activeWorkState.activateWorkForKey(shardedKey1, stuckWork1); - activeWorkState.activateWorkForKey(shardedKey2, stuckWork2); + ExecutableWork stuckWork1 = expiredWork(createWorkItem(1L, 1L, shardedKey1)); + stuckWork1.work().setState(Work.State.COMMITTING); + ExecutableWork stuckWork2 = expiredWork(createWorkItem(2L, 1L, shardedKey2)); + stuckWork2.work().setState(Work.State.COMMITTING); + + activeWorkState.activateWorkForKey(stuckWork1); + activeWorkState.activateWorkForKey(stuckWork2); activeWorkState.invalidateStuckCommits(Instant.now(), invalidatedCommits::put); @@ -312,22 +339,21 @@ public void testInvalidateStuckCommits() { long workToken = 10L; long cacheToken1 = 5L; long cacheToken2 = cacheToken1 + 2L; - - Work firstWork = createWork(createWorkItem(workToken, cacheToken1)); - Work secondWork = createWork(createWorkItem(workToken, cacheToken2)); - Work differentWorkTokenWork = createWork(createWorkItem(1L, 1L)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, differentWorkTokenWork); + ExecutableWork firstWork = createWork(createWorkItem(workToken, cacheToken1, shardedKey)); + ExecutableWork secondWork = createWork(createWorkItem(workToken, cacheToken2, shardedKey)); + ExecutableWork differentWorkTokenWork = createWork(createWorkItem(1L, 1L, shardedKey)); + + activeWorkState.activateWorkForKey(differentWorkTokenWork); // ActivateWork with the same shardedKey, and the same workTokens, but different cacheTokens. - activeWorkState.activateWorkForKey(shardedKey, firstWork); - ActivateWorkResult activateWorkResult = - activeWorkState.activateWorkForKey(shardedKey, secondWork); + activeWorkState.activateWorkForKey(firstWork); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(secondWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork)); - Optional nextWork = + Optional nextWork = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, differentWorkTokenWork.id()); assertTrue(nextWork.isPresent()); assertSame(firstWork, nextWork.get()); @@ -342,20 +368,19 @@ public void testInvalidateStuckCommits() { long workToken = 10L; long cacheToken1 = 5L; long cacheToken2 = 7L; - - Work firstWork = createWork(createWorkItem(workToken, cacheToken1)); - Work secondWork = createWork(createWorkItem(workToken, cacheToken2)); ShardedKey shardedKey = shardedKey("someKey", 1L); + ExecutableWork firstWork = createWork(createWorkItem(workToken, cacheToken1, shardedKey)); + ExecutableWork secondWork = createWork(createWorkItem(workToken, cacheToken2, shardedKey)); + // ActivateWork with the same shardedKey, and the same workTokens, but different cacheTokens. - activeWorkState.activateWorkForKey(shardedKey, firstWork); - ActivateWorkResult activateWorkResult = - activeWorkState.activateWorkForKey(shardedKey, secondWork); + activeWorkState.activateWorkForKey(firstWork); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(secondWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); assertEquals(firstWork, readOnlyActiveWork.get(shardedKey).peek()); assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork)); - Optional nextWork = + Optional nextWork = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, firstWork.id()); assertTrue(nextWork.isPresent()); assertSame(secondWork, nextWork.get()); @@ -367,13 +392,13 @@ public void testInvalidateStuckCommits() { long cacheToken = 1L; long newWorkToken = 10L; long queuedWorkToken = newWorkToken / 2; - - Work queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken)); - Work newWork = createWork(createWorkItem(newWorkToken, cacheToken)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, queuedWork); - ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(shardedKey, newWork); + ExecutableWork queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken, shardedKey)); + ExecutableWork newWork = createWork(createWorkItem(newWorkToken, cacheToken, shardedKey)); + + activeWorkState.activateWorkForKey(queuedWork); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(newWork); // newWork should be queued and queuedWork should not be removed since it is currently active. assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); @@ -388,14 +413,16 @@ public void testInvalidateStuckCommits() { long newWorkToken = 10L; long queuedWorkToken = newWorkToken / 2; - Work differentWorkTokenWork = createWork(createWorkItem(100L, 100L)); - Work queuedWork = createWork(createWorkItem(queuedWorkToken, matchingCacheToken)); - Work newWork = createWork(createWorkItem(newWorkToken, matchingCacheToken)); ShardedKey shardedKey = shardedKey("someKey", 1L); + ExecutableWork differentWorkTokenWork = createWork(createWorkItem(100L, 100L, shardedKey)); + ExecutableWork queuedWork = + createWork(createWorkItem(queuedWorkToken, matchingCacheToken, shardedKey)); + ExecutableWork newWork = + createWork(createWorkItem(newWorkToken, matchingCacheToken, shardedKey)); - activeWorkState.activateWorkForKey(shardedKey, differentWorkTokenWork); - activeWorkState.activateWorkForKey(shardedKey, queuedWork); - ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(shardedKey, newWork); + activeWorkState.activateWorkForKey(differentWorkTokenWork); + activeWorkState.activateWorkForKey(queuedWork); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(newWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); assertTrue(readOnlyActiveWork.get(shardedKey).contains(newWork)); @@ -408,13 +435,13 @@ public void testActivateWorkForKey_matchingCacheTokens_newWorkTokenLesser_STALE( long cacheToken = 1L; long queuedWorkToken = 10L; long newWorkToken = queuedWorkToken / 2; - - Work queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken)); - Work newWork = createWork(createWorkItem(newWorkToken, cacheToken)); ShardedKey shardedKey = shardedKey("someKey", 1L); - activeWorkState.activateWorkForKey(shardedKey, queuedWork); - ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(shardedKey, newWork); + ExecutableWork queuedWork = createWork(createWorkItem(queuedWorkToken, cacheToken, shardedKey)); + ExecutableWork newWork = createWork(createWorkItem(newWorkToken, cacheToken, shardedKey)); + + activeWorkState.activateWorkForKey(queuedWork); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(newWork); assertEquals(ActivateWorkResult.STALE, activateWorkResult); assertFalse(readOnlyActiveWork.get(shardedKey).contains(newWork)); @@ -424,26 +451,28 @@ public void testActivateWorkForKey_matchingCacheTokens_newWorkTokenLesser_STALE( @Test public void testGetKeyHeartbeats() { Instant refreshDeadline = Instant.now(); - - Work freshWork = createWork(createWorkItem(3L, 3L)); - Work refreshableWork1 = expiredWork(createWorkItem(1L, 1L)); - refreshableWork1.setState(Work.State.COMMITTING); - Work refreshableWork2 = expiredWork(createWorkItem(2L, 2L)); - refreshableWork2.setState(Work.State.COMMITTING); ShardedKey shardedKey1 = shardedKey("someKey", 1L); ShardedKey shardedKey2 = shardedKey("anotherKey", 2L); - activeWorkState.activateWorkForKey(shardedKey1, refreshableWork1); - activeWorkState.activateWorkForKey(shardedKey1, freshWork); - activeWorkState.activateWorkForKey(shardedKey2, refreshableWork2); + ExecutableWork freshWork = createWork(createWorkItem(3L, 3L, shardedKey1)); + ExecutableWork refreshableWork1 = expiredWork(createWorkItem(1L, 1L, shardedKey1)); + refreshableWork1.work().setState(Work.State.COMMITTING); + ExecutableWork refreshableWork2 = expiredWork(createWorkItem(2L, 2L, shardedKey2)); + refreshableWork2.work().setState(Work.State.COMMITTING); + + activeWorkState.activateWorkForKey(refreshableWork1); + activeWorkState.activateWorkForKey(freshWork); + activeWorkState.activateWorkForKey(refreshableWork2); ImmutableList requests = activeWorkState.getKeyHeartbeats(refreshDeadline, DataflowExecutionStateSampler.instance()); ImmutableList expected = ImmutableList.of( - HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey1, refreshableWork1), - HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from(shardedKey2, refreshableWork2)); + HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from( + shardedKey1, refreshableWork1.work()), + HeartbeatRequestShardingKeyWorkTokenAndCacheToken.from( + shardedKey2, refreshableWork2.work())); ImmutableList actual = requests.stream() diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java index 3eb434c39038..3c1683ecf436 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java @@ -56,17 +56,23 @@ public class ComputationStateCacheTest { private final ComputationConfig.Fetcher configFetcher = mock(ComputationConfig.Fetcher.class); private ComputationStateCache computationStateCache; - private static Work createWork(long workToken, long cacheToken) { - return Work.create( - Windmill.WorkItem.newBuilder() - .setKey(ByteString.copyFromUtf8("")) - .setShardingKey(1) - .setWorkToken(workToken) - .setCacheToken(cacheToken) - .build(), - Instant::now, - Collections.emptyList(), - unused -> {}); + private static ExecutableWork createWork(ShardedKey shardedKey, long workToken, long cacheToken) { + return ExecutableWork.create( + Work.create( + Windmill.WorkItem.newBuilder() + .setKey(shardedKey.key()) + .setShardingKey(shardedKey.shardingKey()) + .setWorkToken(workToken) + .setCacheToken(cacheToken) + .build(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + Instant::now, + Collections.emptyList()), + ignored -> {}); } @Before @@ -249,24 +255,25 @@ public void testTotalCurrentActiveGetWorkBudget() { ComputationConfig.create(mapTask, userTransformToStateFamilyName, ImmutableMap.of()); when(configFetcher.fetchConfig(eq(computationId))).thenReturn(Optional.of(computationConfig)); when(configFetcher.fetchConfig(eq(computationId2))).thenReturn(Optional.of(computationConfig)); - Work work1 = createWork(1, 1); - Work work2 = createWork(2, 2); - Work work3 = createWork(3, 3); + ShardedKey shardedKey = ShardedKey.create(ByteString.EMPTY, 1); + ShardedKey shardedKey2 = ShardedKey.create(ByteString.EMPTY, 2); + + ExecutableWork work1 = createWork(shardedKey, 1, 1); + ExecutableWork work2 = createWork(shardedKey2, 2, 2); + ExecutableWork work3 = createWork(shardedKey2, 3, 3); // Activate 1 Work for computationId Optional maybeComputationState = computationStateCache.get(computationId); assertTrue(maybeComputationState.isPresent()); ComputationState computationState = maybeComputationState.get(); - ShardedKey shardedKey = ShardedKey.create(ByteString.EMPTY, 1); - computationState.activateWork(shardedKey, work1); + computationState.activateWork(work1); // Activate 2 Work(s) for computationId2 Optional maybeComputationState2 = computationStateCache.get(computationId); assertTrue(maybeComputationState2.isPresent()); ComputationState computationState2 = maybeComputationState2.get(); - ShardedKey shardedKey2 = ShardedKey.create(ByteString.EMPTY, 2); - computationState2.activateWork(shardedKey2, work2); - computationState2.activateWork(shardedKey2, work3); + computationState2.activateWork(work2); + computationState2.activateWork(work3); // GetWorkBudget should have 3 items. 1 from computationId, 2 from computationId2. assertThat(computationStateCache.totalCurrentActiveGetWorkBudget()) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index cfad61385476..85e07c3bd797 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -31,6 +31,7 @@ import org.apache.beam.runners.dataflow.worker.FakeWindmillServer; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -52,17 +53,23 @@ public class StreamingApplianceWorkCommitterTest { private FakeWindmillServer fakeWindmillServer; private StreamingApplianceWorkCommitter workCommitter; - private static Work createMockWork(long workToken, Consumer processWorkFn) { + private static Work createMockWork(long workToken) { return Work.create( Windmill.WorkItem.newBuilder() .setKey(ByteString.EMPTY) .setWorkToken(workToken) - .setShardingKey(workToken) - .setCacheToken(workToken) + .setCacheToken(1L) + .setShardingKey(2L) .build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> { + throw new UnsupportedOperationException(); + }), Instant::now, - Collections.emptyList(), - processWorkFn); + Collections.emptyList()); } private static ComputationState createComputationState(String computationId) { @@ -97,7 +104,7 @@ public void testCommit() { workCommitter = createWorkCommitter(completeCommits::add); List commits = new ArrayList<>(); for (int i = 1; i <= 5; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); Windmill.WorkItemCommitRequest commitRequest = Windmill.WorkItemCommitRequest.newBuilder() .setKey(work.getWorkItem().getKey()) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 49c61b9b8ab2..d53690938aef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -40,6 +40,7 @@ import java.util.function.Supplier; import org.apache.beam.runners.dataflow.worker.FakeWindmillServer; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.WorkId; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -69,17 +70,23 @@ public class StreamingEngineWorkCommitterTest { private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; - private static Work createMockWork(long workToken, Consumer processWorkFn) { + private static Work createMockWork(long workToken) { return Work.create( Windmill.WorkItem.newBuilder() .setKey(ByteString.EMPTY) .setWorkToken(workToken) - .setShardingKey(workToken) - .setCacheToken(workToken) + .setCacheToken(1L) + .setShardingKey(2L) .build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> { + throw new UnsupportedOperationException(); + }), Instant::now, - Collections.emptyList(), - processWorkFn); + Collections.emptyList()); } private static ComputationState createComputationState(String computationId) { @@ -126,7 +133,7 @@ public void testCommit_sendsCommitsToStreamingEngine() { workCommitter = createWorkCommitter(completeCommits::add); List commits = new ArrayList<>(); for (int i = 1; i <= 5; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); WorkItemCommitRequest commitRequest = WorkItemCommitRequest.newBuilder() .setKey(work.getWorkItem().getKey()) @@ -157,7 +164,7 @@ public void testCommit_handlesFailedCommits() { workCommitter = createWorkCommitter(completeCommits::add); List commits = new ArrayList<>(); for (int i = 1; i <= 10; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); // Fail half of the work. if (i % 2 == 0) { work.setFailed(); @@ -213,7 +220,7 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { List commits = new ArrayList<>(); for (int i = 1; i <= 10; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); WorkItemCommitRequest commitRequest = WorkItemCommitRequest.newBuilder() .setKey(work.getWorkItem().getKey()) @@ -278,6 +285,7 @@ public Instant startTime() { return Instant.now(); } }; + commitWorkStreamFactory = WindmillStreamPool.create(1, Duration.standardMinutes(1), fakeCommitWorkStream) ::getCloseableStream; @@ -287,7 +295,7 @@ public Instant startTime() { List commits = new ArrayList<>(); for (int i = 1; i <= 10; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); WorkItemCommitRequest commitRequest = WorkItemCommitRequest.newBuilder() .setKey(work.getWorkItem().getKey()) @@ -323,7 +331,7 @@ public void testMultipleCommitSendersSingleStream() { StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 5, completeCommits::add); List commits = new ArrayList<>(); for (int i = 1; i <= 500; i++) { - Work work = createMockWork(i, ignored -> {}); + Work work = createMockWork(i); WorkItemCommitRequest commitRequest = WorkItemCommitRequest.newBuilder() .setKey(work.getWorkItem().getKey()) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java index 2efd1054822e..9822daa91567 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -23,6 +23,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -37,7 +38,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; @@ -47,10 +47,11 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; @@ -108,20 +109,17 @@ public class StreamingEngineClientTest { private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>()); - private final AtomicReference connections = - new AtomicReference<>(StreamingEngineConnectionState.EMPTY); private Server fakeStreamingEngineServer; private CountDownLatch getWorkerMetadataReady; private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub; private StreamingEngineClient streamingEngineClient; - private static WorkItemProcessor noOpProcessWorkItemFn() { - return (computation, - inputDataWatermark, - synchronizedProcessingTime, - workItem, - ackQueuedWorkItem, + private static WorkItemScheduler noOpProcessWorkItemFn() { + return (workItem, + watermarks, + processingContext, + ackWorkItemQueued, getWorkStreamLatencies) -> {}; } @@ -173,17 +171,18 @@ public void cleanUp() { private StreamingEngineClient newStreamingEngineClient( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, - WorkItemProcessor workItemProcessor) { + WorkItemScheduler workItemScheduler) { return StreamingEngineClient.forTesting( JOB_HEADER, getWorkBudget, - connections, streamFactory, - workItemProcessor, + workItemScheduler, stubFactory, getWorkBudgetDistributor, dispatcherClient, - CLIENT_ID); + CLIENT_ID, + ignored -> mock(WorkCommitter.class), + ignored -> {}); } @Test @@ -216,7 +215,8 @@ public void testStreamsStartCorrectly() throws InterruptedException { fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = connections.get(); + StreamingEngineConnectionState currentConnections = + streamingEngineClient.getCurrentConnections(); assertEquals(2, currentConnections.windmillConnections().size()); assertEquals(2, currentConnections.windmillStreams().size()); @@ -238,7 +238,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { .createDirectGetWorkStream( any(), eq(getWorkRequest(0, 0)), any(), any(), any(), eq(noOpProcessWorkItemFn())); - verify(streamFactory, times(2)).createGetDataStream(any(), any()); + verify(streamFactory, times(2)).createGetDataStream(any(), any(), eq(false), any()); verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); } @@ -306,11 +306,12 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = connections.get(); + StreamingEngineConnectionState currentConnections = + streamingEngineClient.getCurrentConnections(); assertEquals(1, currentConnections.windmillConnections().size()); assertEquals(1, currentConnections.windmillStreams().size()); Set workerTokens = - connections.get().windmillConnections().values().stream() + streamingEngineClient.getCurrentConnections().windmillConnections().values().stream() .map(WindmillConnection::backendWorkerToken) .filter(Optional::isPresent) .map(Optional::get) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java index 2532fca51549..496f69dc52d8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSenderTest.java @@ -33,8 +33,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; @@ -49,11 +50,9 @@ @RunWith(JUnit4.class) public class WindmillStreamSenderTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final GetWorkRequest GET_WORK_REQUEST = GetWorkRequest.newBuilder().setClientId(1L).setJobId("job").setProjectId("project").build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final GrpcWindmillStreamFactory streamFactory = spy( GrpcWindmillStreamFactory.of( @@ -63,13 +62,9 @@ public class WindmillStreamSenderTest { .setWorkerId("worker") .build()) .build()); - private final WorkItemProcessor workItemProcessor = - (computation, - inputDataWatermark, - synchronizedProcessingTime, - workItem, - ackQueuedWorkItem, - getWorkStreamLatencies) -> {}; + private final WorkItemScheduler workItemScheduler = + (workItem, watermarks, processingContext, ackWorkItemQueued, getWorkStreamLatencies) -> {}; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private CloudWindmillServiceV1Alpha1Stub stub; @@ -110,9 +105,9 @@ public void testStartStream_startsAllStreams() { any(ThrottleTimer.class), any(), any(), - eq(workItemProcessor)); + eq(workItemScheduler)); - verify(streamFactory).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory).createGetDataStream(eq(stub), any(ThrottleTimer.class), eq(false), any()); verify(streamFactory).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); } @@ -141,9 +136,10 @@ public void testStartStream_onlyStartsStreamsOnce() { any(ThrottleTimer.class), any(), any(), - eq(workItemProcessor)); + eq(workItemScheduler)); - verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createGetDataStream(eq(stub), any(ThrottleTimer.class), eq(false), any()); verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); } @@ -175,9 +171,10 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted any(ThrottleTimer.class), any(), any(), - eq(workItemProcessor)); + eq(workItemScheduler)); - verify(streamFactory, times(1)).createGetDataStream(eq(stub), any(ThrottleTimer.class)); + verify(streamFactory, times(1)) + .createGetDataStream(eq(stub), any(ThrottleTimer.class), eq(false), any()); verify(streamFactory, times(1)).createCommitWorkStream(eq(stub), any(ThrottleTimer.class)); } @@ -208,10 +205,11 @@ public void testCloseAllStreams_closesAllStreams() { any(ThrottleTimer.class), any(), any(), - eq(workItemProcessor))) + eq(workItemScheduler))) .thenReturn(mockGetWorkStream); - when(mockStreamFactory.createGetDataStream(eq(stub), any(ThrottleTimer.class))) + when(mockStreamFactory.createGetDataStream( + eq(stub), any(ThrottleTimer.class), eq(false), any())) .thenReturn(mockGetDataStream); when(mockStreamFactory.createCommitWorkStream(eq(stub), any(ThrottleTimer.class))) .thenReturn(mockCommitWorkStream); @@ -236,6 +234,12 @@ private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { private WindmillStreamSender newWindmillStreamSender( GetWorkBudget budget, GrpcWindmillStreamFactory streamFactory) { return WindmillStreamSender.create( - stub, GET_WORK_REQUEST, budget, streamFactory, workItemProcessor); + stub, + GET_WORK_REQUEST, + budget, + streamFactory, + workItemScheduler, + ignored -> mock(WorkCommitter.class), + ignored -> {}); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 249642aa6d1e..4fa424412ee0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -30,6 +31,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; @@ -46,9 +48,8 @@ @RunWith(JUnit4.class) public class EvenGetWorkBudgetDistributorTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub; @@ -257,11 +258,8 @@ private WindmillStreamSender createWindmillStreamSender(GetWorkBudget getWorkBud .setWorkerId("worker") .build()) .build(), - (computation, - inputDataWatermark, - synchronizedProcessingTime, - workItem, - ackQueuedWorkItem, - getWorkStreamLatencies) -> {}); + (workItem, watermarks, processingContext, ackWorkItemQueued, getWorkStreamLatencies) -> {}, + ignored -> mock(WorkCommitter.class), + ignored -> {}); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 05b92e73f0ca..bd55595da135 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -18,9 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import java.util.ArrayList; import java.util.HashSet; @@ -32,6 +29,8 @@ import java.util.function.Supplier; import org.apache.beam.runners.dataflow.worker.KeyTokenInvalidException; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -80,43 +79,42 @@ private static FailureTracker streamingApplianceFailureReporter(boolean isWorkFa ignored -> Windmill.ReportStatsResponse.newBuilder().setFailed(isWorkFailed).build()); } - private static Work createWork(Supplier clock, Consumer processWorkFn) { - return Work.create( - Windmill.WorkItem.newBuilder() - .setKey(ByteString.EMPTY) - .setWorkToken(1L) - .setCacheToken(1L) - .setShardingKey(1L) - .build(), - clock, - new ArrayList<>(), + private static ExecutableWork createWork(Supplier clock, Consumer processWorkFn) { + return ExecutableWork.create( + Work.create( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(1L).build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + clock, + new ArrayList<>()), processWorkFn); } - private static Work createWork() { - return createWork(Instant::now, ignored -> {}); - } - - private static Work createWork(Consumer processWorkFn) { + private static ExecutableWork createWork(Consumer processWorkFn) { return createWork(Instant::now, processWorkFn); } @Test public void logAndProcessFailure_doesNotRetryKeyTokenInvalidException() { - Work work = spy(createWork()); + Set executedWork = new HashSet<>(); + ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); workFailureProcessor.logAndProcessFailure( DEFAULT_COMPUTATION_ID, work, new KeyTokenInvalidException("key"), invalidWork::add); - verify(work, times(0)).run(); - assertThat(invalidWork).containsExactly(work); + assertThat(executedWork).isEmpty(); + assertThat(invalidWork).containsExactly(work.work()); } @Test public void logAndProcessFailure_doesNotRetryWhenWorkItemCancelled() { - Work work = spy(createWork()); + Set executedWork = new HashSet<>(); + ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); @@ -126,55 +124,58 @@ public void logAndProcessFailure_doesNotRetryWhenWorkItemCancelled() { new WorkItemCancelledException(work.getWorkItem().getShardingKey()), invalidWork::add); - verify(work, times(0)).run(); - assertThat(invalidWork).containsExactly(work); + assertThat(executedWork).isEmpty(); + assertThat(invalidWork).containsExactly(work.work()); } @Test public void logAndProcessFailure_doesNotRetryOOM() { - Work work = spy(createWork()); + Set executedWork = new HashSet<>(); + ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); workFailureProcessor.logAndProcessFailure( DEFAULT_COMPUTATION_ID, work, new OutOfMemoryError(), invalidWork::add); - verify(work, times(0)).run(); - assertThat(invalidWork).containsExactly(work); + assertThat(executedWork).isEmpty(); + assertThat(invalidWork).containsExactly(work.work()); } @Test public void logAndProcessFailure_doesNotRetryWhenFailureReporterMarksAsNonRetryable() { - Work work = spy(createWork()); + Set executedWork = new HashSet<>(); + ExecutableWork work = createWork(executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(true)); Set invalidWork = new HashSet<>(); workFailureProcessor.logAndProcessFailure( DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); - verify(work, times(0)).run(); - assertThat(invalidWork).containsExactly(work); + assertThat(executedWork).isEmpty(); + assertThat(invalidWork).containsExactly(work.work()); } @Test public void logAndProcessFailure_doesNotRetryAfterLocalRetryTimeout() { - Work veryOldWork = - spy(createWork(() -> Instant.now().minus(Duration.standardDays(30)), ignored -> {})); + Set executedWork = new HashSet<>(); + ExecutableWork veryOldWork = + createWork(() -> Instant.now().minus(Duration.standardDays(30)), executedWork::add); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); workFailureProcessor.logAndProcessFailure( DEFAULT_COMPUTATION_ID, veryOldWork, new RuntimeException(), invalidWork::add); - verify(veryOldWork, times(0)).run(); - assertThat(invalidWork).contains(veryOldWork); + assertThat(executedWork).isEmpty(); + assertThat(invalidWork).contains(veryOldWork.work()); } @Test public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEngine() throws InterruptedException { CountDownLatch runWork = new CountDownLatch(1); - Work work = spy(createWork(ignored -> runWork.countDown())); + ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingEngineFailureReporter()); Set invalidWork = new HashSet<>(); @@ -182,7 +183,6 @@ public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEn DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); runWork.await(); - verify(work, times(1)).run(); assertThat(invalidWork).isEmpty(); } @@ -190,7 +190,7 @@ public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingEn public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingAppliance() throws InterruptedException { CountDownLatch runWork = new CountDownLatch(1); - Work work = spy(createWork(ignored -> runWork.countDown())); + ExecutableWork work = createWork(ignored -> runWork.countDown()); WorkFailureProcessor workFailureProcessor = createWorkFailureProcessor(streamingApplianceFailureReporter(false)); Set invalidWork = new HashSet<>(); @@ -198,7 +198,6 @@ public void logAndProcessFailure_retriesOnUncaughtUnhandledException_streamingAp DEFAULT_COMPUTATION_ID, work, new RuntimeException(), invalidWork::add); runWork.await(); - verify(work, times(1)).run(); assertThat(invalidWork).isEmpty(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java index 31e354042584..175c8421ff8e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java @@ -40,7 +40,9 @@ import java.util.function.Supplier; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -106,16 +108,23 @@ private ActiveWorkRefresher createActiveWorkRefresher( Executors.newSingleThreadScheduledExecutor()); } - private Work createOldWork(int workIds, Consumer processWork) { - return Work.create( - Windmill.WorkItem.newBuilder() - .setWorkToken(workIds) - .setCacheToken(workIds) - .setKey(ByteString.EMPTY) - .setShardingKey(workIds) - .build(), - DispatchedActiveWorkRefresherTest.A_LONG_TIME_AGO, - ImmutableList.of(), + private ExecutableWork createOldWork( + ShardedKey shardedKey, int workIds, Consumer processWork) { + return ExecutableWork.create( + Work.create( + Windmill.WorkItem.newBuilder() + .setKey(shardedKey.key()) + .setShardingKey(shardedKey.shardingKey()) + .setWorkToken(workIds) + .setCacheToken(workIds) + .build(), + Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), + Work.createProcessingContext( + "computationId", + (a, b) -> Windmill.KeyedGetDataResponse.getDefaultInstance(), + ignored -> {}), + DispatchedActiveWorkRefresherTest.A_LONG_TIME_AGO, + ImmutableList.of()), processWork); } @@ -135,14 +144,15 @@ public void testActiveWorkRefresh() throws InterruptedException { }; List computations = new ArrayList<>(); - Map> computationsAndWork = new HashMap<>(); + Map> computationsAndWork = new HashMap<>(); for (int i = 0; i < 5; i++) { ComputationState computationState = createComputationState(i); - Work fakeWork = createOldWork(i, processWork); - computationState.activateWork(ShardedKey.create(ByteString.EMPTY, i), fakeWork); + ExecutableWork fakeWork = + createOldWork(ShardedKey.create(ByteString.EMPTY, i), i, processWork); + computationState.activateWork(fakeWork); computations.add(computationState); - List activeWorkForComputation = + List activeWorkForComputation = computationsAndWork.computeIfAbsent( computationState.getComputationId(), ignored -> new ArrayList<>()); activeWorkForComputation.add(fakeWork); @@ -173,13 +183,13 @@ public void testActiveWorkRefresh() throws InterruptedException { expectedHeartbeats.entrySet()) { String computationId = expectedHeartbeat.getKey(); List heartbeatRequests = expectedHeartbeat.getValue(); - List work = computationsAndWork.get(computationId); + List work = computationsAndWork.get(computationId); // Compare the heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys. assertThat(heartbeatRequests) .comparingElementsUsing( Correspondence.from( - (HeartbeatRequest h, Work w) -> + (HeartbeatRequest h, ExecutableWork w) -> h.getWorkToken() == w.getWorkItem().getWorkToken() && h.getCacheToken() == w.getWorkItem().getWorkToken() && h.getShardingKey() == w.getWorkItem().getShardingKey(), @@ -195,7 +205,7 @@ public void testActiveWorkRefresh() throws InterruptedException { @Test public void testInvalidateStuckCommits() throws InterruptedException { int stuckCommitDurationMillis = 100; - Table computations = + Table computations = HashBasedTable.create(); WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100); ByteString key = ByteString.EMPTY; @@ -203,9 +213,9 @@ public void testInvalidateStuckCommits() throws InterruptedException { WindmillStateCache.ForComputation perComputationStateCache = spy(stateCache.forComputation(COMPUTATION_ID_PREFIX + i)); ComputationState computationState = spy(createComputationState(i, perComputationStateCache)); - Work fakeWork = createOldWork(i, ignored -> {}); - fakeWork.setState(Work.State.COMMITTING); - computationState.activateWork(ShardedKey.create(key, i), fakeWork); + ExecutableWork fakeWork = createOldWork(ShardedKey.create(key, i), i, ignored -> {}); + fakeWork.work().setState(Work.State.COMMITTING); + computationState.activateWork(fakeWork); computations.put(computationState, fakeWork, perComputationStateCache); } @@ -237,10 +247,10 @@ public void testInvalidateStuckCommits() throws InterruptedException { invalidateStuckCommitRan.await(); activeWorkRefresher.stop(); - for (Table.Cell cell : + for (Table.Cell cell : computations.cellSet()) { ComputationState computation = cell.getRowKey(); - Work work = cell.getColumnKey(); + ExecutableWork work = cell.getColumnKey(); WindmillStateCache.ForComputation perComputationStateCache = cell.getValue(); verify(perComputationStateCache, times(1)) .invalidate(eq(key), eq(work.getWorkItem().getShardingKey()));