diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json index 03d86a8d023e..dd2bf3aeb361 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json @@ -3,5 +3,6 @@ "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test", "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json index 9b023f630c36..74f4220571e5 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json index 03d86a8d023e..dd2bf3aeb361 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json @@ -3,5 +3,6 @@ "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test", "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/CHANGES.md b/CHANGES.md index 7707e252961b..06b92953c662 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -71,11 +71,12 @@ ## New Features / Improvements +* Added support for stateful processing in Spark Runner for streaming pipelines. Timer functionality is not yet supported and will be implemented in a future release ([#33237](https://github.com/apache/beam/issues/33237)). * Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)). -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)). * This enables initial Java GroupIntoBatches support. * Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)). +* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index f4e6bf740189..297facd4bc0d 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -345,7 +345,7 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test) excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' // State and Timers - excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo' + excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo' excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer' diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java index 619d2d16173d..44f8d6df683b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java @@ -21,7 +21,7 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import org.apache.beam.runners.spark.io.MicrobatchSource; -import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet.StateAndTimers; +import org.apache.beam.runners.spark.stateful.StateAndTimers; import org.apache.beam.runners.spark.translation.ValueAndCoderKryoSerializer; import org.apache.beam.runners.spark.translation.ValueAndCoderLazySerializable; import org.apache.beam.runners.spark.util.ByteArray; diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index c24841c7dd31..b18b31a67463 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -17,6 +17,9 @@ */ package org.apache.beam.runners.spark.stateful; +import static org.apache.beam.runners.spark.translation.TranslationUtils.checkpointIfNeeded; +import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration; + import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -35,7 +38,6 @@ import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; -import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.translation.ReifyTimestampsAndWindowsFunction; import org.apache.beam.runners.spark.translation.TranslationUtils; @@ -60,10 +62,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.dstream.DStream; @@ -100,27 +100,6 @@ public class SparkGroupAlsoByWindowViaWindowSet implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class); - /** State and Timers wrapper. */ - public static class StateAndTimers implements Serializable { - // Serializable state for internals (namespace to state tag to coded value). - private final Table state; - private final Collection serTimers; - - private StateAndTimers( - final Table state, final Collection timers) { - this.state = state; - this.serTimers = timers; - } - - Table getState() { - return state; - } - - Collection getTimers() { - return serTimers; - } - } - private static class OutputWindowedValueHolder implements OutputWindowedValue>> { private final List>>> windowedValues = new ArrayList<>(); @@ -348,7 +327,7 @@ private Collection filterTimersEligibleForProcessing( // empty outputs are filtered later using DStream filtering final StateAndTimers updated = - new StateAndTimers( + StateAndTimers.of( stateInternals.getState(), SparkTimerInternals.serializeTimers( timerInternals.getTimers(), timerDataCoder)); @@ -466,21 +445,6 @@ private static TimerInternals.TimerDataCoderV2 timerDa return TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder()); } - private static void checkpointIfNeeded( - final DStream>>> firedStream, - final SerializablePipelineOptions options) { - - final Long checkpointDurationMillis = getBatchDuration(options); - - if (checkpointDurationMillis > 0) { - firedStream.checkpoint(new Duration(checkpointDurationMillis)); - } - } - - private static Long getBatchDuration(final SerializablePipelineOptions options) { - return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis(); - } - private static JavaDStream>>> stripStateValues( final DStream>>> firedStream, final Coder keyCoder, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java index 5890662307fb..77ae042d81fa 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java @@ -63,7 +63,7 @@ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -class SparkStateInternals implements StateInternals { +public class SparkStateInternals implements StateInternals { private final K key; // Serializable state for internals (namespace to state tag to coded value). @@ -79,11 +79,11 @@ private SparkStateInternals(K key, Table stateTable) { this.stateTable = stateTable; } - static SparkStateInternals forKey(K key) { + public static SparkStateInternals forKey(K key) { return new SparkStateInternals<>(key); } - static SparkStateInternals forKeyAndState( + public static SparkStateInternals forKeyAndState( K key, Table stateTable) { return new SparkStateInternals<>(key, stateTable); } @@ -412,7 +412,7 @@ public void put(MapKeyT key, MapValueT value) { @Override public ReadableState computeIfAbsent( MapKeyT key, Function mappingFunction) { - Map sparkMapState = readValue(); + Map sparkMapState = readAsMap(); MapValueT current = sparkMapState.get(key); if (current == null) { put(key, mappingFunction.apply(key)); @@ -420,9 +420,17 @@ public ReadableState computeIfAbsent( return ReadableStates.immediate(current); } + private Map readAsMap() { + Map mapState = readValue(); + if (mapState == null) { + mapState = new HashMap<>(); + } + return mapState; + } + @Override public void remove(MapKeyT key) { - Map sparkMapState = readValue(); + Map sparkMapState = readAsMap(); sparkMapState.remove(key); writeValue(sparkMapState); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index de9820e1255c..8b647c42dd7e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -107,7 +107,7 @@ public Collection getTimers() { return timers; } - void addTimers(Iterator timers) { + public void addTimers(Iterator timers) { while (timers.hasNext()) { TimerData timer = timers.next(); this.timers.add(timer); @@ -163,7 +163,8 @@ public void setTimer( Instant target, Instant outputTimestamp, TimeDomain timeDomain) { - throw new UnsupportedOperationException("Setting a timer by ID not yet supported."); + this.setTimer( + TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain)); } @Override diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java new file mode 100644 index 000000000000..83eaddde5532 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateAndTimers.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.stateful; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.util.Collection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; + +/** State and Timers wrapper. */ +@AutoValue +public abstract class StateAndTimers implements Serializable { + public abstract Table getState(); + + public abstract Collection getTimers(); + + public static StateAndTimers of( + final Table state, final Collection timers) { + return new AutoValue_StateAndTimers.Builder().setState(state).setTimers(timers).build(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setState(Table state); + + abstract Builder setTimers(Collection timers); + + abstract StateAndTimers build(); + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java index 8bbcb1f2941a..34836cd6e7ae 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java @@ -31,12 +31,12 @@ import org.joda.time.Instant; /** DoFnRunner decorator which registers {@link MetricsContainerImpl}. */ -class DoFnRunnerWithMetrics implements DoFnRunner { +public class DoFnRunnerWithMetrics implements DoFnRunner { private final DoFnRunner delegate; private final String stepName; private final MetricsContainerStepMapAccumulator metricsAccum; - DoFnRunnerWithMetrics( + public DoFnRunnerWithMetrics( String stepName, DoFnRunner delegate, MetricsContainerStepMapAccumulator metricsAccum) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java index 0af480a2ff02..4b4d23b0c47c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkInputDataProcessor.java @@ -47,7 +47,7 @@ * Processes Spark's input data iterators using Beam's {@link * org.apache.beam.runners.core.DoFnRunner}. */ -interface SparkInputDataProcessor { +public interface SparkInputDataProcessor { /** * @return {@link OutputManager} to be used by {@link org.apache.beam.runners.core.DoFnRunner} for diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java index 5487bb1be73c..bbcd74dc408b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java @@ -23,14 +23,14 @@ import org.apache.beam.sdk.transforms.DoFn; /** Holds current processing context for {@link SparkInputDataProcessor}. */ -class SparkProcessContext { +public class SparkProcessContext { private final String stepName; private final DoFn doFn; private final DoFnRunner doFnRunner; private final Iterator timerDataIterator; private final K key; - SparkProcessContext( + public SparkProcessContext( String stepName, DoFn doFn, DoFnRunner doFnRunner, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 23af6f71b938..f2455e64b956 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -26,6 +26,8 @@ import org.apache.beam.runners.core.InMemoryStateInternals; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.SparkRunner; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; @@ -54,8 +56,10 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.dstream.DStream; import scala.Tuple2; /** A set of utilities to help translating Beam transformations into Spark transformations. */ @@ -258,6 +262,52 @@ public Boolean call(Tuple2, WindowedValue> input) { } } + /** + * Retrieves the batch duration in milliseconds from Spark pipeline options. + * + * @param options The serializable pipeline options containing Spark-specific settings + * @return The checkpoint duration in milliseconds as specified in SparkPipelineOptions + */ + public static Long getBatchDuration(final SerializablePipelineOptions options) { + return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis(); + } + + /** + * Reject timers {@link DoFn}. + * + * @param doFn the {@link DoFn} to possibly reject. + */ + public static void rejectTimers(DoFn doFn) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.timerDeclarations().size() > 0 + || signature.timerFamilyDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", + DoFn.TimerId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + SparkRunner.class.getSimpleName())); + } + } + + /** + * Checkpoints the given DStream if checkpointing is enabled in the pipeline options. + * + * @param dStream The DStream to be checkpointed + * @param options The SerializablePipelineOptions containing configuration settings including + * batch duration + */ + public static void checkpointIfNeeded( + final DStream dStream, final SerializablePipelineOptions options) { + + final Long checkpointDurationMillis = getBatchDuration(options); + + if (checkpointDurationMillis > 0) { + dStream.checkpoint(new Duration(checkpointDurationMillis)); + } + } + /** * Reject state and timers {@link DoFn}. * diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java new file mode 100644 index 000000000000..82557c3b972b --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StatefulDoFnRunner; +import org.apache.beam.runners.core.StepContext; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; +import org.apache.beam.runners.spark.stateful.SparkStateInternals; +import org.apache.beam.runners.spark.stateful.SparkTimerInternals; +import org.apache.beam.runners.spark.stateful.StateAndTimers; +import org.apache.beam.runners.spark.translation.DoFnRunnerWithMetrics; +import org.apache.beam.runners.spark.translation.SparkInputDataProcessor; +import org.apache.beam.runners.spark.translation.SparkProcessContext; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.beam.runners.spark.util.CachedSideInputReader; +import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; +import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.runners.spark.util.SparkSideInputReader; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.spark.streaming.State; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; +import scala.Tuple2; +import scala.runtime.AbstractFunction3; + +/** + * A function to handle stateful processing in Apache Beam's SparkRunner. This class processes + * stateful DoFn operations by managing state updates in a Spark streaming context. + * + *

