outputWindowCoder =
+ WindowedValue.FullWindowedValueCoder.of(outputCoder, windowCoder);
+
+ return Tuple2.apply(
+ tupleTag,
+ CoderHelpers.toByteArray((WindowedValue) e._2(), outputWindowCoder));
+ })
+ .collect(Collectors.toList());
+ }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java
new file mode 100644
index 000000000000..23bcfcb129ce
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluator.java
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration;
+import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectTimers;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator;
+import org.apache.beam.runners.spark.stateful.StateAndTimers;
+import org.apache.beam.runners.spark.translation.EvaluationContext;
+import org.apache.beam.runners.spark.translation.SparkPCollectionView;
+import org.apache.beam.runners.spark.translation.TransformEvaluator;
+import org.apache.beam.runners.spark.translation.TranslationUtils;
+import org.apache.beam.runners.spark.util.ByteArray;
+import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
+import org.apache.beam.runners.spark.util.SideInputBroadcast;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.construction.ParDoTranslation;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
+import org.apache.spark.streaming.State;
+import org.apache.spark.streaming.StateSpec;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import scala.Option;
+import scala.Tuple2;
+
+/**
+ * A specialized evaluator for ParDo operations in Spark Streaming context that is invoked when
+ * stateful streaming is detected in the DoFn.
+ *
+ * This class is used by {@link StreamingTransformTranslator}'s ParDo evaluator to handle
+ * stateful streaming operations. When a DoFn contains stateful processing logic, the translation
+ * process routes the execution through this evaluator instead of the standard ParDo evaluator.
+ *
+ *
The evaluator manages state handling and ensures proper processing semantics for streaming
+ * stateful operations in the Spark runner context.
+ *
+ *
Important: This evaluator includes validation logic that rejects DoFn implementations
+ * containing {@code @Timer} annotations, as timer functionality is not currently supported in the
+ * Spark streaming context.
+ */
+public class StatefulStreamingParDoEvaluator
+ implements TransformEvaluator, OutputT>> {
+
+ @Override
+ public void evaluate(
+ ParDo.MultiOutput, OutputT> transform, EvaluationContext context) {
+ final DoFn, OutputT> doFn = transform.getFn();
+ final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn);
+
+ rejectTimers(doFn);
+ checkArgument(
+ !signature.processElement().isSplittable(),
+ "Splittable DoFn not yet supported in streaming mode: %s",
+ doFn);
+ checkState(
+ signature.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", doFn);
+
+ // options, PCollectionView, WindowingStrategy
+ final SerializablePipelineOptions options = context.getSerializableOptions();
+ final SparkPCollectionView pviews = context.getPViews();
+ final WindowingStrategy, ?> windowingStrategy =
+ context.getInput(transform).getWindowingStrategy();
+
+ final KvCoder inputCoder =
+ (KvCoder) context.getInput(transform).getCoder();
+ Map, Coder>> outputCoders = context.getOutputCoders();
+ JavaPairDStream, WindowedValue>> all;
+
+ final UnboundedDataset> unboundedDataset =
+ (UnboundedDataset>) context.borrowDataset(transform);
+
+ final JavaDStream>> dStream = unboundedDataset.getDStream();
+
+ final DoFnSchemaInformation doFnSchemaInformation =
+ ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
+
+ final Map> sideInputMapping =
+ ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
+
+ final String stepName = context.getCurrentTransform().getFullName();
+
+ final WindowFn, ?> windowFn = windowingStrategy.getWindowFn();
+
+ final List sourceIds = unboundedDataset.getStreamSources();
+
+ // key, value coder
+ final Coder keyCoder = inputCoder.getKeyCoder();
+ final Coder valueCoder = inputCoder.getValueCoder();
+
+ final WindowedValue.FullWindowedValueCoder wvCoder =
+ WindowedValue.FullWindowedValueCoder.of(valueCoder, windowFn.windowCoder());
+
+ final MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
+ final Map, KV, SideInputBroadcast>>> sideInputs =
+ TranslationUtils.getSideInputs(
+ transform.getSideInputs().values(), context.getSparkContext(), pviews);
+
+ // Original code used multiple map operations (.map -> .mapToPair -> .mapToPair)
+ // which created intermediate RDDs for each transformation.
+ // Changed to use mapPartitionsToPair to:
+ // 1. Reduce the number of RDD creations by combining multiple operations
+ // 2. Process data in batches (partitions) rather than element by element
+ // 3. Improve performance by reducing serialization/deserialization overhead
+ // 4. Minimize the number of function objects created during execution
+ final JavaPairDStream<
+ /*Serialized KeyT*/ ByteArray, /*Serialized WindowedValue*/ byte[]>
+ serializedDStream =
+ dStream.mapPartitionsToPair(
+ (Iterator>> iter) ->
+ Iterators.transform(
+ iter,
+ (WindowedValue> windowedKV) -> {
+ final KeyT key = windowedKV.getValue().getKey();
+ final WindowedValue windowedValue =
+ windowedKV.withValue(windowedKV.getValue().getValue());
+ final ByteArray keyBytes =
+ new ByteArray(CoderHelpers.toByteArray(key, keyCoder));
+ final byte[] valueBytes =
+ CoderHelpers.toByteArray(windowedValue, wvCoder);
+ return Tuple2.apply(keyBytes, valueBytes);
+ }));
+
+ final Map watermarks =
+ GlobalWatermarkHolder.get(getBatchDuration(options));
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ final JavaMapWithStateDStream<
+ ByteArray, Option, State, List, byte[]>>>
+ processedPairDStream =
+ serializedDStream.mapWithState(
+ StateSpec.function(
+ new ParDoStateUpdateFn<>(
+ metricsAccum,
+ stepName,
+ doFn,
+ keyCoder,
+ (WindowedValue.FullWindowedValueCoder) wvCoder,
+ options,
+ transform.getMainOutputTag(),
+ transform.getAdditionalOutputTags().getAll(),
+ inputCoder,
+ outputCoders,
+ sideInputs,
+ windowingStrategy,
+ doFnSchemaInformation,
+ sideInputMapping,
+ watermarks,
+ sourceIds)));
+
+ all =
+ processedPairDStream.flatMapToPair(
+ (List, byte[]>> list) ->
+ Iterators.transform(
+ list.iterator(),
+ (Tuple2, byte[]> tuple) -> {
+ final Coder> outputCoder = outputCoders.get(tuple._1());
+ @SuppressWarnings("nullness")
+ final WindowedValue> windowedValue =
+ CoderHelpers.fromByteArray(
+ tuple._2(),
+ WindowedValue.FullWindowedValueCoder.of(
+ outputCoder, windowFn.windowCoder()));
+ return Tuple2.apply(tuple._1(), windowedValue);
+ }));
+
+ Map, PCollection>> outputs = context.getOutputs(transform);
+ if (hasMultipleOutputs(outputs)) {
+ // Caching can cause Serialization, we need to code to bytes
+ // more details in https://issues.apache.org/jira/browse/BEAM-2669
+ Map, Coder>> coderMap =
+ TranslationUtils.getTupleTagCoders(outputs);
+ all =
+ all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
+ .cache()
+ .mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
+
+ for (Map.Entry, PCollection>> output : outputs.entrySet()) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ JavaPairDStream, WindowedValue>> filtered =
+ all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
+ @SuppressWarnings("unchecked")
+ // Object is the best we can do since different outputs can have different tags
+ JavaDStream> values =
+ (JavaDStream>)
+ (JavaDStream>) TranslationUtils.dStreamValues(filtered);
+ context.putDataset(output.getValue(), new UnboundedDataset<>(values, sourceIds));
+ }
+ } else {
+ @SuppressWarnings("unchecked")
+ final JavaDStream> values =
+ (JavaDStream>) (JavaDStream>) TranslationUtils.dStreamValues(all);
+
+ context.putDataset(
+ Iterables.getOnlyElement(outputs.entrySet()).getValue(),
+ new UnboundedDataset<>(values, sourceIds));
+ }
+ }
+
+ @Override
+ public String toNativeString() {
+ return "mapPartitions(new ())";
+ }
+
+ private boolean hasMultipleOutputs(Map, PCollection>> outputs) {
+ return outputs.size() > 1;
+ }
+}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 5be8e718dec6..539f8ff3efe6 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.spark.translation.streaming;
-import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
@@ -65,6 +64,7 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -434,11 +434,27 @@ private static TransformEvaluator transform, final EvaluationContext context) {
final DoFn doFn = transform.getFn();
+ final DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn);
checkArgument(
- !DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(),
+ !signature.processElement().isSplittable(),
"Splittable DoFn not yet supported in streaming mode: %s",
doFn);
- rejectStateAndTimers(doFn);
+ checkState(
+ signature.onWindowExpiration() == null,
+ "onWindowExpiration is not supported: %s",
+ doFn);
+
+ boolean stateful =
+ signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
+
+ if (stateful) {
+ final StatefulStreamingParDoEvaluator, ?, ?> delegate =
+ new StatefulStreamingParDoEvaluator<>();
+
+ delegate.evaluate((ParDo.MultiOutput) transform, context);
+ return;
+ }
+
final SerializablePipelineOptions options = context.getSerializableOptions();
final SparkPCollectionView pviews = context.getPViews();
final WindowingStrategy, ?> windowingStrategy =
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
index a3d7724e4363..243f3a3e533f 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
@@ -527,7 +527,7 @@ public void process(ProcessContext context) {
}
}
- private static PipelineOptions streamingOptions() {
+ static PipelineOptions streamingOptions() {
PipelineOptions options = TestPipeline.testingPipelineOptions();
options.as(TestSparkPipelineOptions.class).setStreaming(true);
return options;
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java
new file mode 100644
index 000000000000..e1f000d16675
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StatefulStreamingParDoEvaluatorTest.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.apache.beam.runners.spark.translation.streaming.CreateStreamTest.streamingOptions;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.StreamingTest;
+import org.apache.beam.runners.spark.io.CreateStream;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+@SuppressWarnings({"unchecked", "unused"})
+public class StatefulStreamingParDoEvaluatorTest {
+
+ @Rule public final transient TestPipeline p = TestPipeline.fromOptions(streamingOptions());
+
+ private PTransform>> createStreamingSource(
+ Pipeline pipeline) {
+ Instant instant = new Instant(0);
+ final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
+ final Duration batchDuration = batchDuration(pipeline);
+ return CreateStream.of(coder, batchDuration)
+ .emptyBatch()
+ .advanceWatermarkForNextBatch(instant)
+ .nextBatch(
+ TimestampedValue.of(KV.of(1, 1), instant),
+ TimestampedValue.of(KV.of(1, 2), instant),
+ TimestampedValue.of(KV.of(1, 3), instant))
+ .advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L)))
+ .nextBatch(
+ TimestampedValue.of(KV.of(2, 4), instant.plus(Duration.standardSeconds(1L))),
+ TimestampedValue.of(KV.of(2, 5), instant.plus(Duration.standardSeconds(1L))),
+ TimestampedValue.of(KV.of(2, 6), instant.plus(Duration.standardSeconds(1L))))
+ .advanceNextBatchWatermarkToInfinity();
+ }
+
+ private PTransform>> createStreamingSource(
+ Pipeline pipeline, int iterCount) {
+ Instant instant = new Instant(0);
+ final KvCoder coder = KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
+ final Duration batchDuration = batchDuration(pipeline);
+
+ CreateStream> createStream =
+ CreateStream.of(coder, batchDuration).emptyBatch().advanceWatermarkForNextBatch(instant);
+
+ int value = 1;
+ for (int i = 0; i < iterCount; i++) {
+ createStream =
+ createStream.nextBatch(
+ TimestampedValue.of(KV.of(1, value++), instant),
+ TimestampedValue.of(KV.of(1, value++), instant),
+ TimestampedValue.of(KV.of(1, value++), instant));
+
+ instant = instant.plus(Duration.standardSeconds(1L));
+ createStream = createStream.advanceWatermarkForNextBatch(instant);
+
+ createStream =
+ createStream.nextBatch(
+ TimestampedValue.of(KV.of(2, value++), instant),
+ TimestampedValue.of(KV.of(2, value++), instant),
+ TimestampedValue.of(KV.of(2, value++), instant));
+
+ instant = instant.plus(Duration.standardSeconds(1L));
+ createStream = createStream.advanceWatermarkForNextBatch(instant);
+ }
+
+ return createStream.advanceNextBatchWatermarkToInfinity();
+ }
+
+ private static class StatefulWithTimerDoFn extends DoFn {
+ @StateId("some-state")
+ private final StateSpec> someStringStateSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @TimerId("some-timer")
+ private final TimerSpec someTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+ @ProcessElement
+ public void process(
+ @Element InputT element, @StateId("some-state") ValueState someStringStage) {
+ // ignore...
+ }
+
+ @OnTimer("some-timer")
+ public void onTimer() {
+ // ignore...
+ }
+ }
+
+ private static class StatefulDoFn extends DoFn, KV> {
+
+ @StateId("test-state")
+ private final StateSpec> testState = StateSpecs.value();
+
+ @ProcessElement
+ public void process(
+ @Element KV element,
+ @StateId("test-state") ValueState testState,
+ OutputReceiver> output) {
+ final Integer value = element.getValue();
+ final Integer currentState = firstNonNull(testState.read(), 0);
+ final Integer newState = currentState + value;
+ testState.write(newState);
+
+ final KV result = KV.of(element.getKey(), newState);
+ output.output(result);
+ }
+ }
+
+ @Category(StreamingTest.class)
+ @Test
+ public void shouldRejectTimer() {
+ p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulWithTimerDoFn<>()));
+
+ final UnsupportedOperationException exception =
+ assertThrows(UnsupportedOperationException.class, p::run);
+
+ assertEquals(
+ "Found TimerId annotations on "
+ + StatefulWithTimerDoFn.class.getName()
+ + ", but DoFn cannot yet be used with timers in the SparkRunner.",
+ exception.getMessage());
+ }
+
+ @Category(StreamingTest.class)
+ @Test
+ public void shouldProcessGlobalWidowStatefulParDo() {
+ final PCollection> result =
+ p.apply(createStreamingSource(p)).apply(ParDo.of(new StatefulDoFn()));
+
+ PAssert.that(result)
+ .containsInAnyOrder(
+ // key 1
+ KV.of(1, 1), // 1
+ KV.of(1, 3), // 1 + 2
+ KV.of(1, 6), // 3 + 3
+ // key 2
+ KV.of(2, 4), // 4
+ KV.of(2, 9), // 4 + 5
+ KV.of(2, 15)); // 9 + 6
+
+ p.run().waitUntilFinish();
+ }
+
+ @Category(StreamingTest.class)
+ @Test
+ public void shouldProcessWindowedStatefulParDo() {
+ final PCollection> result =
+ p.apply(createStreamingSource(p, 2))
+ .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1L))))
+ .apply(ParDo.of(new StatefulDoFn()));
+
+ PAssert.that(result)
+ .containsInAnyOrder(
+ // Windowed Key 1
+ KV.of(1, 1), // 1
+ KV.of(1, 3), // 1 + 2
+ KV.of(1, 6), // 3 + 3
+
+ // Windowed Key 2
+ KV.of(2, 4), // 4
+ KV.of(2, 9), // 4 + 5
+ KV.of(2, 15), // 9 + 6
+
+ // Windowed Key 1
+ KV.of(1, 7), // 7
+ KV.of(1, 15), // 7 + 8
+ KV.of(1, 24), // 15 + 9
+
+ // Windowed Key 2
+ KV.of(2, 10), // 10
+ KV.of(2, 21), // 10 + 11
+ KV.of(2, 33) // 21 + 12
+ );
+
+ p.run().waitUntilFinish();
+ }
+
+ private Duration batchDuration(Pipeline pipeline) {
+ return Duration.millis(
+ pipeline.getOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis());
+ }
+}
diff --git a/sdks/go/pkg/beam/core/core.go b/sdks/go/pkg/beam/core/core.go
index 1b478f483077..a183ddf384ed 100644
--- a/sdks/go/pkg/beam/core/core.go
+++ b/sdks/go/pkg/beam/core/core.go
@@ -27,7 +27,7 @@ const (
// SdkName is the human readable name of the SDK for UserAgents.
SdkName = "Apache Beam SDK for Go"
// SdkVersion is the current version of the SDK.
- SdkVersion = "2.62.0.dev"
+ SdkVersion = "2.63.0.dev"
// DefaultDockerImage represents the associated image for this release.
DefaultDockerImage = "apache/beam_go_sdk:" + SdkVersion
diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle
index 0cfa8da4eb7d..319848b7626b 100644
--- a/sdks/java/io/iceberg/build.gradle
+++ b/sdks/java/io/iceberg/build.gradle
@@ -37,7 +37,7 @@ def hadoopVersions = [
hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")}
-def iceberg_version = "1.4.2"
+def iceberg_version = "1.6.1"
def parquet_version = "1.12.0"
def orc_version = "1.9.2"
diff --git a/sdks/java/io/iceberg/hive/build.gradle b/sdks/java/io/iceberg/hive/build.gradle
index 2d0d2bcc5cde..9884b45af7a1 100644
--- a/sdks/java/io/iceberg/hive/build.gradle
+++ b/sdks/java/io/iceberg/hive/build.gradle
@@ -30,7 +30,7 @@ ext.summary = "Runtime dependencies needed for Hive catalog integration."
def hive_version = "3.1.3"
def hbase_version = "2.6.1-hadoop3"
def hadoop_version = "3.4.1"
-def iceberg_version = "1.4.2"
+def iceberg_version = "1.6.1"
def avatica_version = "1.25.0"
dependencies {
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 22d041f34f8b..57d8197a3a00 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1438,17 +1438,6 @@ def __hash__(self):
return hash(
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))
- @classmethod
- def from_type_hint(cls, typehint, registry):
- # type: (Any, CoderRegistry) -> WindowedValueCoder
- # Ideally this'd take two parameters so that one could hint at
- # the window type as well instead of falling back to the
- # pickle coders.
- return cls(registry.get_coder(typehint.inner_type))
-
- def to_type_hint(self):
- return typehints.WindowedValue[self.wrapped_value_coder.to_type_hint()]
-
Coder.register_structured_urn(
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py
index bddd2cb57e06..dc9780e36be3 100644
--- a/sdks/python/apache_beam/coders/coders_test.py
+++ b/sdks/python/apache_beam/coders/coders_test.py
@@ -258,12 +258,6 @@ def test_numpy_int(self):
_ = indata | "CombinePerKey" >> beam.CombinePerKey(sum)
-class WindowedValueCoderTest(unittest.TestCase):
- def test_to_type_hint(self):
- coder = coders.WindowedValueCoder(coders.VarIntCoder())
- self.assertEqual(coder.to_type_hint(), typehints.WindowedValue[int]) # type: ignore[misc]
-
-
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 892f508d0136..1667cb7a916a 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -94,8 +94,6 @@ def register_standard_coders(self, fallback_coder):
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
- self._register_coder_internal(
- typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [
diff --git a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py
index f7e12dcead03..223b973e5fbe 100644
--- a/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py
+++ b/sdks/python/apache_beam/testing/benchmarks/inference/tensorflow_mnist_classification_cost_benchmark.py
@@ -31,7 +31,7 @@ def test(self):
extra_opts['input'] = self.pipeline.get_option('input_file')
extra_opts['output'] = self.pipeline.get_option('output_file')
extra_opts['model_path'] = self.pipeline.get_option('model')
- tensorflow_mnist_classification.run(
+ self.result = tensorflow_mnist_classification.run(
self.pipeline.get_full_options_as_args(**extra_opts),
save_main_session=False)
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index c9fd2c76b0db..a03652de2496 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -33,7 +33,6 @@
from typing import Callable
from typing import Iterable
from typing import List
-from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
@@ -74,13 +73,11 @@
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
-from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import shared
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
from apache_beam.utils.sharded_key import ShardedKey
-from apache_beam.utils.timestamp import Timestamp
if TYPE_CHECKING:
from apache_beam.runners.pipeline_context import PipelineContext
@@ -956,10 +953,6 @@ def restore_timestamps(element):
window.GlobalWindows.windowed_value((key, value), timestamp)
for (value, timestamp) in values
]
-
- ungrouped = pcoll | Map(reify_timestamps).with_input_types(
- Tuple[K, V]).with_output_types(
- Tuple[K, Tuple[V, Optional[Timestamp]]])
else:
# typing: All conditional function variants must have identical signatures
@@ -973,8 +966,7 @@ def restore_timestamps(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]
- ungrouped = pcoll | Map(reify_timestamps).with_input_types(
- Tuple[K, V]).with_output_types(Tuple[K, TypedWindowedValue[V]])
+ ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
@@ -1026,8 +1018,7 @@ def expand(self, pcoll):
pcoll | 'AddRandomKeys' >>
Map(lambda t: (random.randrange(0, self.num_buckets), t)
).with_input_types(T).with_output_types(Tuple[int, T])
- | ReshufflePerKey().with_input_types(Tuple[int, T]).with_output_types(
- Tuple[int, T])
+ | ReshufflePerKey()
| 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
Tuple[int, T]).with_output_types(T))
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index db73310dfe25..d86509c7dde3 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -1010,60 +1010,6 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
equal_to(expected_data),
label="formatted_after_reshuffle")
- global _Unpicklable
- global _UnpicklableCoder
-
- class _Unpicklable(object):
- def __init__(self, value):
- self.value = value
-
- def __getstate__(self):
- raise NotImplementedError()
-
- def __setstate__(self, state):
- raise NotImplementedError()
-
- class _UnpicklableCoder(beam.coders.Coder):
- def encode(self, value):
- return str(value.value).encode()
-
- def decode(self, encoded):
- return _Unpicklable(int(encoded.decode()))
-
- def to_type_hint(self):
- return _Unpicklable
-
- def is_deterministic(self):
- return True
-
- def test_reshuffle_unpicklable_in_global_window(self):
- beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
-
- with TestPipeline() as pipeline:
- data = [_Unpicklable(i) for i in range(5)]
- expected_data = [0, 10, 20, 30, 40]
- result = (
- pipeline
- | beam.Create(data)
- | beam.WindowInto(GlobalWindows())
- | beam.Reshuffle()
- | beam.Map(lambda u: u.value * 10))
- assert_that(result, equal_to(expected_data))
-
- def test_reshuffle_unpicklable_in_non_global_window(self):
- beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
-
- with TestPipeline() as pipeline:
- data = [_Unpicklable(i) for i in range(5)]
- expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
- result = (
- pipeline
- | beam.Create(data)
- | beam.WindowInto(window.SlidingWindows(size=3, period=1))
- | beam.Reshuffle()
- | beam.Map(lambda u: u.value * 10))
- assert_that(result, equal_to(expected_data))
-
class WithKeysTest(unittest.TestCase):
def setUp(self):
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 381d4f7aae2b..6f704b37a969 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -24,13 +24,9 @@
import sys
import types
import typing
-from typing import Generic
-from typing import TypeVar
from apache_beam.typehints import typehints
-T = TypeVar('T')
-
_LOGGER = logging.getLogger(__name__)
# Describes an entry in the type map in convert_to_beam_type.
@@ -220,18 +216,6 @@ def convert_collections_to_typing(typ):
return typ
-# During type inference of WindowedValue, we need to pass in the inner value
-# type. This cannot be achieved immediately with WindowedValue class because it
-# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
-# could work in theory. However, the class is cythonized and it seems that
-# cython does not handle generic classes well.
-# The workaround here is to create a separate class solely for the type
-# inference purpose. This class should never be used for creating instances.
-class TypedWindowedValue(Generic[T]):
- def __init__(self, *args, **kwargs):
- raise NotImplementedError("This class is solely for type inference")
-
-
def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.
@@ -283,12 +267,6 @@ def convert_to_beam_type(typ):
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
- elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
- getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
- # Need to pass through WindowedValue class so that it can be converted
- # to the correct type constraint in Beam
- # This is needed to fix https://github.com/apache/beam/issues/33356
- pass
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
return typ
@@ -346,10 +324,6 @@ def convert_to_beam_type(typ):
match=_match_is_exactly_collection,
arity=1,
beam_type=typehints.Collection),
- _TypeMapEntry(
- match=_match_issubclass(TypedWindowedValue),
- arity=1,
- beam_type=typehints.WindowedValue),
]
# Find the first matching entry.
diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py
index a65a0f753826..0e18e887c2a0 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -1213,15 +1213,6 @@ def type_check(self, instance):
repr(self.inner_type),
instance.value.__class__.__name__))
- def bind_type_variables(self, bindings):
- bound_inner_type = bind_type_variables(self.inner_type, bindings)
- if bound_inner_type == self.inner_type:
- return self
- return WindowedValue[bound_inner_type]
-
- def __repr__(self):
- return 'WindowedValue[%s]' % repr(self.inner_type)
-
class GeneratorHint(IteratorHint):
"""A Generator type hint.
diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py
index 9974bb68bccf..39185712b141 100644
--- a/sdks/python/apache_beam/version.py
+++ b/sdks/python/apache_beam/version.py
@@ -17,4 +17,4 @@
"""Apache Beam SDK version information and utilities."""
-__version__ = '2.62.0.dev'
+__version__ = '2.63.0.dev'
diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py
index 27e17029f387..fe5727f3ef92 100644
--- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py
+++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py
@@ -250,7 +250,7 @@ def main():
if options.markdown_file or options.html_file:
if '-' in transforms[0]:
extra_docs = 'Supported languages: ' + ', '.join(
- t.split('-')[-1] for t in sorted(transforms))
+ t.split('-')[-1] for t in sorted(transforms)) + '.'
else:
extra_docs = ''
markdown_out.write(
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py
index 8f4a2118c236..7f7da7aca6a9 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -23,6 +23,7 @@
from typing import Callable
from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
@@ -619,6 +620,13 @@ def _PyJsFilter(
See more complete documentation on
[YAML Filtering](https://beam.apache.org/documentation/sdks/yaml-udf/#filtering).
+
+ Args:
+ keep: An expression evaluating to true for those records that should be kept.
+ language: The language of the above expression.
+ Defaults to generic.
+ error_handling: Whether and where to output records that throw errors when
+ the above expressions are evaluated.
""" # pylint: disable=line-too-long
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic')
return pcoll | beam.Filter(keep_fn)
@@ -664,14 +672,32 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
@beam.ptransform.ptransform_fn
@maybe_with_exception_handling_transform_fn
-def _PyJsMapToFields(pcoll, language='generic', **mapping_args):
+def _PyJsMapToFields(
+ pcoll,
+ fields: Mapping[str, Union[str, Mapping[str, str]]],
+ append: Optional[bool] = False,
+ drop: Optional[Iterable[str]] = None,
+ language: Optional[str] = None):
"""Creates records with new fields defined in terms of the input fields.
See more complete documentation on
[YAML Mapping Functions](https://beam.apache.org/documentation/sdks/yaml-udf/#mapping-functions).
+
+ Args:
+ fields: The output fields to compute, each mapping to the expression or
+ callable that creates them.
+ append: Whether to append the created fields to the set of
+ fields already present, outputting a union of both the new fields and
+ the original fields for each record. Defaults to False.
+ drop: If `append` is true, enumerates a subset of fields from the
+ original record that should not be kept
+ language: The language used to define (and execute) the
+ expressions and/or callables in `fields`. Defaults to generic.
+ error_handling: Whether and where to output records that throw errors when
+ the above expressions are evaluated.
""" # pylint: disable=line-too-long
input_schema, fields = normalize_fields(
- pcoll, language=language, **mapping_args)
+ pcoll, fields, drop or (), append, language=language or 'generic')
if language == 'javascript':
options.YamlOptions.check_enabled(pcoll.pipeline, 'javascript')
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py
index 7cb96a7efb32..327023742bc6 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -956,6 +956,21 @@ def preprocess_languages(spec):
else:
return spec
+ def validate_transform_references(spec):
+ name = spec.get('name', '')
+ transform_type = spec.get('type')
+ inputs = spec.get('input').get('input', [])
+
+ if not is_empty(inputs):
+ input_values = [inputs] if isinstance(inputs, str) else inputs
+ for input_value in input_values:
+ if input_value in (name, transform_type):
+ raise ValueError(
+ f"Circular reference detected: Transform {name} "
+ f"references itself as input in {identify_object(spec)}")
+
+ return spec
+
for phase in [
ensure_transforms_have_types,
normalize_mapping,
@@ -966,6 +981,7 @@ def preprocess_languages(spec):
preprocess_chain,
tag_explicit_inputs,
normalize_inputs_outputs,
+ validate_transform_references,
preprocess_flattened_inputs,
ensure_errors_consumed,
preprocess_windowing,
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index 7fcea7e2b662..b9caca4ca9f4 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -259,6 +259,51 @@ def test_csv_to_json(self):
lines=True).sort_values('rank').reindex()
pd.testing.assert_frame_equal(data, result)
+ def test_circular_reference_validation(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ # pylint: disable=expression-not-assigned
+ with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
+ p | YamlTransform(
+ '''
+ type: composite
+ transforms:
+ - type: Create
+ name: Create
+ config:
+ elements: [0, 1, 3, 4]
+ input: Create
+ - type: PyMap
+ name: PyMap
+ config:
+ fn: "lambda row: row.element * row.element"
+ input: Create
+ output: PyMap
+ ''',
+ providers=TEST_PROVIDERS)
+
+ def test_circular_reference_multi_inputs_validation(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ # pylint: disable=expression-not-assigned
+ with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
+ p | YamlTransform(
+ '''
+ type: composite
+ transforms:
+ - type: Create
+ name: Create
+ config:
+ elements: [0, 1, 3, 4]
+ - type: PyMap
+ name: PyMap
+ config:
+ fn: "lambda row: row.element * row.element"
+ input: [Create, PyMap]
+ output: PyMap
+ ''',
+ providers=TEST_PROVIDERS)
+
def test_name_is_not_ambiguous(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
@@ -285,7 +330,7 @@ def test_name_is_ambiguous(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# pylint: disable=expression-not-assigned
- with self.assertRaisesRegex(ValueError, r'Ambiguous.*'):
+ with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
p | YamlTransform(
'''
type: composite
diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json
index 9ccfcaa663d1..3ed0a0e427f4 100644
--- a/sdks/typescript/package.json
+++ b/sdks/typescript/package.json
@@ -1,6 +1,6 @@
{
"name": "apache-beam",
- "version": "2.62.0-SNAPSHOT",
+ "version": "2.63.0-SNAPSHOT",
"devDependencies": {
"@google-cloud/bigquery": "^5.12.0",
"@types/mocha": "^9.0.0",
diff --git a/start-build-env.sh b/start-build-env.sh
index b788146eb988..0f23f32a269c 100755
--- a/start-build-env.sh
+++ b/start-build-env.sh
@@ -91,7 +91,7 @@ RUN echo "${USER_NAME} ALL=NOPASSWD: ALL" > "/etc/sudoers.d/beam-build-${USER_ID
ENV HOME "${DOCKER_HOME_DIR}"
ENV GOPATH ${DOCKER_HOME_DIR}/beam/sdks/go/examples/.gogradle/project_gopath
# This next command still runs as root causing the ~/.cache/go-build to be owned by root
-RUN go get github.com/linkedin/goavro/v2
+RUN go mod init beam-build-${USER_ID} && go get github.com/linkedin/goavro/v2
RUN chown -R ${USER_NAME}:${GROUP_ID} ${DOCKER_HOME_DIR}/.cache
UserSpecificDocker