Current Implementation Status: + * + *

    + *
  • State: Fully implemented and supported through {@link SparkStateInternals} + *
  • Timers: Not supported. While {@link SparkTimerInternals} is present in the code, timer + * functionality is not yet fully implemented and operational + *
+ * + * @param The type of the key in the input KV pairs + * @param The type of the value in the input KV pairs + * @param The input type, must be a KV of KeyT and ValueT + * @param The output type produced by the DoFn + */ +@SuppressWarnings({"rawtypes", "unchecked"}) +public class ParDoStateUpdateFn, OutputT> + extends AbstractFunction3< + /*Serialized KeyT*/ ByteArray, + Option*/ byte[]>, + /*State*/ State, + List, /*Serialized WindowedValue*/ byte[]>>> + implements Serializable { + + @SuppressWarnings("unused") + private static final Logger LOG = LoggerFactory.getLogger(ParDoStateUpdateFn.class); + + private final MetricsContainerStepMapAccumulator metricsAccum; + private final String stepName; + private final DoFn doFn; + private final Coder keyCoder; + private final WindowedValue.FullWindowedValueCoder wvCoder; + private transient boolean wasSetupCalled; + private final SerializablePipelineOptions options; + private final TupleTag mainOutputTag; + private final List> additionalOutputTags; + private final Coder inputCoder; + private final Map, Coder> outputCoders; + private final Map, KV, SideInputBroadcast>> sideInputs; + private final WindowingStrategy windowingStrategy; + private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; + // for timer + private final Map watermarks; + private final List sourceIds; + private final TimerInternals.TimerDataCoderV2 timerDataCoder; + + public ParDoStateUpdateFn( + MetricsContainerStepMapAccumulator metricsAccum, + String stepName, + DoFn doFn, + Coder keyCoder, + WindowedValue.FullWindowedValueCoder wvCoder, + SerializablePipelineOptions options, + TupleTag mainOutputTag, + List> additionalOutputTags, + Coder inputCoder, + Map, Coder> outputCoders, + Map, KV, SideInputBroadcast>> sideInputs, + WindowingStrategy windowingStrategy, + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, + Map watermarks, + List sourceIds) { + this.metricsAccum = metricsAccum; + this.stepName = stepName; + this.doFn = SerializableUtils.clone(doFn); + this.options = options; + this.mainOutputTag = mainOutputTag; + this.additionalOutputTags = additionalOutputTags; + this.keyCoder = keyCoder; + this.inputCoder = inputCoder; + this.outputCoders = outputCoders; + this.wvCoder = wvCoder; + this.sideInputs = sideInputs; + this.windowingStrategy = windowingStrategy; + this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; + this.watermarks = watermarks; + this.sourceIds = sourceIds; + this.timerDataCoder = + TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder()); + } + + @Override + public List, /*Serialized WindowedValue*/ byte[]>> + apply(ByteArray serializedKey, Option serializedValue, State state) { + if (serializedValue.isEmpty()) { + return Lists.newArrayList(); + } + + SparkStateInternals stateInternals; + final SparkTimerInternals timerInternals = + SparkTimerInternals.forStreamFromSources(sourceIds, watermarks); + final KeyT key = CoderHelpers.fromByteArray(serializedKey.getValue(), this.keyCoder); + + if (state.exists()) { + final StateAndTimers stateAndTimers = state.get(); + stateInternals = SparkStateInternals.forKeyAndState(key, stateAndTimers.getState()); + timerInternals.addTimers( + SparkTimerInternals.deserializeTimers(stateAndTimers.getTimers(), timerDataCoder)); + } else { + stateInternals = SparkStateInternals.forKey(key); + } + + final byte[] byteValue = serializedValue.get(); + final WindowedValue windowedValue = CoderHelpers.fromByteArray(byteValue, this.wvCoder); + + final WindowedValue> keyedWindowedValue = + windowedValue.withValue(KV.of(key, windowedValue.getValue())); + + if (!wasSetupCalled) { + DoFnInvokers.tryInvokeSetupFor(this.doFn, this.options.get()); + this.wasSetupCalled = true; + } + + SparkInputDataProcessor, WindowedValue>> processor = + SparkInputDataProcessor.createUnbounded(); + + final StepContext context = + new StepContext() { + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + }; + + DoFnRunner doFnRunner = + DoFnRunners.simpleRunner( + options.get(), + doFn, + CachedSideInputReader.of(new SparkSideInputReader(sideInputs)), + processor.getOutputManager(), + (TupleTag) mainOutputTag, + additionalOutputTags, + context, + inputCoder, + outputCoders, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping); + + final Coder windowCoder = + windowingStrategy.getWindowFn().windowCoder(); + + final StatefulDoFnRunner.CleanupTimer cleanUpTimer = + new StatefulDoFnRunner.TimeInternalsCleanupTimer<>(timerInternals, windowingStrategy); + + final StatefulDoFnRunner.StateCleaner stateCleaner = + new StatefulDoFnRunner.StateInternalsStateCleaner<>(doFn, stateInternals, windowCoder); + + doFnRunner = + DoFnRunners.defaultStatefulDoFnRunner( + doFn, inputCoder, doFnRunner, context, windowingStrategy, cleanUpTimer, stateCleaner); + + DoFnRunnerWithMetrics doFnRunnerWithMetrics = + new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); + + SparkProcessContext ctx = + new SparkProcessContext<>( + stepName, doFn, doFnRunnerWithMetrics, key, timerInternals.getTimers().iterator()); + + final Iterator>> iterator = + Lists.newArrayList(keyedWindowedValue).iterator(); + + final Iterator, WindowedValue>> outputIterator = + processor.createOutputIterator((Iterator) iterator, ctx); + + state.update( + StateAndTimers.of( + stateInternals.getState(), + SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder))); + + final List, WindowedValue>> resultList = + Lists.newArrayList(outputIterator); + + return (List, byte[]>>) + (List) + resultList.stream() + .map( + (Tuple2, WindowedValue> e) -> { + final TupleTag tupleTag = (TupleTag) e._1(); + final Coder outputCoder = + (Coder) outputCoders.get(tupleTag); + + @SuppressWarnings("nullness") + final WindowedValue.FullWindowedValueCoder outputWindowCoder = + WindowedValue.FullWindowedValueCoder.of(outputCoder, windowCoder); + + return Tuple2.apply( + tupleTag, + CoderHelpers.toByteArray((WindowedValue) e._2(), outputWindowCoder)); + }) + .collect(Collectors.toList()); + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java new file mode 100644 index 000000000000..23bcfcb129ce --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration; +import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectTimers; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator; +import org.apache.beam.runners.spark.stateful.StateAndTimers; +import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.SparkPCollectionView; +import org.apache.beam.runners.spark.translation.TransformEvaluator; +import org.apache.beam.runners.spark.translation.TranslationUtils; +import org.apache.beam.runners.spark.util.ByteArray; +import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; +import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.construction.ParDoTranslation; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import scala.Option; +import scala.Tuple2; + +/** + * A specialized evaluator for ParDo operations in Spark Streaming context that is invoked when + * stateful streaming is detected in the DoFn. + * + *

This class is used by {@link StreamingTransformTranslator}'s ParDo evaluator to handle + * stateful streaming operations. When a DoFn contains stateful processing logic, the translation + * process routes the execution through this evaluator instead of the standard ParDo evaluator. + * + *

The evaluator manages state handling and ensures proper processing semantics for streaming + * stateful operations in the Spark runner context. + * + *

Important: This evaluator includes validation logic that rejects DoFn implementations + * containing {@code @Timer} annotations, as timer functionality is not currently supported in the + * Spark streaming context. + */ +public class StatefulStreamingParDoEvaluator + implements TransformEvaluator, OutputT>> { + + @Override + public void evaluate( + ParDo.MultiOutput, OutputT> transform, EvaluationContext context) { + final DoFn, OutputT> doFn = transform.getFn(); + final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); + + rejectTimers(doFn); + checkArgument( + !signature.processElement().isSplittable(), + "Splittable DoFn not yet supported in streaming mode: %s", + doFn); + checkState( + signature.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", doFn); + + // options, PCollectionView, WindowingStrategy + final SerializablePipelineOptions options = context.getSerializableOptions(); + final SparkPCollectionView pviews = context.getPViews(); + final WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + + final KvCoder inputCoder = + (KvCoder) context.getInput(transform).getCoder(); + Map, Coder> outputCoders = context.getOutputCoders(); + JavaPairDStream, WindowedValue> all; + + final UnboundedDataset> unboundedDataset = + (UnboundedDataset>) context.borrowDataset(transform); + + final JavaDStream>> dStream = unboundedDataset.getDStream(); + + final DoFnSchemaInformation doFnSchemaInformation = + ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + + final Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + + final String stepName = context.getCurrentTransform().getFullName(); + + final WindowFn windowFn = windowingStrategy.getWindowFn(); + + final List sourceIds = unboundedDataset.getStreamSources(); + + // key, value coder + final Coder keyCoder = inputCoder.getKeyCoder(); + final Coder valueCoder = inputCoder.getValueCoder(); + + final WindowedValue.FullWindowedValueCoder wvCoder = + WindowedValue.FullWindowedValueCoder.of(valueCoder, windowFn.windowCoder()); + + final MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance(); + final Map, KV, SideInputBroadcast>> sideInputs = + TranslationUtils.getSideInputs( + transform.getSideInputs().values(), context.getSparkContext(), pviews); + + // Original code used multiple map operations (.map -> .mapToPair -> .mapToPair) + // which created intermediate RDDs for each transformation. + // Changed to use mapPartitionsToPair to: + // 1. Reduce the number of RDD creations by combining multiple operations + // 2. Process data in batches (partitions) rather than element by element + // 3. Improve performance by reducing serialization/deserialization overhead + // 4. Minimize the number of function objects created during execution + final JavaPairDStream< + /*Serialized KeyT*/ ByteArray, /*Serialized WindowedValue*/ byte[]> + serializedDStream = + dStream.mapPartitionsToPair( + (Iterator>> iter) -> + Iterators.transform( + iter, + (WindowedValue> windowedKV) -> { + final KeyT key = windowedKV.getValue().getKey(); + final WindowedValue windowedValue = + windowedKV.withValue(windowedKV.getValue().getValue()); + final ByteArray keyBytes = + new ByteArray(CoderHelpers.toByteArray(key, keyCoder)); + final byte[] valueBytes = + CoderHelpers.toByteArray(windowedValue, wvCoder); + return Tuple2.apply(keyBytes, valueBytes); + })); + + final Map watermarks = + GlobalWatermarkHolder.get(getBatchDuration(options)); + + @SuppressWarnings({"rawtypes", "unchecked"}) + final JavaMapWithStateDStream< + ByteArray, Option, State, List, byte[]>>> + processedPairDStream = + serializedDStream.mapWithState( + StateSpec.function( + new ParDoStateUpdateFn<>( + metricsAccum, + stepName, + doFn, + keyCoder, + (WindowedValue.FullWindowedValueCoder) wvCoder, + options, + transform.getMainOutputTag(), + transform.getAdditionalOutputTags().getAll(), + inputCoder, + outputCoders, + sideInputs, + windowingStrategy, + doFnSchemaInformation, + sideInputMapping, + watermarks, + sourceIds))); + + all = + processedPairDStream.flatMapToPair( + (List, byte[]>> list) -> + Iterators.transform( + list.iterator(), + (Tuple2, byte[]> tuple) -> { + final Coder outputCoder = outputCoders.get(tuple._1()); + @SuppressWarnings("nullness") + final WindowedValue windowedValue = + CoderHelpers.fromByteArray( + tuple._2(), + WindowedValue.FullWindowedValueCoder.of( + outputCoder, windowFn.windowCoder())); + return Tuple2.apply(tuple._1(), windowedValue); + })); + + Map, PCollection> outputs = context.getOutputs(transform); + if (hasMultipleOutputs(outputs)) { + // Caching can cause Serialization, we need to code to bytes + // more details in https://issues.apache.org/jira/browse/BEAM-2669 + Map, Coder>> coderMap = + TranslationUtils.getTupleTagCoders(outputs); + all = + all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)) + .cache() + .mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap)); + + for (Map.Entry, PCollection> output : outputs.entrySet()) { + @SuppressWarnings({"unchecked", "rawtypes"}) + JavaPairDStream, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaDStream> values = + (JavaDStream>) + (JavaDStream) TranslationUtils.dStreamValues(filtered); + context.putDataset(output.getValue(), new UnboundedDataset<>(values, sourceIds)); + } + } else { + @SuppressWarnings("unchecked") + final JavaDStream> values = + (JavaDStream>) (JavaDStream) TranslationUtils.dStreamValues(all); + + context.putDataset( + Iterables.getOnlyElement(outputs.entrySet()).getValue(), + new UnboundedDataset<>(values, sourceIds)); + } + } + + @Override + public String toNativeString() { + return "mapPartitions(new ())"; + } + + private boolean hasMultipleOutputs(Map, PCollection> outputs) { + return outputs.size() > 1; + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 5be8e718dec6..539f8ff3efe6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.spark.translation.streaming; -import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -65,6 +64,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -434,11 +434,27 @@ private static TransformEvaluator transform, final EvaluationContext context) { final DoFn doFn = transform.getFn(); + final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); checkArgument( - !DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), + !signature.processElement().isSplittable(), "Splittable DoFn not yet supported in streaming mode: %s", doFn); - rejectStateAndTimers(doFn); + checkState( + signature.onWindowExpiration() == null, + "onWindowExpiration is not supported: %s", + doFn); + + boolean stateful = + signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0; + + if (stateful) { + final StatefulStreamingParDoEvaluator delegate = + new StatefulStreamingParDoEvaluator<>(); + + delegate.evaluate((ParDo.MultiOutput) transform, context); + return; + } + final SerializablePipelineOptions options = context.getSerializableOptions(); final SparkPCollectionView pviews = context.getPViews(); final WindowingStrategy windowingStrategy = diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java index a3d7724e4363..243f3a3e533f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java @@ -527,7 +527,7 @@ public void process(ProcessContext context) { } } - private static PipelineOptions streamingOptions() { + static PipelineOptions streamingOptions() { PipelineOptions options = TestPipeline.testingPipelineOptions(); options.as(TestSparkPipelineOptions.class).setStreaming(true); return options; diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java new file mode 100644 index 000000000000..e1f000d16675 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import static org.apache.beam.runners.spark.translation.streaming.CreateStreamTest.streamingOptions; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.StreamingTest; +import org.apache.beam.runners.spark.io.CreateStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@SuppressWarnings({"unchecked", "unused"}) +public class StatefulStreamingParDoEvaluatorTest { + + @Rule public final transient TestPipeline p = TestPipeline.fromOptions(streamingOptions()); + + private PTransform>> createStreamingSource( + Pipeline pipeline) { + Instant instant = new Instant(0); + final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of()); + final Duration batchDuration = batchDuration(pipeline); + return CreateStream.of(coder, batchDuration) + .emptyBatch() + .advanceWatermarkForNextBatch(instant) + .nextBatch( + TimestampedValue.of(KV.of(1, 1), instant), + TimestampedValue.of(KV.of(1, 2), instant), + TimestampedValue.of(KV.of(1, 3), instant)) + .advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))) + .nextBatch( + TimestampedValue.of(KV.of(2, 4), instant.plus(Duration.standardSeconds(1L))), + TimestampedValue.of(KV.of(2, 5), instant.plus(Duration.standardSeconds(1L))), + TimestampedValue.of(KV.of(2, 6), instant.plus(Duration.standardSeconds(1L)))) + .advanceNextBatchWatermarkToInfinity(); + } + + private PTransform>> createStreamingSource( + Pipeline pipeline, int iterCount) { + Instant instant = new Instant(0); + final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of()); + final Duration batchDuration = batchDuration(pipeline); + + CreateStream> createStream = + CreateStream.of(coder, batchDuration).emptyBatch().advanceWatermarkForNextBatch(instant); + + int value = 1; + for (int i = 0; i < iterCount; i++) { + createStream = + createStream.nextBatch( + TimestampedValue.of(KV.of(1, value++), instant), + TimestampedValue.of(KV.of(1, value++), instant), + TimestampedValue.of(KV.of(1, value++), instant)); + + instant = instant.plus(Duration.standardSeconds(1L)); + createStream = createStream.advanceWatermarkForNextBatch(instant); + + createStream = + createStream.nextBatch( + TimestampedValue.of(KV.of(2, value++), instant), + TimestampedValue.of(KV.of(2, value++), instant), + TimestampedValue.of(KV.of(2, value++), instant)); + + instant = instant.plus(Duration.standardSeconds(1L)); + createStream = createStream.advanceWatermarkForNextBatch(instant); + } + + return createStream.advanceNextBatchWatermarkToInfinity(); + } + + private static class StatefulWithTimerDoFn extends DoFn { + @StateId("some-state") + private final StateSpec> someStringStateSpec = + StateSpecs.value(StringUtf8Coder.of()); + + @TimerId("some-timer") + private final TimerSpec someTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + @ProcessElement + public void process( + @Element InputT element, @StateId("some-state") ValueState someStringStage) { + // ignore... + } + + @OnTimer("some-timer") + public void onTimer() { + // ignore... + } + } + + private static class StatefulDoFn extends DoFn, KV> { + + @StateId("test-state") + private final StateSpec> testState = StateSpecs.value(); + + @ProcessElement + public void process( + @Element KV element, + @StateId("test-state") ValueState testState, + OutputReceiver> output) { + final Integer value = element.getValue(); + final Integer currentState = firstNonNull(testState.read(), 0); + final Integer newState = currentState + value; + testState.write(newState); + + final KV result = KV.of(element.getKey(), newState); + output.output(result); + } + } + + @Category(StreamingTest.class) + @Test + public void shouldRejectTimer() { + p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulWithTimerDoFn<>())); + + final UnsupportedOperationException exception = + assertThrows(UnsupportedOperationException.class, p::run); + + assertEquals( + "Found TimerId annotations on " + + StatefulWithTimerDoFn.class.getName() + + ", but DoFn cannot yet be used with timers in the SparkRunner.", + exception.getMessage()); + } + + @Category(StreamingTest.class) + @Test + public void shouldProcessGlobalWidowStatefulParDo() { + final PCollection> result = + p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulDoFn())); + + PAssert.that(result) + .containsInAnyOrder( + // key 1 + KV.of(1, 1), // 1 + KV.of(1, 3), // 1 + 2 + KV.of(1, 6), // 3 + 3 + // key 2 + KV.of(2, 4), // 4 + KV.of(2, 9), // 4 + 5 + KV.of(2, 15)); // 9 + 6 + + p.run().waitUntilFinish(); + } + + @Category(StreamingTest.class) + @Test + public void shouldProcessWindowedStatefulParDo() { + final PCollection> result = + p.apply(createStreamingSource(p, 2)) + .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1L)))) + .apply(ParDo.of(new StatefulDoFn())); + + PAssert.that(result) + .containsInAnyOrder( + // Windowed Key 1 + KV.of(1, 1), // 1 + KV.of(1, 3), // 1 + 2 + KV.of(1, 6), // 3 + 3 + + // Windowed Key 2 + KV.of(2, 4), // 4 + KV.of(2, 9), // 4 + 5 + KV.of(2, 15), // 9 + 6 + + // Windowed Key 1 + KV.of(1, 7), // 7 + KV.of(1, 15), // 7 + 8 + KV.of(1, 24), // 15 + 9 + + // Windowed Key 2 + KV.of(2, 10), // 10 + KV.of(2, 21), // 10 + 11 + KV.of(2, 33) // 21 + 12 + ); + + p.run().waitUntilFinish(); + } + + private Duration batchDuration(Pipeline pipeline) { + return Duration.millis( + pipeline.getOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis()); + } +}