diff --git a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json index b98aece75634..d5ac7fc60d7f 100644 --- a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json @@ -1,5 +1,6 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", "modification": 1, + "https://github.com/apache/beam/pull/32440": "testing datastream optimizations", "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json index dd9afb90e638..300fbf52b011 100644 --- a/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json @@ -1,3 +1,4 @@ { - "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" + "https://github.com/apache/beam/pull/32440": "testing datastream optimizations", + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json index b26833333238..1d083be7e29a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", + "https://github.com/apache/beam/pull/32440": "test new datastream runner for batch" "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json index bdd2197e534a..76caf9b6954d 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json @@ -1,4 +1,6 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", "modification": "1" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32440": "test new datastream runner for batch" } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json index e3d6056a5de9..2dd3a2471d89 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 1, + "https://github.com/apache/beam/pull/32440": "test new datastream runner for batch" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json index 9200c368abbe..cb7966397921 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json @@ -1,5 +1,6 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32440": "testing datastream optimizations", "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 1eb60f6e4959..0e40218bf035 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,6 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 3 + "modification": 4, + "https://github.com/apache/beam/pull/32440": "test new datastream runner for batch" } diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json index 236b7bee8af8..bb1b9f4c25e9 100644 --- a/.github/trigger_files/beam_PostCommit_XVR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_XVR_Flink.json @@ -1,3 +1,4 @@ { - "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" + "https://github.com/apache/beam/pull/32440": "testing datastream optimizations", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java index 0759487565b0..3e42bb54494e 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/GroupAlsoByWindowViaWindowSetNewDoFn.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.core; import java.util.Collection; +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachines; import org.apache.beam.sdk.transforms.DoFn; @@ -41,6 +42,7 @@ public class GroupAlsoByWindowViaWindowSetNewDoFn< extends DoFn> { private static final long serialVersionUID = 1L; + private final RunnerApi.Trigger triggerProto; public static DoFn, KV> create( @@ -86,6 +88,7 @@ public GroupAlsoByWindowViaWindowSetNewDoFn( this.windowingStrategy = noWildcard; this.reduceFn = reduceFn; this.stateInternalsFactory = stateInternalsFactory; + this.triggerProto = TriggerTranslation.toProto(windowingStrategy.getTrigger()); } private OutputWindowedValue> outputWindowedValue() { @@ -124,8 +127,7 @@ public void processElement(ProcessContext c) throws Exception { key, windowingStrategy, ExecutableTriggerStateMachine.create( - TriggerStateMachines.stateMachineForTrigger( - TriggerTranslation.toProto(windowingStrategy.getTrigger()))), + TriggerStateMachines.stateMachineForTrigger(triggerProto)), stateInternals, timerInternals, outputWindowedValue(), diff --git a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java index 0f87271a9779..30dde7ace394 100644 --- a/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ b/runners/flink/1.17/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -47,23 +47,21 @@ public class CoderTypeSerializer extends TypeSerializer { private final Coder coder; - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - private final boolean fasterCopy; public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { + this( + coder, + Preconditions.checkNotNull(pipelineOptions) + .get() + .as(FlinkPipelineOptions.class) + .getFasterCopy()); + } + + public CoderTypeSerializer(Coder coder, boolean fasterCopy) { Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); + this.fasterCopy = fasterCopy; } @Override @@ -73,7 +71,7 @@ public boolean isImmutableType() { @Override public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); + return new CoderTypeSerializer<>(coder, fasterCopy); } @Override diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index d13e1c5faf6e..9a868500ae82 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -236,6 +236,10 @@ class ValidatesRunnerConfig { def sickbayTests = [ // TODO(https://github.com/apache/beam/issues/21306) 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testOnWindowTimestampSkew', + // Flink errors are not deterministic. Exception may just be + // org.apache.flink.runtime.operators.coordination.TaskNotRunningException: Task is not running, but in state FAILED + // instead of the actual cause. Real cause is visible in the logs. + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests' ] def createValidatesRunnerTask(Map m) { @@ -249,7 +253,7 @@ def createValidatesRunnerTask(Map m) { def pipelineOptionsArray = ["--runner=TestFlinkRunner", "--streaming=${config.streaming}", "--useDataStreamForBatch=${config.useDataStreamForBatch}", - "--parallelism=2", + "--parallelism=1", ] if (config.checkpointing) { pipelineOptionsArray.addAll([ @@ -266,6 +270,8 @@ def createValidatesRunnerTask(Map m) { ) // maxParallelForks decreased from 4 in order to avoid OOM errors maxParallelForks 2 + def flinkConfDir = System.getProperty("user.dir") + "/runners/flink/src/test/validatesRunnerConfig" + environment("FLINK_CONF_DIR", flinkConfDir) useJUnit { if (config.checkpointing) { includeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java index 102340329b6b..014b1f95fc92 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkExecutionEnvironments.java @@ -237,6 +237,16 @@ public static StreamExecutionEnvironment createStreamExecutionEnvironment( flinkStreamEnv.setParallelism(parallelism); if (options.getMaxParallelism() > 0) { flinkStreamEnv.setMaxParallelism(options.getMaxParallelism()); + } else if (!options.isStreaming()) { + // In Flink maxParallelism defines the number of keyGroups. + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L76) + // The default value (parallelism * 1.5) + // (see + // https://github.com/apache/flink/blob/e9dd4683f758b463d0b5ee18e49cecef6a70c5cf/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeAssignment.java#L137-L147) + // create a lot of skew so we force maxParallelism = parallelism in Batch mode. + LOG.info("Setting maxParallelism to {}", parallelism); + flinkStreamEnv.setMaxParallelism(parallelism); } // set parallelism in the options (required by some execution code) options.setParallelism(parallelism); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java index 519afa795bc3..901207a91f00 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineOptions.java @@ -262,7 +262,7 @@ public Long create(PipelineOptions options) { if (options.as(StreamingOptions.class).isStreaming()) { return 1000L; } else { - return 1000000L; + return 5000L; } } } @@ -382,6 +382,13 @@ public Long create(PipelineOptions options) { void setEnableStableInputDrain(Boolean enableStableInputDrain); + @Description( + "Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API.") + @Default.Boolean(true) + Boolean getForceSlotSharingGroup(); + + void setForceSlotSharingGroup(Boolean enableStableInputDrain); + static FlinkPipelineOptions defaults() { return PipelineOptionsFactory.as(FlinkPipelineOptions.class); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java new file mode 100644 index 000000000000..1683ced890c7 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingAggregationsTranslators.java @@ -0,0 +1,544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToFlinkKeyKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.PartialReduceBundleOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator; +import org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.AppliedCombineFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.transformations.TwoInputTransformation; +import org.apache.flink.util.Collector; + +public class FlinkStreamingAggregationsTranslators { + public static class ConcatenateAsIterable + extends Combine.CombineFn, Iterable> { + @Override + public Iterable createAccumulator() { + return new ArrayList<>(); + } + + @Override + public Iterable addInput(Iterable accumulator, T input) { + ArrayList arr = Lists.newArrayList(accumulator); + arr.add(input); + return arr; + } + + @Override + public Iterable mergeAccumulators(Iterable> accumulators) { + return Iterables.concat(accumulators); + } + + @Override + public Iterable extractOutput(Iterable accumulator) { + return accumulator; + } + + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return IterableCoder.of(inputCoder); + } + + @Override + public Coder> getDefaultOutputCoder(CoderRegistry registry, Coder inputCoder) { + return IterableCoder.of(inputCoder); + } + } + + private static + CombineFnBase.GlobalCombineFn toFinalFlinkCombineFn( + CombineFnBase.GlobalCombineFn combineFn, + Coder inputTCoder) { + + if (combineFn instanceof Combine.CombineFn) { + return new Combine.CombineFn() { + + @SuppressWarnings("unchecked") + final Combine.CombineFn fn = + (Combine.CombineFn) combineFn; + + @Override + public Object createAccumulator() { + return fn.createAccumulator(); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } + + @Override + public Object addInput(Object mutableAccumulator, Object input) { + return fn.mergeAccumulators(ImmutableList.of(mutableAccumulator, input)); + } + + @Override + public Object mergeAccumulators(Iterable accumulators) { + return fn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(Object accumulator) { + return fn.extractOutput(accumulator); + } + }; + } else if (combineFn instanceof CombineWithContext.CombineFnWithContext) { + return new CombineWithContext.CombineFnWithContext() { + + @SuppressWarnings("unchecked") + final CombineWithContext.CombineFnWithContext fn = + (CombineWithContext.CombineFnWithContext) combineFn; + + @Override + public Object createAccumulator(CombineWithContext.Context c) { + return fn.createAccumulator(c); + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return fn.getAccumulatorCoder(registry, inputTCoder); + } + + @Override + public Object addInput(Object accumulator, Object input, CombineWithContext.Context c) { + return fn.mergeAccumulators(ImmutableList.of(accumulator, input), c); + } + + @Override + public Object mergeAccumulators( + Iterable accumulators, CombineWithContext.Context c) { + return fn.mergeAccumulators(accumulators, c); + } + + @Override + public OutputT extractOutput(Object accumulator, CombineWithContext.Context c) { + return fn.extractOutput(accumulator, c); + } + }; + } + throw new IllegalArgumentException( + "Unsupported CombineFn implementation: " + combineFn.getClass()); + } + + /** + * Create a DoFnOperator instance that group elements per window and apply a combine function on + * them. + */ + public static + WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + SystemReduceFn reduceFn, + Map> sideInputTagMapping, + List> sideInputs) { + + // Naming + String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); + TupleTag> mainTag = new TupleTag<>("main output"); + + // input infos + PCollection> input = context.getInput(transform); + + @SuppressWarnings("unchecked") + WindowingStrategy windowingStrategy = + (WindowingStrategy) input.getWindowingStrategy(); + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + + // Coders + Coder keyCoder = inputKvCoder.getKeyCoder(); + + SingletonKeyedWorkItemCoder workItemCoder = + SingletonKeyedWorkItemCoder.of( + keyCoder, inputKvCoder.getValueCoder(), windowingStrategy.getWindowFn().windowCoder()); + + WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = + WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); + + // Key selector + WorkItemKeySelector workItemKeySelector = new WorkItemKeySelector<>(keyCoder); + + return new WindowDoFnOperator<>( + reduceFn, + fullName, + (Coder) windowedWorkItemCoder, + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, outputCoder, serializablePipelineOptions), + windowingStrategy, + sideInputTagMapping, + sideInputs, + context.getPipelineOptions(), + keyCoder, + workItemKeySelector); + } + + public static + WindowDoFnOperator getWindowedAggregateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + KvCoder inputKvCoder, + Coder>> outputCoder, + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { + + // Combining fn + SystemReduceFn reduceFn = + SystemReduceFn.combining( + inputKvCoder.getKeyCoder(), + AppliedCombineFn.withInputCoder( + combineFn, + context.getInput(transform).getPipeline().getCoderRegistry(), + inputKvCoder)); + + return getWindowedAggregateDoFnOperator( + context, transform, inputKvCoder, outputCoder, reduceFn, sideInputTagMapping, sideInputs); + } + + private static class FlattenIterable + implements FlatMapFunction< + WindowedValue>>>, + WindowedValue>>> { + @Override + public void flatMap( + WindowedValue>>> w, + Collector>>> collector) + throws Exception { + WindowedValue>> flattened = + w.withValue(KV.of(w.getValue().getKey(), Iterables.concat(w.getValue().getValue()))); + collector.collect(flattened); + } + } + + public static + SingleOutputStreamOperator>> getBatchCombinePerKeyOperator( + FlinkStreamingTranslationContext context, + PCollection> input, + Map> sideInputTagMapping, + List> sideInputs, + Coder>> windowedAccumCoder, + CombineFnBase.GlobalCombineFn combineFn, + WindowDoFnOperator finalDoFnOperator, + TypeInformation>> outputTypeInfo) { + + String fullName = FlinkStreamingTransformTranslators.getCurrentTransformName(context); + DataStream>> inputDataStream = context.getInputDataStream(input); + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + + TupleTag> mainTag = new TupleTag<>("main output"); + String partialName = "Combine: " + fullName; + + KvToFlinkKeyKeySelector accumKeySelector = + new KvToFlinkKeyKeySelector<>(inputKvCoder.getKeyCoder()); + + CoderTypeInformation>> partialTypeInfo = + new CoderTypeInformation<>(windowedAccumCoder, context.getPipelineOptions()); + + PartialReduceBundleOperator partialDoFnOperator = + new PartialReduceBundleOperator<>( + combineFn, + fullName, + context.getWindowedInputCoder(input), + mainTag, + Collections.emptyList(), + new DoFnOperator.MultiOutputOutputManagerFactory<>( + mainTag, windowedAccumCoder, serializablePipelineOptions), + input.getWindowingStrategy(), + sideInputTagMapping, + sideInputs, + context.getPipelineOptions()); + + if (sideInputs.isEmpty()) { + return inputDataStream + .transform(partialName, partialTypeInfo, partialDoFnOperator) + .uid(partialName) + .name(partialName) + .keyBy(accumKeySelector) + .transform(fullName, outputTypeInfo, finalDoFnOperator) + .uid(fullName) + .name(fullName); + } else { + + Tuple2>, DataStream> transformSideInputs = + FlinkStreamingTransformTranslators.transformSideInputs(sideInputs, context); + + TwoInputTransformation< + WindowedValue>, RawUnionValue, WindowedValue>> + rawPartialFlinkTransform = + new TwoInputTransformation<>( + inputDataStream.getTransformation(), + transformSideInputs.f1.broadcast().getTransformation(), + partialName, + partialDoFnOperator, + partialTypeInfo, + inputDataStream.getParallelism()); + + SingleOutputStreamOperator>> partialyCombinedStream = + new SingleOutputStreamOperator>>( + inputDataStream.getExecutionEnvironment(), + rawPartialFlinkTransform) {}; // we have to cheat around the ctor being protected + + inputDataStream.getExecutionEnvironment().addOperator(rawPartialFlinkTransform); + + return buildTwoInputStream( + partialyCombinedStream.keyBy(accumKeySelector), + transformSideInputs.f1, + fullName, + finalDoFnOperator, + outputTypeInfo); + } + } + + /** + * Creates a two-steps GBK operation. Elements are first aggregated locally to save on serialized + * size since in batch it's very likely that all the elements will be within the same window and + * pane. The only difference with batchCombinePerKey is the nature of the SystemReduceFn used. It + * uses SystemReduceFn.buffering() instead of SystemReduceFn.combining() so that new element can + * simply be appended without accessing the existing state. + */ + public static + SingleOutputStreamOperator>>> batchGroupByKey( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>>> transform) { + + Map> sideInputTagMapping = new HashMap<>(); + List> sideInputs = Collections.emptyList(); + + PCollection> input = context.getInput(transform); + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + + SerializablePipelineOptions serializablePipelineOptions = + new SerializablePipelineOptions(context.getPipelineOptions()); + + TypeInformation>>> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + Coder> accumulatorCoder = IterableCoder.of(inputKvCoder.getValueCoder()); + KvCoder> accumKvCoder = + KvCoder.of(inputKvCoder.getKeyCoder(), accumulatorCoder); + + Coder>>> windowedAccumCoder = + WindowedValue.getFullCoder( + accumKvCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + + Coder>>>> outputCoder = + WindowedValue.getFullCoder( + KvCoder.of(inputKvCoder.getKeyCoder(), IterableCoder.of(accumulatorCoder)), + input.getWindowingStrategy().getWindowFn().windowCoder()); + + TypeInformation>>>> accumulatedTypeInfo = + new CoderTypeInformation<>( + WindowedValue.getFullCoder( + KvCoder.of( + inputKvCoder.getKeyCoder(), + IterableCoder.of(IterableCoder.of(inputKvCoder.getValueCoder()))), + input.getWindowingStrategy().getWindowFn().windowCoder()), + serializablePipelineOptions); + + // final aggregation + WindowDoFnOperator, Iterable>> finalDoFnOperator = + getWindowedAccumulateDoFnOperator( + context, transform, accumKvCoder, outputCoder, sideInputTagMapping, sideInputs); + + return getBatchCombinePerKeyOperator( + context, + input, + sideInputTagMapping, + sideInputs, + windowedAccumCoder, + new ConcatenateAsIterable<>(), + finalDoFnOperator, + accumulatedTypeInfo) + .flatMap(new FlattenIterable<>(), outputTypeInfo) + .name("concatenate"); + } + + private static + WindowDoFnOperator, Iterable>> + getWindowedAccumulateDoFnOperator( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>>> + transform, + KvCoder> accumKvCoder, + Coder>>>> outputCoder, + Map> sideInputTagMapping, + List> sideInputs) { + + // Combining fn + SystemReduceFn< + K, + Iterable, + Iterable>, + Iterable>, + BoundedWindow> + reduceFn = SystemReduceFn.buffering(accumKvCoder.getValueCoder()); + + return getWindowedAggregateDoFnOperator( + context, transform, accumKvCoder, outputCoder, reduceFn, sideInputTagMapping, sideInputs); + } + + public static + SingleOutputStreamOperator>> batchCombinePerKey( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn, + Map> sideInputTagMapping, + List> sideInputs) { + + Coder>> windowedAccumCoder; + KvCoder accumKvCoder; + + PCollection> input = context.getInput(transform); + KvCoder inputKvCoder = (KvCoder) input.getCoder(); + TypeInformation>> outputTypeInfo = + context.getTypeInfo(context.getOutput(transform)); + + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); + + Coder accumulatorCoder; + try { + accumulatorCoder = + combineFn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), inputKvCoder.getValueCoder()); + + accumKvCoder = KvCoder.of(inputKvCoder.getKeyCoder(), accumulatorCoder); + + windowedAccumCoder = + WindowedValue.getFullCoder( + accumKvCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + + // final aggregation from AccumT to OutputT + WindowDoFnOperator finalDoFnOperator = + getWindowedAggregateDoFnOperator( + context, + transform, + accumKvCoder, + outputCoder, + toFinalFlinkCombineFn(combineFn, inputKvCoder.getValueCoder()), + sideInputTagMapping, + sideInputs); + + return getBatchCombinePerKeyOperator( + context, + context.getInput(transform), + sideInputTagMapping, + sideInputs, + windowedAccumCoder, + combineFn, + finalDoFnOperator, + outputTypeInfo); + } + + @SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) + }) + public static + SingleOutputStreamOperator>> buildTwoInputStream( + KeyedStream>, FlinkKey> keyedStream, + DataStream sideInputStream, + String name, + WindowDoFnOperator operator, + TypeInformation>> outputTypeInfo) { + // we have to manually construct the two-input transform because we're not + // allowed to have only one input keyed, normally. + TwoInputTransformation< + WindowedValue>, RawUnionValue, WindowedValue>> + rawFlinkTransform = + new TwoInputTransformation<>( + keyedStream.getTransformation(), + sideInputStream.broadcast().getTransformation(), + name, + operator, + outputTypeInfo, + keyedStream.getParallelism()); + + rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); + rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); + + @SuppressWarnings({"unchecked", "rawtypes"}) + SingleOutputStreamOperator>> outDataStream = + new SingleOutputStreamOperator( + keyedStream.getExecutionEnvironment(), + rawFlinkTransform) {}; // we have to cheat around the ctor being protected + + keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + + return outDataStream; + } + + public static + SingleOutputStreamOperator>> batchCombinePerKeyNoSideInputs( + FlinkStreamingTranslationContext context, + PTransform>, PCollection>> transform, + CombineFnBase.GlobalCombineFn combineFn) { + return batchCombinePerKey( + context, transform, combineFn, new HashMap<>(), Collections.emptyList()); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java index 3ed00a3c5ef2..0607838987f1 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java @@ -20,7 +20,6 @@ import static org.apache.beam.sdk.util.construction.PTransformTranslation.WRITE_FILES_TRANSFORM_URN; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -28,7 +27,7 @@ import java.util.Objects; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.ShardedKeyCoder; @@ -399,7 +398,7 @@ private Map> generateShardedKeys(int key, int shard // create effective key in the same way Beam/Flink will do so we can see if it gets // allocated to the partition we want - ByteBuffer effectiveKey = FlinkKeyUtils.encodeKey(shk, shardedKeyCoder); + FlinkKey effectiveKey = FlinkKey.of(shk, shardedKeyCoder); int partition = KeyGroupRangeAssignment.assignKeyToParallelOperator( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java index 836c825300db..a74be9f7e9e0 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java @@ -27,7 +27,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auto.service.AutoService; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -41,14 +40,15 @@ import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory; import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator; -import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector; -import org.apache.beam.runners.flink.translation.wrappers.streaming.SdfByteBufferKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToFlinkKeyKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.SdfFlinkKeyKeySelector; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator; import org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector; @@ -430,24 +430,11 @@ private SingleOutputStreamOperator>>> add WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder()); - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); - - DataStream>> workItemStream = - inputDataStream - .flatMap( - new FlinkStreamingTransformTranslators.ToKeyedWorkItem<>( - context.getPipelineOptions())) - .returns(workItemTypeInfo) - .name("ToKeyedWorkItem"); - WorkItemKeySelector keySelector = - new WorkItemKeySelector<>( - inputElementCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); + new WorkItemKeySelector<>(inputElementCoder.getKeyCoder()); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); + KeyedStream>, FlinkKey> keyedWorkItemStream = + inputDataStream.keyBy(new KvToFlinkKeyKeySelector(inputElementCoder.getKeyCoder())); SystemReduceFn, Iterable, BoundedWindow> reduceFn = SystemReduceFn.buffering(inputElementCoder.getValueCoder()); @@ -841,9 +828,7 @@ private void translateExecutableStage( } if (stateful) { keyCoder = ((KvCoder) valueCoder).getKeyCoder(); - keySelector = - new KvToByteBufferKeySelector( - keyCoder, new SerializablePipelineOptions(context.getPipelineOptions())); + keySelector = new KvToFlinkKeyKeySelector(keyCoder); } else { // For an SDF, we know that the input element should be // KV>, size>. We are going to use the element @@ -857,9 +842,7 @@ private void translateExecutableStage( valueCoder.getClass().getSimpleName())); } keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder(); - keySelector = - new SdfByteBufferKeySelector( - keyCoder, new SerializablePipelineOptions(context.getPipelineOptions())); + keySelector = new SdfFlinkKeyKeySelector(keyCoder); } inputDataStream = inputDataStream.keyBy(keySelector); } @@ -872,7 +855,7 @@ private void translateExecutableStage( tagsToIds, new SerializablePipelineOptions(context.getPipelineOptions())); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new ExecutableStageDoFnOperator<>( transform.getUniqueName(), windowedInputCoder, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index f9089d11a25e..36cf035a33be 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -33,11 +33,12 @@ import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.ImpulseSourceFunction; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; -import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector; +import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToFlinkKeyKeySelector; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem; import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder; import org.apache.beam.runners.flink.translation.wrappers.streaming.SplittableDoFnOperator; @@ -76,7 +77,6 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.construction.PTransformTranslation; @@ -105,8 +105,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.CheckpointListener; import org.apache.flink.runtime.state.FunctionInitializationContext; @@ -169,6 +169,8 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new TestStreamTranslator()); } + private static final String FORCED_SLOT_GROUP = "beam"; + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getTranslator( PTransform transform) { @Nullable String urn = PTransformTranslation.urnForTransformOrNull(transform); @@ -176,7 +178,7 @@ public static FlinkStreamingPipelineTranslator.StreamTransformTranslator getT } @SuppressWarnings("unchecked") - private static String getCurrentTransformName(FlinkStreamingTranslationContext context) { + public static String getCurrentTransformName(FlinkStreamingTranslationContext context) { return context.getCurrentTransform().getFullName(); } @@ -261,17 +263,17 @@ public void translateNode( } static class ValueWithRecordIdKeySelector - implements KeySelector>, ByteBuffer>, - ResultTypeQueryable { + implements KeySelector>, FlinkKey>, + ResultTypeQueryable { @Override - public ByteBuffer getKey(WindowedValue> value) throws Exception { - return ByteBuffer.wrap(value.getValue().getId()); + public FlinkKey getKey(WindowedValue> value) throws Exception { + return FlinkKey.of(ByteBuffer.wrap(value.getValue().getId())); } @Override - public TypeInformation getProducedType() { - return new GenericTypeInfo<>(ByteBuffer.class); + public TypeInformation getProducedType() { + return ValueTypeInfo.of(FlinkKey.class); } } @@ -309,7 +311,7 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) WindowedValue.getFullCoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE), context.getPipelineOptions()); - final SingleOutputStreamOperator> impulseOperator; + SingleOutputStreamOperator> impulseOperator; if (context.isStreaming()) { long shutdownAfterIdleSourcesMs = context @@ -328,6 +330,14 @@ void translateNode(Impulse transform, FlinkStreamingTranslationContext context) .getExecutionEnvironment() .fromSource(impulseSource, WatermarkStrategy.noWatermarks(), "Impulse") .returns(typeInfo); + + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { + impulseOperator = impulseOperator.slotSharingGroup(FORCED_SLOT_GROUP); + } } context.setOutputDataStream(context.getOutput(transform), impulseOperator); } @@ -389,14 +399,25 @@ public void translateNode( new SerializablePipelineOptions(context.getPipelineOptions()), parallelism); - DataStream> source; + TypeInformation> typeInfo = context.getTypeInfo(output); + + SingleOutputStreamOperator> source; try { source = context .getExecutionEnvironment() .fromSource( flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo) - .uid(fullName); + .uid(fullName) + .returns(typeInfo); + + if (!context.isStreaming() + && context + .getPipelineOptions() + .as(FlinkPipelineOptions.class) + .getForceSlotSharingGroup()) { + source = source.slotSharingGroup(FORCED_SLOT_GROUP); + } } catch (Exception e) { throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e); } @@ -427,7 +448,7 @@ public RawUnionValue map(T o) throws Exception { } } - private static Tuple2>, DataStream> + public static Tuple2>, DataStream> transformSideInputs( Collection> sideInputs, FlinkStreamingTranslationContext context) { @@ -492,7 +513,7 @@ public RawUnionValue map(T o) throws Exception { static class ParDoTranslationHelper { interface DoFnOperatorFactory { - DoFnOperator createDoFnOperator( + DoFnOperator createDoFnOperator( DoFn doFn, String stepName, List> sideInputs, @@ -569,9 +590,7 @@ static void translateParDo( // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed keyCoder = ((KvCoder) input.getCoder()).getKeyCoder(); - keySelector = - new KvToByteBufferKeySelector( - keyCoder, new SerializablePipelineOptions(context.getPipelineOptions())); + keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); final PTransform> producer = context.getProducer(input); final String previousUrn = producer != null @@ -588,9 +607,7 @@ static void translateParDo( } else if (doFn instanceof SplittableParDoViaKeyedWorkItems.ProcessFn) { // we know that it is keyed on byte[] keyCoder = ByteArrayCoder.of(); - keySelector = - new WorkItemKeySelector<>( - keyCoder, new SerializablePipelineOptions(context.getPipelineOptions())); + keySelector = new WorkItemKeySelector<>(keyCoder); stateful = true; } @@ -600,7 +617,7 @@ static void translateParDo( context.getPipelineOptions()); if (sideInputs.isEmpty()) { - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, getCurrentTransformName(context), @@ -627,7 +644,7 @@ static void translateParDo( Tuple2>, DataStream> transformedSideInputs = transformSideInputs(sideInputs, context); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, getCurrentTransformName(context), @@ -932,86 +949,47 @@ public void translateNode( FlinkStreamingTranslationContext context) { PCollection> input = context.getInput(transform); - @SuppressWarnings("unchecked") WindowingStrategy windowingStrategy = (WindowingStrategy) input.getWindowingStrategy(); - KvCoder inputKvCoder = (KvCoder) input.getCoder(); - - SingletonKeyedWorkItemCoder workItemCoder = - SingletonKeyedWorkItemCoder.of( - inputKvCoder.getKeyCoder(), - ByteArrayCoder.of(), - input.getWindowingStrategy().getWindowFn().windowCoder()); - DataStream>> inputDataStream = context.getInputDataStream(input); + String fullName = getCurrentTransformName(context); - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = - WindowedValue.getFullCoder( - workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); - - DataStream>> workItemStream = - inputDataStream - .flatMap( - new ToBinaryKeyedWorkItem<>( - context.getPipelineOptions(), inputKvCoder.getValueCoder())) - .returns(workItemTypeInfo) - .name("ToBinaryKeyedWorkItem"); - - WorkItemKeySelector keySelector = - new WorkItemKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); - - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); - - SystemReduceFn, Iterable, BoundedWindow> reduceFn = - SystemReduceFn.buffering(ByteArrayCoder.of()); - - Coder>>> outputCoder = - WindowedValue.getFullCoder( - KvCoder.of(inputKvCoder.getKeyCoder(), IterableCoder.of(ByteArrayCoder.of())), - windowingStrategy.getWindowFn().windowCoder()); - - TypeInformation>>> outputTypeInfo = - new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); - - TupleTag>> mainTag = new TupleTag<>("main output"); + SingleOutputStreamOperator>>> outDataStream; + // Pre-aggregate before shuffle similar to group combine + if (!context.isStreaming()) { + outDataStream = FlinkStreamingAggregationsTranslators.batchGroupByKey(context, transform); + } else { + // No pre-aggregation in Streaming mode. + KvToFlinkKeyKeySelector keySelector = + new KvToFlinkKeyKeySelector<>(inputKvCoder.getKeyCoder()); - String fullName = getCurrentTransformName(context); - WindowDoFnOperator> doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - keySelector); + Coder>>> outputCoder = + WindowedValue.getFullCoder( + KvCoder.of( + inputKvCoder.getKeyCoder(), IterableCoder.of(inputKvCoder.getValueCoder())), + windowingStrategy.getWindowFn().windowCoder()); - final SingleOutputStreamOperator>>> outDataStream = - keyedWorkItemStream - .transform(fullName, outputTypeInfo, doFnOperator) - .uid(fullName) - .flatMap( - new ToGroupByKeyResult<>( - context.getPipelineOptions(), inputKvCoder.getValueCoder())) - .returns(context.getTypeInfo(context.getOutput(transform))) - .name("ToGBKResult"); + TypeInformation>>> outputTypeInfo = + new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); + WindowDoFnOperator> doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + SystemReduceFn.buffering(inputKvCoder.getValueCoder()), + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + inputDataStream + .keyBy(keySelector) + .transform(fullName, outputTypeInfo, doFnOperator) + .uid(fullName); + } context.setOutputDataStream(context.getOutput(transform), outDataStream); } } @@ -1042,128 +1020,79 @@ public void translateNode( PTransform>, PCollection>> transform, FlinkStreamingTranslationContext context) { String fullName = getCurrentTransformName(context); - PCollection> input = context.getInput(transform); - @SuppressWarnings("unchecked") - WindowingStrategy windowingStrategy = - (WindowingStrategy) input.getWindowingStrategy(); + PCollection> input = context.getInput(transform); KvCoder inputKvCoder = (KvCoder) input.getCoder(); - - SingletonKeyedWorkItemCoder workItemCoder = - SingletonKeyedWorkItemCoder.of( - inputKvCoder.getKeyCoder(), - inputKvCoder.getValueCoder(), - input.getWindowingStrategy().getWindowFn().windowCoder()); + Coder keyCoder = inputKvCoder.getKeyCoder(); + Coder>> outputCoder = + context.getWindowedInputCoder(context.getOutput(transform)); DataStream>> inputDataStream = context.getInputDataStream(input); - WindowedValue.FullWindowedValueCoder> windowedWorkItemCoder = - WindowedValue.getFullCoder( - workItemCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - - CoderTypeInformation>> workItemTypeInfo = - new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions()); - - DataStream>> workItemStream = - inputDataStream - .flatMap(new ToKeyedWorkItem<>(context.getPipelineOptions())) - .returns(workItemTypeInfo) - .name("ToKeyedWorkItem"); - - WorkItemKeySelector keySelector = - new WorkItemKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions())); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy(keySelector); - - GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); - SystemReduceFn reduceFn = - SystemReduceFn.combining( - inputKvCoder.getKeyCoder(), - AppliedCombineFn.withInputCoder( - combineFn, input.getPipeline().getCoderRegistry(), inputKvCoder)); + @SuppressWarnings("unchecked") + GlobalCombineFn combineFn = ((Combine.PerKey) transform).getFn(); - Coder>> outputCoder = - context.getWindowedInputCoder(context.getOutput(transform)); TypeInformation>> outputTypeInfo = context.getTypeInfo(context.getOutput(transform)); + @SuppressWarnings("unchecked") List> sideInputs = ((Combine.PerKey) transform).getSideInputs(); + KeyedStream>, FlinkKey> keyedStream = + inputDataStream.keyBy(new KvToFlinkKeyKeySelector<>(keyCoder)); + if (sideInputs.isEmpty()) { - TupleTag> mainTag = new TupleTag<>("main output"); - WindowDoFnOperator doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - (Coder) windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - keySelector); - - SingleOutputStreamOperator>> outDataStream = - keyedWorkItemStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + SingleOutputStreamOperator>> outDataStream; + + if (!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKeyNoSideInputs( + context, transform, combineFn); + } else { + WindowDoFnOperator doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + new HashMap<>(), + Collections.emptyList()); + + outDataStream = + keyedStream.transform(fullName, outputTypeInfo, doFnOperator).uid(fullName); + } + context.setOutputDataStream(context.getOutput(transform), outDataStream); } else { Tuple2>, DataStream> transformSideInputs = transformSideInputs(sideInputs, context); + SingleOutputStreamOperator>> outDataStream; - TupleTag> mainTag = new TupleTag<>("main output"); - WindowDoFnOperator doFnOperator = - new WindowDoFnOperator<>( - reduceFn, - fullName, - (Coder) windowedWorkItemCoder, - mainTag, - Collections.emptyList(), - new DoFnOperator.MultiOutputOutputManagerFactory<>( - mainTag, - outputCoder, - new SerializablePipelineOptions(context.getPipelineOptions())), - windowingStrategy, - transformSideInputs.f0, - sideInputs, - context.getPipelineOptions(), - inputKvCoder.getKeyCoder(), - keySelector); - - // we have to manually contruct the two-input transform because we're not - // allowed to have only one input keyed, normally. - - TwoInputTransformation< - WindowedValue>, - RawUnionValue, - WindowedValue>> - rawFlinkTransform = - new TwoInputTransformation<>( - keyedWorkItemStream.getTransformation(), - transformSideInputs.f1.broadcast().getTransformation(), - transform.getName(), - doFnOperator, - outputTypeInfo, - keyedWorkItemStream.getParallelism()); - - rawFlinkTransform.setStateKeyType(keyedWorkItemStream.getKeyType()); - rawFlinkTransform.setStateKeySelectors(keyedWorkItemStream.getKeySelector(), null); - - @SuppressWarnings({"unchecked", "rawtypes"}) - SingleOutputStreamOperator>> outDataStream = - new SingleOutputStreamOperator( - keyedWorkItemStream.getExecutionEnvironment(), - rawFlinkTransform) {}; // we have to cheat around the ctor being protected - - keyedWorkItemStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + if (!context.isStreaming()) { + outDataStream = + FlinkStreamingAggregationsTranslators.batchCombinePerKey( + context, transform, combineFn, transformSideInputs.f0, sideInputs); + } else { + WindowDoFnOperator doFnOperator = + FlinkStreamingAggregationsTranslators.getWindowedAggregateDoFnOperator( + context, + transform, + inputKvCoder, + outputCoder, + combineFn, + transformSideInputs.f0, + sideInputs); + + outDataStream = + FlinkStreamingAggregationsTranslators.buildTwoInputStream( + keyedStream, + transformSideInputs.f1, + transform.getName(), + doFnOperator, + outputTypeInfo); + } context.setOutputDataStream(context.getOutput(transform), outDataStream); } @@ -1210,11 +1139,8 @@ public void translateNode( .returns(workItemTypeInfo) .name("ToKeyedWorkItem"); - KeyedStream>, ByteBuffer> keyedWorkItemStream = - workItemStream.keyBy( - new WorkItemKeySelector<>( - inputKvCoder.getKeyCoder(), - new SerializablePipelineOptions(context.getPipelineOptions()))); + KeyedStream>, FlinkKey> keyedWorkItemStream = + workItemStream.keyBy(new WorkItemKeySelector<>(inputKvCoder.getKeyCoder())); context.setOutputDataStream(context.getOutput(transform), keyedWorkItemStream); } @@ -1328,115 +1254,6 @@ public void flatMap(T t, Collector collector) throws Exception { } } - static class ToKeyedWorkItem - extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { - - private final SerializablePipelineOptions options; - - ToKeyedWorkItem(PipelineOptions options) { - this.options = new SerializablePipelineOptions(options); - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue> inWithMultipleWindows, - Collector>> out) { - - // we need to wrap each one work item per window for now - // since otherwise the PushbackSideInputRunner will not correctly - // determine whether side inputs are ready - // - // this is tracked as https://github.com/apache/beam/issues/18358 - for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { - SingletonKeyedWorkItem workItem = - new SingletonKeyedWorkItem<>( - in.getValue().getKey(), in.withValue(in.getValue().getValue())); - - out.collect(in.withValue(workItem)); - } - } - } - - static class ToBinaryKeyedWorkItem - extends RichFlatMapFunction< - WindowedValue>, WindowedValue>> { - - private final SerializablePipelineOptions options; - private final Coder valueCoder; - - ToBinaryKeyedWorkItem(PipelineOptions options, Coder valueCoder) { - this.options = new SerializablePipelineOptions(options); - this.valueCoder = valueCoder; - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue> inWithMultipleWindows, - Collector>> out) - throws CoderException { - - // we need to wrap each one work item per window for now - // since otherwise the PushbackSideInputRunner will not correctly - // determine whether side inputs are ready - // - // this is tracked as https://github.com/apache/beam/issues/18358 - for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { - final byte[] binaryValue = - CoderUtils.encodeToByteArray(valueCoder, in.getValue().getValue()); - final SingletonKeyedWorkItem workItem = - new SingletonKeyedWorkItem<>(in.getValue().getKey(), in.withValue(binaryValue)); - out.collect(in.withValue(workItem)); - } - } - } - - static class ToGroupByKeyResult - extends RichFlatMapFunction< - WindowedValue>>, WindowedValue>>> { - - private final SerializablePipelineOptions options; - private final Coder valueCoder; - - ToGroupByKeyResult(PipelineOptions options, Coder valueCoder) { - this.options = new SerializablePipelineOptions(options); - this.valueCoder = valueCoder; - } - - @Override - public void open(Configuration parameters) { - // Initialize FileSystems for any coders which may want to use the FileSystem, - // see https://issues.apache.org/jira/browse/BEAM-8303 - FileSystems.setDefaultPipelineOptions(options.get()); - } - - @Override - public void flatMap( - WindowedValue>> element, - Collector>>> collector) - throws CoderException { - final List result = new ArrayList<>(); - for (byte[] binaryValue : element.getValue().getValue()) { - result.add(CoderUtils.decodeFromByteArray(valueCoder, binaryValue)); - } - collector.collect(element.withValue(KV.of(element.getValue().getKey(), result))); - } - } - /** Registers classes specialized to the Flink runner. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class FlinkTransformsRegistrar implements TransformPayloadTranslatorRegistrar { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/FlinkKey.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/FlinkKey.java new file mode 100644 index 000000000000..6a5e8d0458f3 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/FlinkKey.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.adapter; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import java.io.IOException; +import java.nio.ByteBuffer; +import javax.annotation.Nullable; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hashing; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.types.Value; + +public class FlinkKey implements Value { + + private final CoderTypeSerializer serializer; + + @SuppressWarnings("initialization.fields.uninitialized") + private ByteBuffer underlying; + + public FlinkKey() { + this.serializer = new CoderTypeSerializer<>(ByteArrayCoder.of(), false); + } + + private FlinkKey(ByteBuffer underlying) { + this(); + this.underlying = underlying; + } + + public ByteBuffer getSerializedKey() { + return underlying; + } + + public static FlinkKey of(ByteBuffer bytes) { + return new FlinkKey(bytes); + } + + public static FlinkKey of(K key, Coder coder) { + return new FlinkKey(FlinkKeyUtils.encodeKey(key, coder)); + } + + @Override + public void write(DataOutputView out) throws IOException { + checkNotNull(underlying); + serializer.serialize(underlying.array(), out); + } + + @Override + public void read(DataInputView in) throws IOException { + this.underlying = ByteBuffer.wrap(serializer.deserialize(in)); + } + + public K getKey(Coder coder) { + return FlinkKeyUtils.decodeKey(underlying, coder); + } + + @Override + public int hashCode() { + // return underlying.hashCode(); + return Hashing.murmur3_128().hashBytes(underlying.array()).asInt(); + } + + @Override + public boolean equals(@Nullable Object obj) { + return obj instanceof FlinkKey && ((FlinkKey) obj).underlying.equals(underlying); + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 1072702c3e66..01b12cfa717a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -24,9 +24,9 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; @@ -56,6 +56,7 @@ import org.apache.beam.runners.core.TimerInternals.TimerData; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate; import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; @@ -144,12 +145,14 @@ "keyfor", "nullness" }) // TODO(https://github.com/apache/beam/issues/20497) -public class DoFnOperator extends AbstractStreamOperator> - implements OneInputStreamOperator, WindowedValue>, - TwoInputStreamOperator, RawUnionValue, WindowedValue>, - Triggerable { +public class DoFnOperator + extends AbstractStreamOperator> + implements OneInputStreamOperator, WindowedValue>, + TwoInputStreamOperator, RawUnionValue, WindowedValue>, + Triggerable { private static final Logger LOG = LoggerFactory.getLogger(DoFnOperator.class); + private final boolean isStreaming; protected DoFn doFn; @@ -270,7 +273,7 @@ public class DoFnOperator extends AbstractStreamOperator doFn, + @Nullable DoFn doFn, String stepName, Coder> inputWindowedCoder, Map, Coder> outputCoders, @@ -281,8 +284,8 @@ public DoFnOperator( Map> sideInputTagMapping, Collection> sideInputs, PipelineOptions options, - Coder keyCoder, - KeySelector, ?> keySelector, + @Nullable Coder keyCoder, + @Nullable KeySelector, ?> keySelector, DoFnSchemaInformation doFnSchemaInformation, Map> sideInputMapping) { this.doFn = doFn; @@ -294,6 +297,7 @@ public DoFnOperator( this.sideInputTagMapping = sideInputTagMapping; this.sideInputs = sideInputs; this.serializedOptions = new SerializablePipelineOptions(options); + this.isStreaming = serializedOptions.get().as(FlinkPipelineOptions.class).isStreaming(); this.windowingStrategy = windowingStrategy; this.outputManagerFactory = outputManagerFactory; @@ -358,6 +362,11 @@ protected DoFn getDoFn() { return doFn; } + protected Iterable> preProcess(WindowedValue input) { + // Assume Input is PreInputT + return Collections.singletonList((WindowedValue) input); + } + // allow overriding this, for example SplittableDoFnOperator will not create a // stateful DoFn runner because ProcessFn, which is used for executing a Splittable DoFn // doesn't play by the normal DoFn rules and WindowDoFnOperator uses LateDataDroppingDoFnRunner @@ -417,6 +426,10 @@ public void setup( super.setup(containingTask, config, output); } + protected boolean shoudBundleElements() { + return isStreaming; + } + @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); @@ -465,7 +478,10 @@ public void initializeState(StateInitializationContext context) throws Exception if (keyCoder != null) { keyedStateInternals = new FlinkStateInternals<>( - (KeyedStateBackend) getKeyedStateBackend(), keyCoder, serializedOptions); + (KeyedStateBackend) getKeyedStateBackend(), + keyCoder, + windowingStrategy.getWindowFn().windowCoder(), + serializedOptions); if (timerService == null) { timerService = @@ -595,7 +611,10 @@ private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAc if (doFn != null) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); FlinkStateInternals.EarlyBinder earlyBinder = - new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions); + new FlinkStateInternals.EarlyBinder( + getKeyedStateBackend(), + serializedOptions, + windowingStrategy.getWindowFn().windowCoder()); for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) { StateSpec spec = (StateSpec) signature.stateDeclarations().get(value.id()).field().get(doFn); @@ -727,30 +746,34 @@ protected final void setBundleFinishedCallback(Runnable callback) { } @Override - public final void processElement(StreamRecord> streamRecord) { - checkInvokeStartBundle(); - LOG.trace("Processing element {} in {}", streamRecord.getValue().getValue(), doFn.getClass()); - long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; - doFnRunner.processElement(streamRecord.getValue()); - checkInvokeFinishBundleByCount(); - emitWatermarkIfHoldChanged(oldHold); + public final void processElement(StreamRecord> streamRecord) { + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + LOG.trace("Processing element {} in {}", streamRecord.getValue().getValue(), doFn.getClass()); + long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; + doFnRunner.processElement(e); + checkInvokeFinishBundleByCount(); + emitWatermarkIfHoldChanged(oldHold); + } } @Override - public final void processElement1(StreamRecord> streamRecord) + public final void processElement1(StreamRecord> streamRecord) throws Exception { - checkInvokeStartBundle(); - Iterable> justPushedBack = - pushbackDoFnRunner.processElementInReadyWindows(streamRecord.getValue()); + for (WindowedValue e : preProcess(streamRecord.getValue())) { + checkInvokeStartBundle(); + Iterable> justPushedBack = + pushbackDoFnRunner.processElementInReadyWindows(e); - long min = pushedBackWatermark; - for (WindowedValue pushedBackValue : justPushedBack) { - min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); - pushedBackElementsHandler.pushBack(pushedBackValue); - } - pushedBackWatermark = min; + long min = pushedBackWatermark; + for (WindowedValue pushedBackValue : justPushedBack) { + min = Math.min(min, pushedBackValue.getTimestamp().getMillis()); + pushedBackElementsHandler.pushBack(pushedBackValue); + } + pushedBackWatermark = min; - checkInvokeFinishBundleByCount(); + checkInvokeFinishBundleByCount(); + } } /** @@ -789,7 +812,9 @@ public final void processElement2(StreamRecord streamRecord) thro WindowedValue element = it.next(); // we need to set the correct key in case the operator is // a (keyed) window operator - setKeyContextElement1(new StreamRecord<>(element)); + if (keySelector != null) { + setCurrentKey(keySelector.getKey(element)); + } Iterable> justPushedBack = pushbackDoFnRunner.processElementInReadyWindows(element); @@ -969,6 +994,9 @@ private void checkInvokeStartBundle() { @SuppressWarnings("NonAtomicVolatileUpdate") @SuppressFBWarnings("VO_VOLATILE_INCREMENT") private void checkInvokeFinishBundleByCount() { + if (!shoudBundleElements()) { + return; + } // We do not access this statement concurrently, but we want to make sure that each thread // sees the latest value, which is why we use volatile. See the class field section above // for more information. @@ -982,6 +1010,9 @@ private void checkInvokeFinishBundleByCount() { /** Check whether invoke finishBundle by timeout. */ private void checkInvokeFinishBundleByTime() { + if (!shoudBundleElements()) { + return; + } long now = getProcessingTimeService().getCurrentProcessingTime(); if (now - lastFinishBundleTime >= maxBundleTimeMills) { invokeFinishBundle(); @@ -1045,7 +1076,7 @@ public void prepareSnapshotPreBarrier(long checkpointId) { } @Override - public final void snapshotState(StateSnapshotContext context) throws Exception { + public void snapshotState(StateSnapshotContext context) throws Exception { if (checkpointStats != null) { checkpointStats.snapshotStart(context.getCheckpointId()); } @@ -1117,19 +1148,19 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception { } @Override - public void onEventTime(InternalTimer timer) { + public void onEventTime(InternalTimer timer) { checkInvokeStartBundle(); fireTimerInternal(timer.getKey(), timer.getNamespace()); } @Override - public void onProcessingTime(InternalTimer timer) { + public void onProcessingTime(InternalTimer timer) { checkInvokeStartBundle(); fireTimerInternal(timer.getKey(), timer.getNamespace()); } // allow overriding this in ExecutableStageDoFnOperator to set the key context - protected void fireTimerInternal(ByteBuffer key, TimerData timerData) { + protected void fireTimerInternal(FlinkKey key, TimerData timerData) { long oldHold = keyCoder != null ? keyedStateInternals.minWatermarkHoldMs() : -1L; fireTimer(timerData); emitWatermarkIfHoldChanged(oldHold); @@ -1210,6 +1241,8 @@ public static class BufferedOutputManager implements DoFnRunners.Output */ private final Lock bufferLock; + private final boolean isStreaming; + private Map> idsToTags; /** Elements buffered during a snapshot, by output id. */ @VisibleForTesting @@ -1228,7 +1261,8 @@ public static class BufferedOutputManager implements DoFnRunners.Output Map, OutputTag>> tagsToOutputTags, Map, Integer> tagsToIds, Lock bufferLock, - PushedBackElementsHandler>> pushedBackElementsHandler) { + PushedBackElementsHandler>> pushedBackElementsHandler, + boolean isStreaming) { this.output = output; this.mainTag = mainTag; this.tagsToOutputTags = tagsToOutputTags; @@ -1239,6 +1273,7 @@ public static class BufferedOutputManager implements DoFnRunners.Output idsToTags.put(entry.getValue(), entry.getKey()); } this.pushedBackElementsHandler = pushedBackElementsHandler; + this.isStreaming = isStreaming; } void openBuffer() { @@ -1251,7 +1286,8 @@ void closeBuffer() { @Override public void output(TupleTag tag, WindowedValue value) { - if (!openBuffer) { + // Don't buffer elements in Batch mode + if (!openBuffer || !isStreaming) { emit(tag, value); } else { buffer(KV.of(tagsToIds.get(tag), value)); @@ -1360,6 +1396,7 @@ public static class MultiOutputOutputManagerFactory private final Map, OutputTag>> tagsToOutputTags; private final Map, Coder>> tagsToCoders; private final SerializablePipelineOptions pipelineOptions; + private final boolean isStreaming; // There is no side output. @SuppressWarnings("unchecked") @@ -1388,6 +1425,7 @@ public MultiOutputOutputManagerFactory( this.tagsToCoders = tagsToCoders; this.tagsToIds = tagsToIds; this.pipelineOptions = pipelineOptions; + this.isStreaming = pipelineOptions.get().as(FlinkPipelineOptions.class).isStreaming(); } @Override @@ -1410,7 +1448,13 @@ public BufferedOutputManager create( NonKeyedPushedBackElementsHandler.create(listStateBuffer); return new BufferedOutputManager<>( - output, mainTag, tagsToOutputTags, tagsToIds, bufferLock, pushedBackElementsHandler); + output, + mainTag, + tagsToOutputTags, + tagsToIds, + bufferLock, + pushedBackElementsHandler, + isStreaming); } private TaggedKvCoder buildTaggedKvCoder() { @@ -1491,7 +1535,7 @@ void processPendingProcessingTimeTimers() { keyedStateBackend.setCurrentKey(internalTimer.getKey()); TimerData timer = internalTimer.getNamespace(); checkInvokeStartBundle(); - fireTimerInternal((ByteBuffer) internalTimer.getKey(), timer); + fireTimerInternal((FlinkKey) internalTimer.getKey(), timer); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 456f75b0ee67..53e09f3f818c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -24,7 +24,6 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collection; @@ -59,6 +58,7 @@ import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternalsFactory; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.runners.flink.translation.utils.Locker; @@ -111,7 +111,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; @@ -138,7 +137,8 @@ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -public class ExecutableStageDoFnOperator extends DoFnOperator { +public class ExecutableStageDoFnOperator + extends DoFnOperator { private static final Logger LOG = LoggerFactory.getLogger(ExecutableStageDoFnOperator.class); @@ -247,7 +247,7 @@ protected Lock getLockToAcquireForStateAccessDuringBundles() { public void open() throws Exception { executableStage = ExecutableStage.fromPayload(payload); hasSdfProcessFn = hasSDF(executableStage); - initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions); + initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions, windowCoder); // TODO: Wire this into the distributed cache and make it pluggable. // TODO: Do we really want this layer of indirection when accessing the stage bundle factory? // It's a little strange because this operator is responsible for the lifetime of the stage @@ -369,17 +369,17 @@ static class BagUserStateFactory implements StateRequestHandlers.BagUserStateHandlerFactory { private final StateInternals stateInternals; - private final KeyedStateBackend keyedStateBackend; + private final KeyedStateBackend keyedStateBackend; /** Lock to hold whenever accessing the state backend. */ private final Lock stateBackendLock; /** For debugging: The key coder used by the Runner. */ private final @Nullable Coder runnerKeyCoder; /** For debugging: Same as keyedStateBackend but upcasted, to access key group meta info. */ - private final @Nullable AbstractKeyedStateBackend keyStateBackendWithKeyGroupInfo; + private final @Nullable AbstractKeyedStateBackend keyStateBackendWithKeyGroupInfo; BagUserStateFactory( StateInternals stateInternals, - KeyedStateBackend keyedStateBackend, + KeyedStateBackend keyedStateBackend, Lock stateBackendLock, @Nullable Coder runnerKeyCoder) { this.stateInternals = stateInternals; @@ -389,7 +389,7 @@ static class BagUserStateFactory // This will always succeed, unless a custom state backend is used which does not extend // AbstractKeyedStateBackend. This is unlikely but we should still consider this case. this.keyStateBackendWithKeyGroupInfo = - (AbstractKeyedStateBackend) keyedStateBackend; + (AbstractKeyedStateBackend) keyedStateBackend; } else { this.keyStateBackendWithKeyGroupInfo = null; } @@ -417,7 +417,7 @@ public Iterable get(ByteString key, W window) { "State get for {} {} {} {}", pTransformId, userStateId, - Arrays.toString(keyedStateBackend.getCurrentKey().array()), + Arrays.toString(keyedStateBackend.getCurrentKey().getSerializedKey().array()), window); } BagState bagState = @@ -437,7 +437,7 @@ public void append(ByteString key, W window, Iterator values) { "State append for {} {} {} {}", pTransformId, userStateId, - Arrays.toString(keyedStateBackend.getCurrentKey().array()), + Arrays.toString(keyedStateBackend.getCurrentKey().getSerializedKey().array()), window); } BagState bagState = @@ -458,7 +458,7 @@ public void clear(ByteString key, W window) { "State clear for {} {} {} {}", pTransformId, userStateId, - Arrays.toString(keyedStateBackend.getCurrentKey().array()), + Arrays.toString(keyedStateBackend.getCurrentKey().getSerializedKey().array()), window); } BagState bagState = @@ -469,7 +469,7 @@ public void clear(ByteString key, W window) { private void prepareStateBackend(ByteString key) { // Key for state request is shipped encoded with NESTED context. - ByteBuffer encodedKey = FlinkKeyUtils.fromEncodedKey(key); + FlinkKey encodedKey = FlinkKey.of(FlinkKeyUtils.fromEncodedKey(key)); keyedStateBackend.setCurrentKey(encodedKey); if (keyStateBackendWithKeyGroupInfo != null) { int currentKeyGroupIndex = keyStateBackendWithKeyGroupInfo.getCurrentKeyGroupIndex(); @@ -511,13 +511,13 @@ public void setKeyContextElement1(StreamRecord record) {} public void setCurrentKey(Object key) {} @Override - public ByteBuffer getCurrentKey() { + public FlinkKey getCurrentKey() { // This is the key retrieved by HeapInternalTimerService when setting a Flink timer. // Note: Only called by the TimerService. Must be guarded by a lock. Preconditions.checkState( stateBackendLock.isLocked(), "State backend must be locked when retrieving the current key."); - return this.getKeyedStateBackend().getCurrentKey(); + return this.getKeyedStateBackend().getCurrentKey(); } void setTimer(Timer timerElement, TimerInternals.TimerData timerData) { @@ -527,8 +527,8 @@ void setTimer(Timer timerElement, TimerInternals.TimerData timerData) { LOG.debug("Setting timer: {} {}", timerElement, timerData); // KvToByteBufferKeySelector returns the key encoded, it doesn't care about the // window, timestamp or pane information. - ByteBuffer encodedKey = - (ByteBuffer) + FlinkKey encodedKey = + (FlinkKey) keySelector.getKey( WindowedValue.valueInGlobalWindow( (InputT) KV.of(timerElement.getUserKey(), null))); @@ -562,8 +562,7 @@ class SdfFlinkTimerInternalsFactory implements TimerInternalsFactory { @Override public TimerInternals timerInternalsForKey(InputT key) { try { - ByteBuffer encodedKey = - (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key)); + FlinkKey encodedKey = (FlinkKey) keySelector.getKey(WindowedValue.valueInGlobalWindow(key)); return new SdfFlinkTimerInternals(encodedKey); } catch (Exception e) { throw new RuntimeException("Couldn't get a timer internals", e); @@ -576,9 +575,9 @@ public TimerInternals timerInternalsForKey(InputT key) { * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}. */ class SdfFlinkTimerInternals implements TimerInternals { - private final ByteBuffer key; + private final FlinkKey key; - SdfFlinkTimerInternals(ByteBuffer key) { + SdfFlinkTimerInternals(FlinkKey key) { this.key = key; } @@ -659,8 +658,7 @@ class SdfFlinkStateInternalsFactory implements StateInternalsFactory { @Override public StateInternals stateInternalsForKey(InputT key) { try { - ByteBuffer encodedKey = - (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key)); + FlinkKey encodedKey = (FlinkKey) keySelector.getKey(WindowedValue.valueInGlobalWindow(key)); return new SdfFlinkStateInternals(encodedKey); } catch (Exception e) { throw new RuntimeException("Couldn't get a state internals", e); @@ -671,9 +669,9 @@ public StateInternals stateInternalsForKey(InputT key) { /** A {@link StateInternals} for keeping {@link DelayedBundleApplication}s as states. */ class SdfFlinkStateInternals implements StateInternals { - private final ByteBuffer key; + private final FlinkKey key; - SdfFlinkStateInternals(ByteBuffer key) { + SdfFlinkStateInternals(FlinkKey key) { this.key = key; } @@ -697,7 +695,7 @@ public T state( } @Override - protected void fireTimerInternal(ByteBuffer key, TimerInternals.TimerData timer) { + protected void fireTimerInternal(FlinkKey key, TimerInternals.TimerData timer) { // We have to synchronize to ensure the state backend is not concurrently accessed by the state // requests try (Locker locker = Locker.locked(stateBackendLock)) { @@ -774,7 +772,7 @@ DoFnRunner createBufferingDoFnRunnerIfNeeded( serializedOptions, keyedBufferingBackend != null ? () -> Locker.locked(stateBackendLock) : null, keyedBufferingBackend != null - ? input -> FlinkKeyUtils.encodeKey(((KV) input).getKey(), (Coder) keyCoder) + ? input -> FlinkKey.of(((KV) input).getKey(), (Coder) keyCoder) : null, sdkHarnessRunner::emitResults); } @@ -797,7 +795,7 @@ protected DoFnRunner createWrappingDoFnRunner( windowCoder, inputCoder, this::setTimer, - () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder), + () -> FlinkKeyUtils.decodeKey(getCurrentKey().getSerializedKey(), keyCoder), keyedStateInternals); return ensureStateDoFnRunner(sdkHarnessRunner, payload, stepContext); @@ -1116,7 +1114,7 @@ private DoFnRunner ensureStateDoFnRunner( .map(UserStateReference::localName) .collect(Collectors.toList()); - KeyedStateBackend stateBackend = getKeyedStateBackend(); + KeyedStateBackend stateBackend = getKeyedStateBackend(); StateCleaner stateCleaner = new StateCleaner( @@ -1159,7 +1157,7 @@ static class CleanupTimer implements StatefulDoFnRunner.CleanupTimer keyedStateBackend; + private final KeyedStateBackend keyedStateBackend; CleanupTimer( TimerInternals timerInternals, @@ -1167,7 +1165,7 @@ static class CleanupTimer implements StatefulDoFnRunner.CleanupTimer keyedStateBackend) { + KeyedStateBackend keyedStateBackend) { this.timerInternals = timerInternals; this.stateBackendLock = stateBackendLock; this.windowingStrategy = windowingStrategy; @@ -1186,7 +1184,7 @@ public void setForWindow(InputT input, BoundedWindow window) { return; } // needs to match the encoding in prepareStateBackend for state request handler - final ByteBuffer key = FlinkKeyUtils.encodeKey(((KV) input).getKey(), keyCoder); + final FlinkKey key = FlinkKey.of(((KV) input).getKey(), keyCoder); // Ensure the state backend is not concurrently accessed by the state requests try (Locker locker = Locker.locked(stateBackendLock)) { keyedStateBackend.setCurrentKey(key); @@ -1221,15 +1219,15 @@ static class StateCleaner implements StatefulDoFnRunner.StateCleaner userStateNames; private final Coder windowCoder; - private final ArrayDeque> cleanupQueue; - private final Supplier currentKeySupplier; + private final ArrayDeque> cleanupQueue; + private final Supplier currentKeySupplier; private final ThrowingFunction hasPendingEventTimeTimers; private final CleanupTimer cleanupTimer; StateCleaner( List userStateNames, Coder windowCoder, - Supplier currentKeySupplier, + Supplier currentKeySupplier, ThrowingFunction hasPendingEventTimeTimers, CleanupTimer cleanupTimer) { this.userStateNames = userStateNames; @@ -1247,11 +1245,10 @@ public void clearForWindow(BoundedWindow window) { cleanupQueue.add(KV.of(currentKeySupplier.get(), window)); } - @SuppressWarnings("ByteBufferBackingArray") - void cleanupState(StateInternals stateInternals, Consumer keyContextConsumer) + void cleanupState(StateInternals stateInternals, Consumer keyContextConsumer) throws Exception { while (!cleanupQueue.isEmpty()) { - KV kv = Preconditions.checkNotNull(cleanupQueue.remove()); + KV kv = Preconditions.checkNotNull(cleanupQueue.remove()); BoundedWindow window = Preconditions.checkNotNull(kv.getValue()); keyContextConsumer.accept(kv.getKey()); // Check whether we have pending timers which were set during the bundle. @@ -1260,7 +1257,10 @@ void cleanupState(StateInternals stateInternals, Consumer keyContext cleanupTimer.setCleanupTimer(window); } else { if (LOG.isDebugEnabled()) { - LOG.debug("State cleanup for {} {}", Arrays.toString(kv.getKey().array()), window); + LOG.debug( + "State cleanup for {} {}", + Arrays.toString(kv.getKey().getSerializedKey().array()), + window); } // No more timers (finally!). Time to clean up. for (String userState : userStateNames) { @@ -1280,14 +1280,15 @@ void cleanupState(StateInternals stateInternals, Consumer keyContext private static void initializeUserState( ExecutableStage executableStage, @Nullable KeyedStateBackend keyedStateBackend, - SerializablePipelineOptions pipelineOptions) { + SerializablePipelineOptions pipelineOptions, + Coder windowCoder) { executableStage .getUserStates() .forEach( ref -> { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + new FlinkStateInternals.FlinkStateNamespaceKeySerializer(windowCoder), new ListStateDescriptor<>( ref.localName(), new CoderTypeSerializer<>(ByteStringCoder.of(), pipelineOptions))); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToFlinkKeyKeySelector.java similarity index 62% rename from runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java rename to runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToFlinkKeyKeySelector.java index 204247b1d836..a852a724c040 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToByteBufferKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/KvToFlinkKeyKeySelector.java @@ -17,40 +17,37 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming; -import java.nio.ByteBuffer; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; /** * {@link KeySelector} that retrieves a key from a {@link KV}. This will return the key as encoded - * by the provided {@link Coder} in a {@link ByteBuffer}. This ensures that all key + * by the provided {@link Coder} in a {@link FlinkKey}. This ensures that all key * comparisons/hashing happen on the encoded form. */ -public class KvToByteBufferKeySelector - implements KeySelector>, ByteBuffer>, ResultTypeQueryable { +public class KvToFlinkKeyKeySelector + implements KeySelector>, FlinkKey>, ResultTypeQueryable { private final Coder keyCoder; - private final SerializablePipelineOptions pipelineOptions; - public KvToByteBufferKeySelector(Coder keyCoder, SerializablePipelineOptions pipelineOptions) { + public KvToFlinkKeyKeySelector(Coder keyCoder) { this.keyCoder = keyCoder; - this.pipelineOptions = pipelineOptions; } @Override - public ByteBuffer getKey(WindowedValue> value) { + public FlinkKey getKey(WindowedValue> value) { K key = value.getValue().getKey(); - return FlinkKeyUtils.encodeKey(key, keyCoder); + return FlinkKey.of(key, keyCoder); } @Override - public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); + public TypeInformation getProducedType() { + return ValueTypeInfo.of(FlinkKey.class); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java new file mode 100644 index 000000000000..03570143231b --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/PartialReduceBundleOperator.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.beam.runners.flink.translation.functions.AbstractFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.functions.HashingFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.functions.SortingFlinkCombineRunner; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.util.Collector; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class PartialReduceBundleOperator + extends DoFnOperator, KV, KV> { + + private final CombineFnBase.GlobalCombineFn combineFn; + + private Multimap>> state; + private transient @Nullable ListState>> checkpointedState; + + public PartialReduceBundleOperator( + CombineFnBase.GlobalCombineFn combineFn, + String stepName, + Coder>> windowedInputCoder, + TupleTag> mainOutputTag, + List> additionalOutputTags, + OutputManagerFactory> outputManagerFactory, + WindowingStrategy windowingStrategy, + Map> sideInputTagMapping, + Collection> sideInputs, + PipelineOptions options) { + super( + null, + stepName, + windowedInputCoder, + Collections.emptyMap(), + mainOutputTag, + additionalOutputTags, + outputManagerFactory, + windowingStrategy, + sideInputTagMapping, + sideInputs, + options, + null, + null, + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + this.combineFn = combineFn; + this.state = ArrayListMultimap.create(); + this.checkpointedState = null; + } + + @Override + public void open() throws Exception { + clearState(); + setBundleFinishedCallback(this::finishBundle); + super.open(); + } + + @Override + protected boolean shoudBundleElements() { + return true; + } + + private void finishBundle() { + AbstractFlinkCombineRunner reduceRunner; + try { + if (windowingStrategy.needsMerge() && windowingStrategy.getWindowFn() instanceof Sessions) { + reduceRunner = new SortingFlinkCombineRunner<>(); + } else { + reduceRunner = new HashingFlinkCombineRunner<>(); + } + + for (Map.Entry>>> e : state.asMap().entrySet()) { + //noinspection unchecked + reduceRunner.combine( + new AbstractFlinkCombineRunner.PartialFlinkCombiner<>(combineFn), + (WindowingStrategy) windowingStrategy, + sideInputReader, + serializedOptions.get(), + e.getValue(), + new Collector>>() { + @Override + public void collect(WindowedValue> record) { + outputManager.output(mainOutputTag, record); + } + + @Override + public void close() {} + }); + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + clearState(); + } + + private void clearState() { + this.state = ArrayListMultimap.create(); + if (this.checkpointedState != null) { + this.checkpointedState.clear(); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + ListStateDescriptor>> descriptor = + new ListStateDescriptor<>( + "buffered-elements", new CoderTypeSerializer<>(windowedInputCoder, serializedOptions)); + + checkpointedState = context.getOperatorStateStore().getListState(descriptor); + + if (context.isRestored() && this.checkpointedState != null) { + for (WindowedValue> wkv : this.checkpointedState.get()) { + this.state.put(wkv.getValue().getKey(), wkv); + } + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + if (this.checkpointedState != null) { + this.checkpointedState.update(new ArrayList<>(this.state.values())); + } + } + + @Override + protected DoFn, KV> getDoFn() { + return new DoFn, KV>() { + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow window) throws Exception { + WindowedValue> windowedValue = + WindowedValue.of(c.element(), c.timestamp(), window, c.pane()); + state.put(Objects.requireNonNull(c.element()).getKey(), windowedValue); + } + }; + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfFlinkKeyKeySelector.java similarity index 65% rename from runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java rename to runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfFlinkKeyKeySelector.java index 8c6f10abf448..b316726e74f8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfFlinkKeyKeySelector.java @@ -17,45 +17,42 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming; -import java.nio.ByteBuffer; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; /** * {@link KeySelector} that retrieves a key from a {@code KV>, size>}. This will return the element as encoded by the provided {@link Coder} - * in a {@link ByteBuffer}. This ensures that all key comparisons/hashing happen on the encoded - * form. Note that the reason we don't use the whole {@code KV>, Double>} as the key is when checkpoint happens, we will get different * restriction/watermarkState/size, which Flink treats as a new key. Using new key to set state and * timer may cause defined behavior. */ -public class SdfByteBufferKeySelector - implements KeySelector, Double>>, ByteBuffer>, - ResultTypeQueryable { +public class SdfFlinkKeyKeySelector + implements KeySelector, Double>>, FlinkKey>, + ResultTypeQueryable { private final Coder keyCoder; - private final SerializablePipelineOptions pipelineOptions; - public SdfByteBufferKeySelector(Coder keyCoder, SerializablePipelineOptions pipelineOptions) { + public SdfFlinkKeyKeySelector(Coder keyCoder) { this.keyCoder = keyCoder; - this.pipelineOptions = pipelineOptions; } @Override - public ByteBuffer getKey(WindowedValue, Double>> value) { + public FlinkKey getKey(WindowedValue, Double>> value) { K key = value.getValue().getKey().getKey(); - return FlinkKeyUtils.encodeKey(key, keyCoder); + return FlinkKey.of(key, keyCoder); } @Override - public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); + public TypeInformation getProducedType() { + return ValueTypeInfo.of(FlinkKey.class); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index 8eae5be177a5..d80dd60a5925 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -65,7 +65,10 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SplittableDoFnOperator - extends DoFnOperator>, OutputT> { + extends DoFnOperator< + KeyedWorkItem>, + KeyedWorkItem>, + OutputT> { private static final Logger LOG = LoggerFactory.getLogger(SplittableDoFnOperator.class); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index d8f4885ea057..60b20f375f22 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -19,6 +19,7 @@ import static org.apache.beam.runners.core.TimerInternals.TimerData; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -50,7 +51,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class WindowDoFnOperator - extends DoFnOperator, KV> { + extends DoFnOperator, KeyedWorkItem, KV> { private final SystemReduceFn systemReduceFn; @@ -87,6 +88,25 @@ public WindowDoFnOperator( this.systemReduceFn = systemReduceFn; } + @Override + protected Iterable>> preProcess( + WindowedValue> inWithMultipleWindows) { + // we need to wrap each one work item per window for now + // since otherwise the PushbackSideInputRunner will not correctly + // determine whether side inputs are ready + // + // this is tracked as https://github.com/apache/beam/issues/18358 + ArrayList>> inputs = new ArrayList<>(); + for (WindowedValue> in : inWithMultipleWindows.explodeWindows()) { + SingletonKeyedWorkItem workItem = + new SingletonKeyedWorkItem<>( + in.getValue().getKey(), in.withValue(in.getValue().getValue())); + + inputs.add(in.withValue(workItem)); + } + return inputs; + } + @Override protected DoFnRunner, KV> createWrappingDoFnRunner( DoFnRunner, KV> wrappedRunner, StepContext stepContext) { diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java index 64ea6ca26d4d..d809f4287983 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WorkItemKeySelector.java @@ -19,13 +19,13 @@ import java.nio.ByteBuffer; import org.apache.beam.runners.core.KeyedWorkItem; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.util.WindowedValue; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; /** * {@link KeySelector} that retrieves a key from a {@link KeyedWorkItem}. This will return the key @@ -33,25 +33,23 @@ * comparisons/hashing happen on the encoded form. */ public class WorkItemKeySelector - implements KeySelector>, ByteBuffer>, - ResultTypeQueryable { + implements KeySelector>, FlinkKey>, + ResultTypeQueryable { private final Coder keyCoder; - private final SerializablePipelineOptions pipelineOptions; - public WorkItemKeySelector(Coder keyCoder, SerializablePipelineOptions pipelineOptions) { + public WorkItemKeySelector(Coder keyCoder) { this.keyCoder = keyCoder; - this.pipelineOptions = pipelineOptions; } @Override - public ByteBuffer getKey(WindowedValue> value) throws Exception { + public FlinkKey getKey(WindowedValue> value) throws Exception { K key = value.getValue().key(); - return FlinkKeyUtils.encodeKey(key, keyCoder); + return FlinkKey.of(key, keyCoder); } @Override - public TypeInformation getProducedType() { - return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions); + public TypeInformation getProducedType() { + return ValueTypeInfo.of(FlinkKey.class); } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/DedupingOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/DedupingOperator.java index d43723964844..9d238aa36110 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/DedupingOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/DedupingOperator.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.io; -import java.nio.ByteBuffer; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; @@ -43,7 +43,7 @@ }) public class DedupingOperator extends AbstractStreamOperator> implements OneInputStreamOperator>, WindowedValue>, - Triggerable { + Triggerable { private static final long MAX_RETENTION_SINCE_ACCESS = Duration.standardMinutes(10L).getMillis(); private final SerializablePipelineOptions options; @@ -94,12 +94,12 @@ public void processElement(StreamRecord>> str } @Override - public void onEventTime(InternalTimer internalTimer) { + public void onEventTime(InternalTimer internalTimer) { // will never happen } @Override - public void onProcessingTime(InternalTimer internalTimer) + public void onProcessingTime(InternalTimer internalTimer) throws Exception { ValueState dedupingState = getPartitionedState(dedupingStateDescriptor); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java index 506b651da68f..74eba2491d3d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSource.java @@ -118,8 +118,14 @@ public Boundedness getBoundedness() { @Override public SplitEnumerator, Map>>> createEnumerator(SplitEnumeratorContext> enumContext) throws Exception { - return new FlinkSourceSplitEnumerator<>( - enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + + if (boundedness == Boundedness.BOUNDED) { + return new LazyFlinkSourceSplitEnumerator<>( + enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + } else { + return new FlinkSourceSplitEnumerator<>( + enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + } } @Override @@ -128,9 +134,8 @@ public Boundedness getBoundedness() { SplitEnumeratorContext> enumContext, Map>> checkpoint) throws Exception { - FlinkSourceSplitEnumerator enumerator = - new FlinkSourceSplitEnumerator<>( - enumContext, beamSource, serializablePipelineOptions.get(), numSplits); + SplitEnumerator, Map>>> enumerator = + createEnumerator(enumContext); checkpoint.forEach( (subtaskId, splitsForSubtask) -> enumerator.addSplitsBack(splitsForSubtask, subtaskId)); return enumerator; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java new file mode 100644 index 000000000000..b7046cff2cb2 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/LazyFlinkSourceSplitEnumerator.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io.source; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.FileBasedSource; +import org.apache.beam.sdk.io.Source; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.flink.api.connector.source.SplitEnumerator; +import org.apache.flink.api.connector.source.SplitEnumeratorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Flink {@link org.apache.flink.api.connector.source.SplitEnumerator SplitEnumerator} + * implementation that holds a Beam {@link Source} and does the following: + * + *
    + *
  • Split the Beam {@link Source} to desired number of splits. + *
  • Lazily assign the splits to the Flink Source Reader. + *
+ * + * @param The output type of the encapsulated Beam {@link Source}. + */ +public class LazyFlinkSourceSplitEnumerator + implements SplitEnumerator, Map>>> { + private static final Logger LOG = LoggerFactory.getLogger(LazyFlinkSourceSplitEnumerator.class); + private final SplitEnumeratorContext> context; + private final Source beamSource; + private final PipelineOptions pipelineOptions; + private final int numSplits; + private final List> pendingSplits; + + public LazyFlinkSourceSplitEnumerator( + SplitEnumeratorContext> context, + Source beamSource, + PipelineOptions pipelineOptions, + int numSplits) { + this.context = context; + this.beamSource = beamSource; + this.pipelineOptions = pipelineOptions; + this.numSplits = numSplits; + this.pendingSplits = new ArrayList<>(numSplits); + } + + @Override + public void start() { + context.callAsync( + () -> { + try { + LOG.info("Starting source {}", beamSource); + List> beamSplitSourceList = splitBeamSource(); + int i = 0; + for (Source beamSplitSource : beamSplitSourceList) { + pendingSplits.add(new FlinkSourceSplit<>(i, beamSplitSource)); + i++; + } + return pendingSplits; + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + (sourceSplits, error) -> { + pendingSplits.addAll(sourceSplits); + if (error != null) { + throw new RuntimeException("Failed to start source enumerator.", error); + } + }); + } + + @Override + public void handleSplitRequest(int subtask, @Nullable String hostname) { + if (!context.registeredReaders().containsKey(subtask)) { + // reader failed between sending the request and now. skip this request. + return; + } + + if (LOG.isInfoEnabled()) { + final String hostInfo = + hostname == null ? "(no host locality info)" : "(on host '" + hostname + "')"; + LOG.info("Subtask {} {} is requesting a file source split", subtask, hostInfo); + } + + if (!pendingSplits.isEmpty()) { + final FlinkSourceSplit split = pendingSplits.remove(pendingSplits.size() - 1); + context.assignSplit(split, subtask); + LOG.info("Assigned split to subtask {} : {}", subtask, split); + } else { + context.signalNoMoreSplits(subtask); + LOG.info("No more splits available for subtask {}", subtask); + } + } + + @Override + public void addSplitsBack(List> splits, int subtaskId) { + LOG.info("Adding splits {} back from subtask {}", splits, subtaskId); + pendingSplits.addAll(splits); + } + + @Override + public void addReader(int subtaskId) { + // this source is purely lazy-pull-based, nothing to do upon registration + } + + @Override + public Map>> snapshotState(long checkpointId) throws Exception { + LOG.info("Taking snapshot for checkpoint {}", checkpointId); + return snapshotState(); + } + + public Map>> snapshotState() throws Exception { + // For type compatibility reasons, we return a Map but we do not actually care about the key + Map>> state = new HashMap<>(1); + state.put(1, pendingSplits); + return state; + } + + @Override + public void close() throws IOException { + // NoOp + } + + private long getDesiredSizeBytes(int numSplits, BoundedSource boundedSource) throws Exception { + long totalSize = boundedSource.getEstimatedSizeBytes(pipelineOptions); + long defaultSplitSize = totalSize / numSplits; + long maxSplitSize = 0; + if (pipelineOptions != null) { + maxSplitSize = pipelineOptions.as(FlinkPipelineOptions.class).getFileInputSplitMaxSizeMB(); + } + if (beamSource instanceof FileBasedSource && maxSplitSize > 0) { + // Most of the time parallelism is < number of files in source. + // Each file becomes a unique split which commonly create skew. + // This limits the size of splits to reduce skew. + return Math.min(defaultSplitSize, maxSplitSize * 1024 * 1024); + } else { + return defaultSplitSize; + } + } + + // -------------- Private helper methods ---------------------- + private List> splitBeamSource() throws Exception { + if (beamSource instanceof BoundedSource) { + BoundedSource boundedSource = (BoundedSource) beamSource; + long desiredSizeBytes = getDesiredSizeBytes(numSplits, boundedSource); + List> splits = + ((BoundedSource) beamSource).split(desiredSizeBytes, pipelineOptions); + LOG.info("Split bounded source {} in {} splits", beamSource, splits.size()); + return splits; + } else if (beamSource instanceof UnboundedSource) { + List> splits = + ((UnboundedSource) beamSource).split(numSplits, pipelineOptions); + LOG.info("Split source {} to {} splits", beamSource, splits); + return splits; + } else { + throw new IllegalStateException("Unknown source type " + beamSource.getClass()); + } + } +} diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java index e4bd4496ae90..6b23dd13c9b8 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java @@ -100,6 +100,11 @@ protected FlinkBoundedSourceReader( @Override public InputStatus pollNext(ReaderOutput> output) throws Exception { checkExceptionAndMaybeThrow(); + + if (currentReader == null && currentSplitId == -1) { + context.sendSplitRequest(); + } + if (currentReader == null && !moveToNextNonEmptyReader()) { // Nothing to read for now. if (noMoreSplits()) { @@ -137,6 +142,7 @@ public InputStatus pollNext(ReaderOutput> output) throws Except LOG.debug("Finished reading from {}", currentSplitId); currentReader = null; currentSplitId = -1; + context.sendSplitRequest(); } // Always return MORE_AVAILABLE here regardless of the availability of next record. If there // is no more diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 205270c22332..47390428d4bd 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,7 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; -import java.nio.ByteBuffer; +import java.io.IOException; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; @@ -33,6 +33,8 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; import org.apache.beam.sdk.coders.Coder; @@ -55,6 +57,7 @@ import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CombineContextFactory; @@ -74,8 +77,13 @@ import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.JavaSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; @@ -89,7 +97,7 @@ * {@link StateInternals} that uses a Flink {@link KeyedStateBackend} to manage state. * *

Note: In the Flink streaming runner the key is always encoded using an {@link Coder} and - * stored in a {@link ByteBuffer}. + * stored in a {@link FlinkKey}. */ @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -100,8 +108,9 @@ public class FlinkStateInternals implements StateInternals { private static final StateNamespace globalWindowNamespace = StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE); - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; private final Coder keyCoder; + FlinkStateNamespaceKeySerializer namespaceKeySerializer; private static class StateAndNamespaceDescriptor { static StateAndNamespaceDescriptor of( @@ -162,22 +171,24 @@ public String toString() { // State to persist combined watermark holds for all keys of this partition private final MapStateDescriptor watermarkHoldStateDescriptor; - private final SerializablePipelineOptions pipelineOptions; + private final boolean fasterCopy; public FlinkStateInternals( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, Coder keyCoder, + Coder windowCoder, SerializablePipelineOptions pipelineOptions) throws Exception { this.flinkStateBackend = Objects.requireNonNull(flinkStateBackend); this.keyCoder = Objects.requireNonNull(keyCoder); + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceKeySerializer = new FlinkStateNamespaceKeySerializer(windowCoder); + watermarkHoldStateDescriptor = new MapStateDescriptor<>( "watermark-holds", StringSerializer.INSTANCE, - new CoderTypeSerializer<>(InstantCoder.of(), pipelineOptions)); - this.pipelineOptions = pipelineOptions; - + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy)); restoreWatermarkHoldsView(); } @@ -192,8 +203,8 @@ public Long minWatermarkHoldMs() { @Override public K getKey() { - ByteBuffer keyBytes = flinkStateBackend.getCurrentKey(); - return FlinkKeyUtils.decodeKey(keyBytes, keyCoder); + FlinkKey keyBytes = flinkStateBackend.getCurrentKey(); + return FlinkKeyUtils.decodeKey(keyBytes.getSerializedKey(), keyCoder); } @Override @@ -241,29 +252,30 @@ private FlinkStateBinder(StateNamespace namespace, StateContext stateContext) public ValueState bindValue( String id, StateSpec> spec, Coder coder) { FlinkValueState valueState = - new FlinkValueState<>(flinkStateBackend, id, namespace, coder, pipelineOptions); + new FlinkValueState<>( + flinkStateBackend, id, namespace, coder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - valueState.flinkStateDescriptor, - valueState.namespace.stringKey(), - StringSerializer.INSTANCE); + valueState.flinkStateDescriptor, valueState.namespace, namespaceKeySerializer); return valueState; } @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { FlinkBagState bagState = - new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkBagState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - bagState.flinkStateDescriptor, bagState.namespace.stringKey(), StringSerializer.INSTANCE); + bagState.flinkStateDescriptor, bagState.namespace, namespaceKeySerializer); return bagState; } @Override public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { FlinkSetState setState = - new FlinkSetState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkSetState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( - setState.flinkStateDescriptor, setState.namespace.stringKey(), StringSerializer.INSTANCE); + setState.flinkStateDescriptor, setState.namespace, namespaceKeySerializer); return setState; } @@ -275,9 +287,15 @@ public MapState bindMap( Coder mapValueCoder) { FlinkMapState mapState = new FlinkMapState<>( - flinkStateBackend, id, namespace, mapKeyCoder, mapValueCoder, pipelineOptions); + flinkStateBackend, + id, + namespace, + mapKeyCoder, + mapValueCoder, + namespaceKeySerializer, + fasterCopy); collectGlobalWindowStateDescriptor( - mapState.flinkStateDescriptor, mapState.namespace.stringKey(), StringSerializer.INSTANCE); + mapState.flinkStateDescriptor, mapState.namespace, namespaceKeySerializer); return mapState; } @@ -285,11 +303,12 @@ public MapState bindMap( public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { FlinkOrderedListState flinkOrderedListState = - new FlinkOrderedListState<>(flinkStateBackend, id, namespace, elemCoder, pipelineOptions); + new FlinkOrderedListState<>( + flinkStateBackend, id, namespace, elemCoder, namespaceKeySerializer, fasterCopy); collectGlobalWindowStateDescriptor( flinkOrderedListState.flinkStateDescriptor, - flinkOrderedListState.namespace.stringKey(), - StringSerializer.INSTANCE); + flinkOrderedListState.namespace, + namespaceKeySerializer); return flinkOrderedListState; } @@ -311,11 +330,15 @@ public CombiningState bindCom Combine.CombineFn combineFn) { FlinkCombiningState combiningState = new FlinkCombiningState<>( - flinkStateBackend, id, combineFn, namespace, accumCoder, pipelineOptions); + flinkStateBackend, + id, + combineFn, + namespace, + accumCoder, + namespaceKeySerializer, + fasterCopy); collectGlobalWindowStateDescriptor( - combiningState.flinkStateDescriptor, - combiningState.namespace.stringKey(), - StringSerializer.INSTANCE); + combiningState.flinkStateDescriptor, combiningState.namespace, namespaceKeySerializer); return combiningState; } @@ -333,12 +356,13 @@ CombiningState bindCombiningWithContext( combineFn, namespace, accumCoder, + namespaceKeySerializer, CombineContextFactory.createFromStateContext(stateContext), - pipelineOptions); + fasterCopy); collectGlobalWindowStateDescriptor( combiningStateWithContext.flinkStateDescriptor, - combiningStateWithContext.namespace.stringKey(), - StringSerializer.INSTANCE); + combiningStateWithContext.namespace, + namespaceKeySerializer); return combiningStateWithContext; } @@ -368,34 +392,156 @@ private void collectGlobalWindowStateDescriptor( } } + public static class FlinkStateNamespaceKeySerializer extends TypeSerializer { + + public Coder getCoder() { + return coder; + } + + private final Coder coder; + + public FlinkStateNamespaceKeySerializer(Coder coder) { + this.coder = coder; + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return this; + } + + @Override + public StateNamespace createInstance() { + return null; + } + + @Override + public StateNamespace copy(StateNamespace from) { + return from; + } + + @Override + public StateNamespace copy(StateNamespace from, StateNamespace reuse) { + return from; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(StateNamespace record, DataOutputView target) throws IOException { + StringSerializer.INSTANCE.serialize(record.stringKey(), target); + } + + @Override + public StateNamespace deserialize(DataInputView source) throws IOException { + return StateNamespaces.fromString(StringSerializer.INSTANCE.deserialize(source), coder); + } + + @Override + public StateNamespace deserialize(StateNamespace reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + throw new UnsupportedOperationException("copy is not supported for FlinkStateNamespace key"); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof FlinkStateNamespaceKeySerializer; + } + + @Override + public int hashCode() { + return Objects.hashCode(getClass()); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new FlinkStateNameSpaceSerializerSnapshot(this); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class FlinkStateNameSpaceSerializerSnapshot + implements TypeSerializerSnapshot { + + @Nullable private Coder windowCoder; + + public FlinkStateNameSpaceSerializerSnapshot() {} + + FlinkStateNameSpaceSerializerSnapshot(FlinkStateNamespaceKeySerializer ser) { + this.windowCoder = ser.getCoder(); + } + + @Override + public int getCurrentVersion() { + return 0; + } + + @Override + public void writeSnapshot(DataOutputView out) throws IOException { + new JavaSerializer>().serialize(windowCoder, out); + } + + @Override + public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) + throws IOException { + this.windowCoder = new JavaSerializer>().deserialize(in); + } + + @Override + public TypeSerializer restoreSerializer() { + return new FlinkStateNamespaceKeySerializer(windowCoder); + } + + @Override + public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( + TypeSerializer newSerializer) { + return TypeSerializerSchemaCompatibility.compatibleAsIs(); + } + } + } + private static class FlinkValueState implements ValueState { private final StateNamespace namespace; private final String stateId; private final ValueStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkValueState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = - new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, pipelineOptions)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); } @Override public void write(T input) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .update(input); } catch (Exception e) { throw new RuntimeException("Error updating state.", e); @@ -411,8 +557,7 @@ public ValueState readLater() { public T read() { try { return flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -423,8 +568,7 @@ public T read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -456,19 +600,22 @@ public int hashCode() { private static class FlinkOrderedListState implements OrderedListState { private final StateNamespace namespace; private final ListStateDescriptor> flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkOrderedListState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new ListStateDescriptor<>( - stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), pipelineOptions)); + stateId, new CoderTypeSerializer<>(TimestampedValueCoder.of(coder), fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -483,7 +630,7 @@ public void clearRange(Instant minTimestamp, Instant limitTimestamp) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.update(Lists.newArrayList(sortedMap.values())); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -500,7 +647,7 @@ public void add(TimestampedValue value) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); partitionedState.add(value); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -515,8 +662,7 @@ public Boolean read() { try { Iterable> result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -542,7 +688,7 @@ private SortedMap> readAsMap() { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); listValues = MoreObjects.firstNonNull(partitionedState.get(), Collections.emptyList()); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -564,8 +710,7 @@ public GroupingState, Iterable>> readLat public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -578,22 +723,25 @@ private static class FlinkBagState implements BagState { private final StateNamespace namespace; private final String stateId; private final ListStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; private final boolean storesVoidValues; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkBagState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.storesVoidValues = coder instanceof VoidCoder; this.flinkStateDescriptor = - new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, pipelineOptions)); + new ListStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -601,7 +749,7 @@ public void add(T input) { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); if (storesVoidValues) { Preconditions.checkState(input == null, "Expected to a null value but was: %s", input); // Flink does not allow storing null values @@ -625,7 +773,7 @@ public Iterable read() { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); Iterable result = partitionedState.get(); if (storesVoidValues) { return () -> { @@ -661,8 +809,7 @@ public Boolean read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -681,8 +828,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -718,24 +864,26 @@ private static class FlinkCombiningState private final String stateId; private final Combine.CombineFn combineFn; private final ValueStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkCombiningState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, Combine.CombineFn combineFn, StateNamespace namespace, Coder accumCoder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; + this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = - new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, pipelineOptions)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -748,7 +896,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -766,7 +914,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -785,8 +933,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(); } catch (Exception e) { @@ -804,7 +951,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -824,8 +971,7 @@ public ReadableState isEmpty() { public Boolean read() { try { return flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -844,8 +990,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -881,27 +1026,29 @@ private static class FlinkCombiningStateWithContext private final String stateId; private final CombineWithContext.CombineFnWithContext combineFn; private final ValueStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; private final CombineWithContext.Context context; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkCombiningStateWithContext( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, CombineWithContext.CombineFnWithContext combineFn, StateNamespace namespace, Coder accumCoder, + FlinkStateNamespaceKeySerializer namespaceSerializer, CombineWithContext.Context context, - SerializablePipelineOptions pipelineOptions) { + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; this.context = context; + this.namespaceSerializer = namespaceSerializer; flinkStateDescriptor = - new ValueStateDescriptor<>( - stateId, new CoderTypeSerializer<>(accumCoder, pipelineOptions)); + new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(accumCoder, fasterCopy)); } @Override @@ -914,7 +1061,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -932,7 +1079,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -951,8 +1098,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(context); } catch (Exception e) { @@ -970,7 +1116,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -990,8 +1136,7 @@ public ReadableState isEmpty() { public Boolean read() { try { return flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -1010,8 +1155,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1048,7 +1192,7 @@ private class FlinkWatermarkHoldState implements WatermarkHoldState { private org.apache.flink.api.common.state.MapState watermarkHoldsState; public FlinkWatermarkHoldState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, MapStateDescriptor watermarkHoldStateDescriptor, String stateId, StateNamespace namespace, @@ -1170,23 +1314,26 @@ private static class FlinkMapState implements MapState flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkMapState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder mapKeyCoder, Coder mapValueCoder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new MapStateDescriptor<>( stateId, - new CoderTypeSerializer<>(mapKeyCoder, pipelineOptions), - new CoderTypeSerializer<>(mapValueCoder, pipelineOptions)); + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy)); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -1203,8 +1350,7 @@ public ReadableState get(final KeyT input) { try { ValueT value = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(key); return (value != null) ? value : defaultValue; } catch (Exception e) { @@ -1223,8 +1369,7 @@ public ReadableState get(final KeyT input) { public void put(KeyT key, ValueT value) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(key, value); } catch (Exception e) { throw new RuntimeException("Error put kv to state.", e); @@ -1237,14 +1382,12 @@ public ReadableState computeIfAbsent( try { ValueT current = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(key); if (current == null) { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(current); @@ -1257,8 +1400,7 @@ public ReadableState computeIfAbsent( public void remove(KeyT key) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .remove(key); } catch (Exception e) { throw new RuntimeException("Error remove map state key.", e); @@ -1273,8 +1415,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1297,8 +1438,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .values(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1321,8 +1461,7 @@ public Iterable> read() { try { Iterable> result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .entries(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1360,8 +1499,7 @@ public ReadableState>> readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1395,22 +1533,23 @@ private static class FlinkSetState implements SetState { private final StateNamespace namespace; private final String stateId; private final MapStateDescriptor flinkStateDescriptor; - private final KeyedStateBackend flinkStateBackend; + private final KeyedStateBackend flinkStateBackend; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; FlinkSetState( - KeyedStateBackend flinkStateBackend, + KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, Coder coder, - SerializablePipelineOptions pipelineOptions) { + FlinkStateNamespaceKeySerializer namespaceSerializer, + boolean fasterCopy) { this.namespace = namespace; this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new MapStateDescriptor<>( - stateId, - new CoderTypeSerializer<>(coder, pipelineOptions), - BooleanSerializer.INSTANCE); + stateId, new CoderTypeSerializer<>(coder, fasterCopy), BooleanSerializer.INSTANCE); + this.namespaceSerializer = namespaceSerializer; } @Override @@ -1418,8 +1557,7 @@ public ReadableState contains(final T t) { try { Boolean result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .get(t); return ReadableStates.immediate(result != null && result); } catch (Exception e) { @@ -1432,7 +1570,7 @@ public ReadableState addIfAbsent(final T t) { try { org.apache.flink.api.common.state.MapState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespace, namespaceSerializer, flinkStateDescriptor); boolean alreadyContained = state.contains(t); if (!alreadyContained) { state.put(t, true); @@ -1447,8 +1585,7 @@ public ReadableState addIfAbsent(final T t) { public void remove(T t) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .remove(t); } catch (Exception e) { throw new RuntimeException("Error remove value to state.", e); @@ -1464,8 +1601,7 @@ public SetState readLater() { public void add(T value) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .put(value, true); } catch (Exception e) { throw new RuntimeException("Error add value to state.", e); @@ -1480,8 +1616,7 @@ public Boolean read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result == null || Iterables.isEmpty(result); } catch (Exception e) { @@ -1501,8 +1636,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1514,8 +1648,7 @@ public Iterable read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespace, namespaceSerializer, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1557,9 +1690,9 @@ private void restoreWatermarkHoldsView() throws Exception { org.apache.flink.api.common.state.MapState mapState = flinkStateBackend.getPartitionedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, watermarkHoldStateDescriptor); - try (Stream keys = + try (Stream keys = flinkStateBackend.getKeys(watermarkHoldStateDescriptor.getName(), VoidNamespace.INSTANCE)) { - Iterator iterator = keys.iterator(); + Iterator iterator = keys.iterator(); while (iterator.hasNext()) { flinkStateBackend.setCurrentKey(iterator.next()); mapState.values().forEach(this::addWatermarkHoldUsage); @@ -1571,20 +1704,24 @@ private void restoreWatermarkHoldsView() throws Exception { public static class EarlyBinder implements StateBinder { private final KeyedStateBackend keyedStateBackend; - private final SerializablePipelineOptions pipelineOptions; + private final Boolean fasterCopy; + private final FlinkStateNamespaceKeySerializer namespaceSerializer; public EarlyBinder( - KeyedStateBackend keyedStateBackend, SerializablePipelineOptions pipelineOptions) { + KeyedStateBackend keyedStateBackend, + SerializablePipelineOptions pipelineOptions, + Coder windowCoder) { this.keyedStateBackend = keyedStateBackend; - this.pipelineOptions = pipelineOptions; + this.fasterCopy = pipelineOptions.get().as(FlinkPipelineOptions.class).getFasterCopy(); + this.namespaceSerializer = new FlinkStateNamespaceKeySerializer(windowCoder); } @Override public ValueState bindValue(String id, StateSpec> spec, Coder coder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, pipelineOptions))); + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1596,8 +1733,8 @@ public ValueState bindValue(String id, StateSpec> spec, Cod public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, - new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, pipelineOptions))); + namespaceSerializer, + new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1609,11 +1746,9 @@ public BagState bindBag(String id, StateSpec> spec, Coder public SetState bindSet(String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new MapStateDescriptor<>( - id, - new CoderTypeSerializer<>(elemCoder, pipelineOptions), - BooleanSerializer.INSTANCE)); + id, new CoderTypeSerializer<>(elemCoder, fasterCopy), BooleanSerializer.INSTANCE)); } catch (Exception e) { throw new RuntimeException(e); } @@ -1628,11 +1763,11 @@ public org.apache.beam.sdk.state.MapState bindMap( Coder mapValueCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new MapStateDescriptor<>( id, - new CoderTypeSerializer<>(mapKeyCoder, pipelineOptions), - new CoderTypeSerializer<>(mapValueCoder, pipelineOptions))); + new CoderTypeSerializer<>(mapKeyCoder, fasterCopy), + new CoderTypeSerializer<>(mapValueCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1644,10 +1779,9 @@ public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, + namespaceSerializer, new ListStateDescriptor<>( - id, - new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), pipelineOptions))); + id, new CoderTypeSerializer<>(TimestampedValueCoder.of(elemCoder), fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1673,8 +1807,8 @@ public CombiningState bindCom Combine.CombineFn combineFn) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions))); + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1690,8 +1824,8 @@ CombiningState bindCombiningWithContext( CombineWithContext.CombineFnWithContext combineFn) { try { keyedStateBackend.getOrCreateKeyedState( - StringSerializer.INSTANCE, - new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions))); + namespaceSerializer, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } @@ -1707,7 +1841,7 @@ public WatermarkHoldState bindWatermark( new MapStateDescriptor<>( "watermark-holds", StringSerializer.INSTANCE, - new CoderTypeSerializer<>(InstantCoder.of(), pipelineOptions))); + new CoderTypeSerializer<>(InstantCoder.of(), fasterCopy))); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java index c20bd077c3f2..5d08beb938fd 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -99,7 +99,7 @@ public void testDefaults() { assertThat(options.getFasterCopy(), is(false)); assertThat(options.isStreaming(), is(false)); - assertThat(options.getMaxBundleSize(), is(1000000L)); + assertThat(options.getMaxBundleSize(), is(5000L)); assertThat(options.getMaxBundleTimeMills(), is(10000L)); // In streaming mode bundle size and bundle time are shorter @@ -139,7 +139,7 @@ public void parDoBaseClassPipelineOptionsSerializationTest() throws Exception { TupleTag mainTag = new TupleTag<>("main-output"); Coder> coder = WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new TestDoFn(), "stepName", @@ -161,7 +161,7 @@ mainTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults()) final byte[] serialized = SerializationUtils.serialize(doFnOperator); @SuppressWarnings("unchecked") - DoFnOperator deserialized = SerializationUtils.deserialize(serialized); + DoFnOperator deserialized = SerializationUtils.deserialize(serialized); TypeInformation> typeInformation = TypeInformation.of(new TypeHint>() {}); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java index cf860717def3..8e4c3255fac5 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkSubmissionTest.java @@ -26,6 +26,7 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.options.PipelineOptions; @@ -49,6 +50,8 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.rules.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** End-to-end submission test of Beam jobs on a Flink cluster. */ @SuppressWarnings({ @@ -56,6 +59,8 @@ }) public class FlinkSubmissionTest { + private static final Logger LOG = LoggerFactory.getLogger(FlinkSubmissionTest.class); + @ClassRule public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder(); private static final Map ENV = System.getenv(); private static final SecurityManager SECURITY_MANAGER = System.getSecurityManager(); @@ -66,14 +71,9 @@ public class FlinkSubmissionTest { /** Each test has a timeout of 60 seconds (for safety). */ @Rule public Timeout timeout = new Timeout(60, TimeUnit.SECONDS); - /** Whether to run in streaming or batch translation mode. */ - private static boolean streaming; - /** Counter which keeps track of the number of jobs submitted. */ private static int expectedNumberOfJobs; - public static boolean useDataStreamForBatch; - @BeforeClass public static void beforeClass() throws Exception { Configuration config = new Configuration(); @@ -103,38 +103,38 @@ public static void afterClass() throws Exception { @Test public void testSubmissionBatch() throws Exception { - runSubmission(false, false); + runSubmission(false, false, false); } @Test public void testSubmissionBatchUseDataStream() throws Exception { - FlinkSubmissionTest.useDataStreamForBatch = true; - runSubmission(false, false); + runSubmission(false, false, true); } @Test public void testSubmissionStreaming() throws Exception { - runSubmission(false, true); + runSubmission(false, true, false); } @Test public void testDetachedSubmissionBatch() throws Exception { - runSubmission(true, false); + runSubmission(true, false, false); } @Test public void testDetachedSubmissionBatchUseDataStream() throws Exception { - FlinkSubmissionTest.useDataStreamForBatch = true; - runSubmission(true, false); + runSubmission(true, false, true); } @Test public void testDetachedSubmissionStreaming() throws Exception { - runSubmission(true, true); + runSubmission(true, true, false); } - private void runSubmission(boolean isDetached, boolean isStreaming) throws Exception { + private void runSubmission(boolean isDetached, boolean isStreaming, boolean useDataStreamForBatch) + throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); + options.as(FlinkPipelineOptions.class).setStreaming(isStreaming); options.setTempLocation(TEMP_FOLDER.getRoot().getPath()); String jarPath = Iterables.getFirst( @@ -149,8 +149,16 @@ private void runSubmission(boolean isDetached, boolean isStreaming) throws Excep argsBuilder.add("-d"); } argsBuilder.add(jarPath); + argsBuilder.add("--runner=flink"); + + if (isStreaming) { + argsBuilder.add("--streaming"); + } + + if (useDataStreamForBatch) { + argsBuilder.add("--useDataStreamForBatch"); + } - FlinkSubmissionTest.streaming = isStreaming; FlinkSubmissionTest.expectedNumberOfJobs++; // Run end-to-end test CliFrontend.main(argsBuilder.build().toArray(new String[0])); @@ -168,7 +176,10 @@ private void waitUntilJobIsCompleted() throws Exception { Collection allJobsStates = flinkCluster.listJobs().get(); if (allJobsStates.size() == expectedNumberOfJobs && allJobsStates.stream() - .allMatch(jobStatus -> jobStatus.getJobState().name().equals("FINISHED"))) { + .allMatch(jobStatus -> jobStatus.getJobState().isTerminalState())) { + LOG.info( + "All job finished with statuses: {}", + allJobsStates.stream().map(j -> j.getJobState().name()).collect(Collectors.toList())); return; } Thread.sleep(50); @@ -177,10 +188,9 @@ private void waitUntilJobIsCompleted() throws Exception { /** The Flink program which is executed by the CliFrontend. */ public static void main(String[] args) { - FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); - options.setUseDataStreamForBatch(useDataStreamForBatch); + FlinkPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).withValidation().as(FlinkPipelineOptions.class); options.setRunner(FlinkRunner.class); - options.setStreaming(streaming); options.setParallelism(1); Pipeline p = Pipeline.create(options); p.apply(GenerateSequence.from(0).to(1)); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/FlinkKeyTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/FlinkKeyTest.java new file mode 100644 index 000000000000..649332c1e48f --- /dev/null +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/FlinkKeyTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.adapter; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.not; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; +import org.apache.flink.util.MathUtils; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Test; + +public class FlinkKeyTest { + @Test + public void testIsRecognizedAsValue() { + byte[] bs = "foobar".getBytes(StandardCharsets.UTF_8); + ByteBuffer buf = ByteBuffer.wrap(bs); + FlinkKey key = FlinkKey.of(buf); + TypeInformation tpe = TypeExtractor.getForObject(key); + + assertThat(tpe, IsInstanceOf.instanceOf(ValueTypeInfo.class)); + + TypeInformation> tupleTpe = + TypeExtractor.getForObject(Tuple2.of(key, bs)); + assertThat(tupleTpe, not(IsInstanceOf.instanceOf(GenericTypeInfo.class))); + } + + @Test + public void testIsConsistent() { + byte[] bs = "foobar".getBytes(StandardCharsets.UTF_8); + byte[] bs2 = "foobar".getBytes(StandardCharsets.UTF_8); + + FlinkKey key1 = FlinkKey.of(ByteBuffer.wrap(bs)); + FlinkKey key2 = FlinkKey.of(ByteBuffer.wrap(bs2)); + + assertThat(key1, equalTo(key2)); + assertThat(key1.hashCode(), equalTo(key2.hashCode())); + } + + private void checkDistribution(int numKeys) { + int paralellism = 2100; + + Set hashcodes = + IntStream.range(0, numKeys) + .mapToObj(i -> FlinkKey.of(i, VarIntCoder.of())) + .map(k -> k.hashCode()) + .collect(Collectors.toSet()); + + Set keyGroups = + hashcodes.stream() + .map(hash -> MathUtils.murmurHash(hash) % paralellism) + .collect(Collectors.toSet()); + + assertThat((double) hashcodes.size(), greaterThan(numKeys * 0.95)); + assertThat((double) keyGroups.size(), greaterThan(paralellism * 0.95)); + } + + @Test + public void testWillBeWellDistributedForSmallKeyGroups() { + checkDistribution(8192); + } + + @Test + public void testWillBeWellDistributedForLargeKeyGroups() { + checkDistribution(1000000); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index d0338ec3b0d3..2324a262acc0 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -30,6 +30,7 @@ import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -39,7 +40,7 @@ import org.apache.beam.sdk.util.CoderUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; @@ -64,10 +65,11 @@ public class FlinkStateInternalsTest extends StateInternalsTest { @Override protected StateInternals createStateInternals() { try { - KeyedStateBackend keyedStateBackend = createStateBackend(); + KeyedStateBackend keyedStateBackend = createStateBackend(); return new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); } catch (Exception e) { throw new RuntimeException(e); @@ -76,11 +78,12 @@ protected StateInternals createStateInternals() { @Test public void testWatermarkHoldsPersistence() throws Exception { - KeyedStateBackend keyedStateBackend = createStateBackend(); + KeyedStateBackend keyedStateBackend = createStateBackend(); FlinkStateInternals stateInternals = new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); StateTag stateTag = @@ -114,9 +117,9 @@ public void testWatermarkHoldsPersistence() throws Exception { assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); // Watermark hold should be computed across all keys - ByteBuffer firstKey = keyedStateBackend.getCurrentKey(); + FlinkKey firstKey = keyedStateBackend.getCurrentKey(); changeKey(keyedStateBackend); - ByteBuffer secondKey = keyedStateBackend.getCurrentKey(); + FlinkKey secondKey = keyedStateBackend.getCurrentKey(); assertThat(firstKey, is(Matchers.not(secondKey))); assertThat(stateInternals.minWatermarkHoldMs(), is(low.getMillis())); // ..but be tracked per key / window @@ -136,6 +139,7 @@ public void testWatermarkHoldsPersistence() throws Exception { new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); globalWindow = stateInternals.state(StateNamespaces.global(), stateTag); fixedWindow = @@ -168,11 +172,12 @@ public void testWatermarkHoldsPersistence() throws Exception { @Test public void testGlobalWindowWatermarkHoldClear() throws Exception { - KeyedStateBackend keyedStateBackend = createStateBackend(); + KeyedStateBackend keyedStateBackend = createStateBackend(); FlinkStateInternals stateInternals = new FlinkStateInternals<>( keyedStateBackend, StringUtf8Coder.of(), + IntervalWindow.getCoder(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); StateTag stateTag = StateTags.watermarkStateInternal("hold", TimestampCombiner.EARLIEST); @@ -183,13 +188,13 @@ public void testGlobalWindowWatermarkHoldClear() throws Exception { assertThat(state.read(), is((Instant) null)); } - public static KeyedStateBackend createStateBackend() throws Exception { - AbstractKeyedStateBackend keyedStateBackend = + public static KeyedStateBackend createStateBackend() throws Exception { + AbstractKeyedStateBackend keyedStateBackend = MemoryStateBackendWrapper.createKeyedStateBackend( new DummyEnvironment("test", 1, 0), new JobID(), "test_op", - new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), + new ValueTypeInfo<>(FlinkKey.class).createSerializer(new ExecutionConfig()), 2, new KeyGroupRange(0, 1), new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()), @@ -203,10 +208,11 @@ public static KeyedStateBackend createStateBackend() throws Exceptio return keyedStateBackend; } - private static void changeKey(KeyedStateBackend keyedStateBackend) + private static void changeKey(KeyedStateBackend keyedStateBackend) throws CoderException { keyedStateBackend.setCurrentKey( - ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), UUID.randomUUID().toString()))); + FlinkKey.of( + ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), UUID.randomUUID().toString())))); } } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 17cc16cc76e0..f0d8816bdeab 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -32,7 +32,6 @@ import com.fasterxml.jackson.databind.type.TypeFactory; import com.fasterxml.jackson.databind.util.LRUMap; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -52,8 +51,8 @@ import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.metrics.FlinkMetricContainer; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; @@ -96,6 +95,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; @@ -149,7 +149,7 @@ public void testSingleOutput() throws Exception { TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -211,7 +211,7 @@ public void testMultiOutputOutput() throws Exception { .put(additionalOutput2, 2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new MultiOutputDoFn(additionalOutput1, additionalOutput2), "stepName", @@ -348,12 +348,12 @@ public void onProcessingTime(OnTimerContext context) { WindowedValue.getFullCoder( StringUtf8Coder.of(), windowingStrategy.getWindowFn().windowCoder()); - KeySelector, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder); + KeySelector, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue(), keyCoder); TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -376,10 +376,7 @@ public void onProcessingTime(OnTimerContext context) { OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - doFnOperator, - keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + doFnOperator, keySelector, ValueTypeInfo.of(FlinkKey.class)); testHarness.setup( new CoderTypeSerializer<>( @@ -438,11 +435,12 @@ public void testWatermarkUpdateAfterWatermarkHoldRelease() throws Exception { TupleTag> outputTag = new TupleTag<>("main-output"); List emittedWatermarkHolds = new ArrayList<>(); - KeySelector>, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of()); - DoFnOperator, KV> doFnOperator = - new DoFnOperator, KV>( + KeySelector>, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue().getKey(), StringUtf8Coder.of()); + + DoFnOperator, KV, KV> doFnOperator = + new DoFnOperator, KV, KV>( new IdentityDoFn<>(), "stepName", coder, @@ -544,10 +542,7 @@ void emitWatermarkIfHoldChanged(long currentWatermarkHold) { WindowedValue>, WindowedValue>> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - doFnOperator, - keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + doFnOperator, keySelector, ValueTypeInfo.of(FlinkKey.class)); testHarness.setup(); @@ -611,12 +606,12 @@ public void processElement(ProcessContext context) { WindowedValue.getFullCoder( StringUtf8Coder.of(), windowingStrategy.getWindowFn().windowCoder()); - KeySelector, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder); + KeySelector, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue(), keyCoder); TupleTag outputTag = new TupleTag<>("main-output"); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -639,10 +634,7 @@ public void processElement(ProcessContext context) { OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - doFnOperator, - keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + doFnOperator, keySelector, ValueTypeInfo.of(FlinkKey.class)); testHarness.open(); @@ -693,7 +685,7 @@ public void testStateGCForStatefulFn() throws Exception { final int timerOutput = 4093; KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> + FlinkKey, WindowedValue>, WindowedValue>> testHarness = getHarness( windowingStrategy, @@ -758,7 +750,7 @@ public void testGCForGlobalWindow() throws Exception { WindowingStrategy windowingStrategy = WindowingStrategy.globalDefault(); KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> + FlinkKey, WindowedValue>, WindowedValue>> testHarness = getHarness(windowingStrategy, 5000, (window) -> new Instant(50), 4092); testHarness.open(); @@ -818,7 +810,7 @@ public void testGCForGlobalWindow() throws Exception { } private static KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> + FlinkKey, WindowedValue>, WindowedValue>> getHarness( WindowingStrategy windowingStrategy, int elementOffset, @@ -863,10 +855,13 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState TupleTag> outputTag = new TupleTag<>("main-output"); - KeySelector>, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue().getKey(), StringUtf8Coder.of()); + KeySelector>, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue().getKey(), StringUtf8Coder.of()); + + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setStreaming(true); - DoFnOperator, KV> doFnOperator = + DoFnOperator, KV, KV> doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -875,21 +870,18 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState outputTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>( - outputTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + outputTag, coder, new SerializablePipelineOptions(options)), windowingStrategy, new HashMap<>(), /* side-input mapping */ Collections.emptyList(), /* side inputs */ - FlinkPipelineOptions.defaults(), + options, StringUtf8Coder.of(), /* key coder */ keySelector, DoFnSchemaInformation.create(), Collections.emptyMap()); return new KeyedOneInputStreamOperatorTestHarness<>( - doFnOperator, - keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + doFnOperator, keySelector, ValueTypeInfo.of(FlinkKey.class)); } @Test @@ -912,12 +904,12 @@ void testSideInputs(boolean keyed) throws Exception { ImmutableMap.>builder().put(1, view1).put(2, view2).build(); Coder keyCoder = StringUtf8Coder.of(); - KeySelector, ByteBuffer> keySelector = null; + KeySelector, FlinkKey> keySelector = null; if (keyed) { - keySelector = value -> FlinkKeyUtils.encodeKey(value.getValue(), keyCoder); + keySelector = value -> FlinkKey.of(value.getValue(), keyCoder); } - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -943,11 +935,7 @@ outputTag, coder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults( // we use a dummy key for the second input since it is considered to be broadcast testHarness = new KeyedTwoInputStreamOperatorTestHarness<>( - doFnOperator, - keySelector, - null, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + doFnOperator, keySelector, null, ValueTypeInfo.of(FlinkKey.class)); } testHarness.open(); @@ -1036,16 +1024,14 @@ public void processElement( TupleTag> outputTag = new TupleTag<>("main-output"); StringUtf8Coder keyCoder = StringUtf8Coder.of(); - KvToByteBufferKeySelector keySelector = - new KvToByteBufferKeySelector<>(keyCoder, null); + KvToFlinkKeyKeySelector keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); + KvCoder coder = KvCoder.of(keyCoder, VarLongCoder.of()); FullWindowedValueCoder> kvCoder = WindowedValue.getFullCoder(coder, windowingStrategy.getWindowFn().windowCoder()); - CoderTypeInformation keyCoderInfo = - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults()); + TypeInformation keyCoderInfo = ValueTypeInfo.of(FlinkKey.class); OneInputStreamOperatorTestHarness< WindowedValue>, WindowedValue>> @@ -1115,7 +1101,7 @@ public void nonKeyedParDoSideInputCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1149,8 +1135,8 @@ public void keyedParDoSideInputCheckpointing() throws Exception { WindowedValue.getFullCoder(keyCoder, IntervalWindow.getCoder()); TupleTag outputTag = new TupleTag<>("main-output"); - KeySelector, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder); + KeySelector, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue(), keyCoder); ImmutableMap> sideInputMapping = ImmutableMap.>builder() @@ -1158,7 +1144,7 @@ public void keyedParDoSideInputCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1184,8 +1170,7 @@ public void keyedParDoSideInputCheckpointing() throws Exception { keySelector, // we use a dummy key for the second input since it is considered to be broadcast null, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + ValueTypeInfo.of(FlinkKey.class)); }); } @@ -1261,7 +1246,7 @@ public void nonKeyedParDoPushbackDataCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1296,8 +1281,8 @@ public void keyedParDoPushbackDataCheckpointing() throws Exception { TupleTag outputTag = new TupleTag<>("main-output"); - KeySelector, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder); + KeySelector, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue(), keyCoder); ImmutableMap> sideInputMapping = ImmutableMap.>builder() @@ -1305,7 +1290,7 @@ public void keyedParDoPushbackDataCheckpointing() throws Exception { .put(2, view2) .build(); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn<>(), "stepName", @@ -1331,8 +1316,7 @@ public void keyedParDoPushbackDataCheckpointing() throws Exception { keySelector, // we use a dummy key for the second input since it is considered to be broadcast null, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + ValueTypeInfo.of(FlinkKey.class)); }); } @@ -1434,11 +1418,10 @@ public void onEventTime(OnTimerContext context) { final CoderTypeSerializer> outputSerializer = new CoderTypeSerializer<>( outputCoder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); - CoderTypeInformation keyCoderInfo = - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults()); - KeySelector, ByteBuffer> keySelector = - e -> FlinkKeyUtils.encodeKey(e.getValue(), keyCoder); + TypeInformation keyCoderInfo = ValueTypeInfo.of(FlinkKey.class); + + KeySelector, FlinkKey> keySelector = + e -> FlinkKey.of(e.getValue(), keyCoder); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = createTestHarness( @@ -1504,7 +1487,7 @@ OneInputStreamOperatorTestHarness, WindowedValue> creat TypeInformation keyCoderInfo, KeySelector, K> keySelector) throws Exception { - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( fn, "stepName", @@ -1538,6 +1521,7 @@ public void testBundle() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(2L); options.setMaxBundleTimeMills(10L); + options.setStreaming(true); IdentityDoFn doFn = new IdentityDoFn() { @@ -1554,7 +1538,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1603,7 +1587,7 @@ public void finishBundle(FinishBundleContext context) { testHarness.close(); - DoFnOperator newDoFnOperator = + DoFnOperator newDoFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1669,9 +1653,7 @@ public void finishBundle(FinishBundleContext context) { public void testBundleKeyed() throws Exception { StringUtf8Coder keyCoder = StringUtf8Coder.of(); - KvToByteBufferKeySelector keySelector = - new KvToByteBufferKeySelector<>( - keyCoder, new SerializablePipelineOptions(FlinkPipelineOptions.defaults())); + KvToFlinkKeyKeySelector keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); KvCoder kvCoder = KvCoder.of(keyCoder, StringUtf8Coder.of()); WindowedValue.ValueOnlyWindowedValueCoder> windowedValueCoder = WindowedValue.getValueOnlyCoder(kvCoder); @@ -1680,6 +1662,7 @@ public void testBundleKeyed() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(2L); options.setMaxBundleTimeMills(10L); + options.setStreaming(true); DoFn, String> doFn = new DoFn, String>() { @@ -1702,7 +1685,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(kvCoder.getValueCoder(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator, String> doFnOperator = + DoFnOperator, KV, String> doFnOperator = new DoFnOperator<>( doFn, "stepName", @@ -1806,6 +1789,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setMaxBundleSize(10L); options.setCheckpointingInterval(1L); + options.setStreaming(true); TupleTag outputTag = new TupleTag<>("main-output"); @@ -1819,7 +1803,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( new IdentityDoFn<>(), @@ -1838,7 +1822,7 @@ public void testCheckpointBufferingWithMultipleBundles() throws Exception { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -1943,7 +1927,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( doFn, @@ -1962,7 +1946,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2054,7 +2038,7 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier> doFnOperatorSupplier = + Supplier> doFnOperatorSupplier = () -> new DoFnOperator<>( doFn, @@ -2073,7 +2057,7 @@ public void finishBundle(FinishBundleContext context) { DoFnSchemaInformation.create(), Collections.emptyMap()); - DoFnOperator doFnOperator = doFnOperatorSupplier.get(); + DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2116,8 +2100,7 @@ public void testExactlyOnceBufferingKeyed() throws Exception { TupleTag> outputTag = new TupleTag<>("main-output"); StringUtf8Coder keyCoder = StringUtf8Coder.of(); - KvToByteBufferKeySelector keySelector = - new KvToByteBufferKeySelector<>(keyCoder, new SerializablePipelineOptions(options)); + KvToFlinkKeyKeySelector keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); KvCoder kvCoder = KvCoder.of(keyCoder, StringUtf8Coder.of()); WindowedValue.ValueOnlyWindowedValueCoder> windowedValueCoder = WindowedValue.getValueOnlyCoder(kvCoder); @@ -2151,26 +2134,28 @@ public void finishBundle(FinishBundleContext context) { WindowedValue.getFullCoder(kvCoder, GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - Supplier, KV>> doFnOperatorSupplier = - () -> - new DoFnOperator<>( - doFn, - "stepName", - windowedValueCoder, - Collections.emptyMap(), - outputTag, - Collections.emptyList(), - outputManagerFactory, - WindowingStrategy.globalDefault(), - new HashMap<>(), /* side-input mapping */ - Collections.emptyList(), /* side inputs */ - options, - keyCoder, - keySelector, - DoFnSchemaInformation.create(), - Collections.emptyMap()); - - DoFnOperator, KV> doFnOperator = doFnOperatorSupplier.get(); + Supplier, KV, KV>> + doFnOperatorSupplier = + () -> + new DoFnOperator<>( + doFn, + "stepName", + windowedValueCoder, + Collections.emptyMap(), + outputTag, + Collections.emptyList(), + outputManagerFactory, + WindowingStrategy.globalDefault(), + new HashMap<>(), /* side-input mapping */ + Collections.emptyList(), /* side inputs */ + options, + keyCoder, + keySelector, + DoFnSchemaInformation.create(), + Collections.emptyMap()); + + DoFnOperator, KV, KV> doFnOperator = + doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness< WindowedValue>, WindowedValue>> testHarness = @@ -2199,7 +2184,7 @@ public void finishBundle(FinishBundleContext context) { assertThat(numStartBundleCalled, is(1)); assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), - contains( + containsInAnyOrder( WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), @@ -2220,7 +2205,7 @@ public void finishBundle(FinishBundleContext context) { assertThat(numStartBundleCalled, is(2)); assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), - contains( + containsInAnyOrder( WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), @@ -2233,7 +2218,7 @@ public void finishBundle(FinishBundleContext context) { assertThat(numStartBundleCalled, is(2)); assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), - contains( + containsInAnyOrder( WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), @@ -2246,8 +2231,7 @@ public void testFailOnRequiresStableInputAndDisabledCheckpointing() { TupleTag> outputTag = new TupleTag<>("main-output"); StringUtf8Coder keyCoder = StringUtf8Coder.of(); - KvToByteBufferKeySelector keySelector = - new KvToByteBufferKeySelector<>(keyCoder, null); + KvToFlinkKeyKeySelector keySelector = new KvToFlinkKeyKeySelector<>(keyCoder); KvCoder kvCoder = KvCoder.of(keyCoder, StringUtf8Coder.of()); WindowedValue.ValueOnlyWindowedValueCoder> windowedValueCoder = WindowedValue.getValueOnlyCoder(kvCoder); @@ -2307,7 +2291,7 @@ public void testBundleProcessingExceptionIsFatalDuringCheckpointing() throws Exc WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE), new SerializablePipelineOptions(options)); - DoFnOperator doFnOperator = + DoFnOperator doFnOperator = new DoFnOperator<>( new IdentityDoFn() { @FinishBundle @@ -2346,7 +2330,7 @@ public void finishBundle() { @Test public void testAccumulatorRegistrationOnOperatorClose() throws Exception { - DoFnOperator doFnOperator = getOperatorForCleanupInspection(); + DoFnOperator doFnOperator = getOperatorForCleanupInspection(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -2382,7 +2366,7 @@ public void testRemoveCachedClassReferences() throws Exception { assertThat(typeCache.size(), is(0)); } - private static DoFnOperator getOperatorForCleanupInspection() { + private static DoFnOperator getOperatorForCleanupInspection() { FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); options.setParallelism(4); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java index 2eb0545b7794..a0a955aea1d6 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java @@ -63,6 +63,7 @@ import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate; import org.apache.beam.runners.flink.streaming.FlinkStateInternalsTest; import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory; @@ -109,6 +110,7 @@ import org.apache.flink.api.common.cache.DistributedCache; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -683,7 +685,7 @@ public void testEnsureStateCleanupWithKeyedInputCleanupTimer() { cleanupTimer.setForWindow(KV.of("key", "string"), window); Mockito.verify(stateBackendLock).lock(); - ByteBuffer key = FlinkKeyUtils.encodeKey("key", keyCoder); + FlinkKey key = FlinkKey.of("key", keyCoder); Mockito.verify(keyedStateBackend).setCurrentKey(key); assertThat( inMemoryTimerInternals.getNextTimer(TimeDomain.EVENT_TIME), @@ -707,9 +709,9 @@ public void testEnsureStateCleanupWithKeyedInputStateCleaner() throws Exception } ImmutableList> bagStates = bagStateBuilder.build(); - MutableObject key = + MutableObject key = new MutableObject<>( - ByteBuffer.wrap(stateInternals.getKey().getBytes(StandardCharsets.UTF_8))); + FlinkKey.of(ByteBuffer.wrap(stateInternals.getKey().getBytes(StandardCharsets.UTF_8)))); // Test that state is cleaned up correctly ExecutableStageDoFnOperator.StateCleaner stateCleaner = @@ -786,21 +788,18 @@ private void testEnsureDeferredStateCleanupTimerFiring(boolean withCheckpointing when(bundle.getTimerReceivers()).thenReturn(ImmutableMap.of(timerInputKey, timerReceiver)); KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue> + FlinkKey, WindowedValue>, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness( - operator, - operator.keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + operator, operator.keySelector, ValueTypeInfo.of(FlinkKey.class)); testHarness.open(); Lock stateBackendLock = Whitebox.getInternalState(operator, "stateBackendLock"); stateBackendLock.lock(); - KeyedStateBackend keyedStateBackend = operator.getKeyedStateBackend(); - ByteBuffer key = FlinkKeyUtils.encodeKey(windowedValue.getValue().getKey(), keyCoder); + KeyedStateBackend keyedStateBackend = operator.getKeyedStateBackend(); + FlinkKey key = FlinkKey.of(windowedValue.getValue().getKey(), keyCoder); keyedStateBackend.setCurrentKey(key); DoFnOperator.FlinkTimerInternals timerInternals = @@ -937,13 +936,10 @@ public void testEnsureStateCleanupOnFinalWatermark() throws Exception { WindowedValue.getFullCoder(kvCoder, windowCoder)); KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue> + FlinkKey, WindowedValue>, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness( - operator, - operator.keySelector, - new CoderTypeInformation<>( - FlinkKeyUtils.ByteBufferCoder.of(), FlinkPipelineOptions.defaults())); + operator, operator.keySelector, ValueTypeInfo.of(FlinkKey.class)); RemoteBundle bundle = Mockito.mock(RemoteBundle.class); when(bundle.getInputReceivers()) @@ -955,8 +951,8 @@ public void testEnsureStateCleanupOnFinalWatermark() throws Exception { testHarness.open(); - KeyedStateBackend keyedStateBackend = operator.getKeyedStateBackend(); - ByteBuffer key = FlinkKeyUtils.encodeKey("key1", keyCoder); + KeyedStateBackend keyedStateBackend = operator.getKeyedStateBackend(); + FlinkKey key = FlinkKey.of("key1", keyCoder); keyedStateBackend.setCurrentKey(key); // create some state which can be cleaned up @@ -981,7 +977,7 @@ public void testEnsureStateCleanupOnFinalWatermark() throws Exception { @Test public void testCacheTokenHandling() throws Exception { InMemoryStateInternals test = InMemoryStateInternals.forKey("test"); - KeyedStateBackend stateBackend = FlinkStateInternalsTest.createStateBackend(); + KeyedStateBackend stateBackend = FlinkStateInternalsTest.createStateBackend(); ExecutableStageDoFnOperator.BagUserStateFactory bagUserStateFactory = new ExecutableStageDoFnOperator.BagUserStateFactory<>( @@ -1254,7 +1250,7 @@ private ExecutableStageDoFnOperator getOperator( createOutputMap(mainOutput, additionalOutputs), windowingStrategy, keyCoder, - keyCoder != null ? new KvToByteBufferKeySelector<>(keyCoder, null) : null); + keyCoder != null ? new KvToFlinkKeyKeySelector<>(keyCoder) : null); Whitebox.setInternalState(operator, "stateRequestHandler", stateRequestHandler); return operator; diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java index 8fab1bc6c167..6380108ddb94 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperatorTest.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.flink.adapter.FlinkKey; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator.MultiOutputOutputManagerFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; @@ -52,7 +53,7 @@ import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; @@ -73,8 +74,8 @@ public class WindowDoFnOperatorTest { public void testRestore() throws Exception { // test harness KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> - testHarness = createTestHarness(getWindowDoFnOperator()); + FlinkKey, WindowedValue>, WindowedValue>> + testHarness = createTestHarness(getWindowDoFnOperator(true)); testHarness.open(); // process elements @@ -92,7 +93,7 @@ public void testRestore() throws Exception { testHarness.close(); // restore from the snapshot - testHarness = createTestHarness(getWindowDoFnOperator()); + testHarness = createTestHarness(getWindowDoFnOperator(true)); testHarness.initializeState(snapshot); testHarness.open(); @@ -123,14 +124,14 @@ public void testRestore() throws Exception { @Test public void testTimerCleanupOfPendingTimerList() throws Exception { // test harness - WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(); + WindowDoFnOperator windowDoFnOperator = getWindowDoFnOperator(true); KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> + FlinkKey, WindowedValue>, WindowedValue>> testHarness = createTestHarness(windowDoFnOperator); testHarness.open(); - DoFnOperator, KV>.FlinkTimerInternals timerInternals = - windowDoFnOperator.timerInternals; + DoFnOperator, KeyedWorkItem, KV>.FlinkTimerInternals + timerInternals = windowDoFnOperator.timerInternals; // process elements IntervalWindow window = new IntervalWindow(new Instant(0), Duration.millis(100)); @@ -195,7 +196,7 @@ public void testTimerCleanupOfPendingTimerList() throws Exception { testHarness.close(); } - private WindowDoFnOperator getWindowDoFnOperator() { + private WindowDoFnOperator getWindowDoFnOperator(boolean streaming) { WindowingStrategy windowingStrategy = WindowingStrategy.of(FixedWindows.of(standardMinutes(1))); @@ -217,6 +218,9 @@ private WindowDoFnOperator getWindowDoFnOperator() { FullWindowedValueCoder> outputCoder = WindowedValue.getFullCoder(KvCoder.of(VarLongCoder.of(), VarLongCoder.of()), windowCoder); + FlinkPipelineOptions options = FlinkPipelineOptions.defaults(); + options.setStreaming(streaming); + return new WindowDoFnOperator( reduceFn, "stepName", @@ -224,31 +228,28 @@ private WindowDoFnOperator getWindowDoFnOperator() { outputTag, emptyList(), new MultiOutputOutputManagerFactory<>( - outputTag, - outputCoder, - new SerializablePipelineOptions(FlinkPipelineOptions.defaults())), + outputTag, outputCoder, new SerializablePipelineOptions(options)), windowingStrategy, emptyMap(), emptyList(), - FlinkPipelineOptions.defaults(), + options, VarLongCoder.of(), - new WorkItemKeySelector( - VarLongCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()))); + new WorkItemKeySelector(VarLongCoder.of())); } private KeyedOneInputStreamOperatorTestHarness< - ByteBuffer, WindowedValue>, WindowedValue>> + FlinkKey, WindowedValue>, WindowedValue>> createTestHarness(WindowDoFnOperator windowDoFnOperator) throws Exception { return new KeyedOneInputStreamOperatorTestHarness<>( windowDoFnOperator, - (KeySelector>, ByteBuffer>) + (KeySelector>, FlinkKey>) o -> { try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - VarLongCoder.of().encode(o.getValue().key(), baos); - return ByteBuffer.wrap(baos.toByteArray()); + VarLongCoder.of().encode(o.getValue().getKey(), baos); + return FlinkKey.of(ByteBuffer.wrap(baos.toByteArray())); } }, - new GenericTypeInfo<>(ByteBuffer.class)); + ValueTypeInfo.of(FlinkKey.class)); } private static class Item { @@ -262,11 +263,9 @@ static ItemBuilder builder() { private long timestamp; private IntervalWindow window; - StreamRecord>> toStreamRecord() { - WindowedValue item = WindowedValue.of(value, new Instant(timestamp), window, NO_FIRING); - WindowedValue> keyedItem = - WindowedValue.of( - new SingletonKeyedWorkItem<>(key, item), new Instant(timestamp), window, NO_FIRING); + StreamRecord>> toStreamRecord() { + WindowedValue> keyedItem = + WindowedValue.of(KV.of(key, value), new Instant(timestamp), window, NO_FIRING); return new StreamRecord<>(keyedItem); } diff --git a/runners/flink/src/test/validatesRunnerConfig/flink-conf.yaml b/runners/flink/src/test/validatesRunnerConfig/flink-conf.yaml new file mode 100644 index 000000000000..3b075a095721 --- /dev/null +++ b/runners/flink/src/test/validatesRunnerConfig/flink-conf.yaml @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +#parallelism.default: 23 +taskmanager.memory.network.fraction: 0.2 +taskmanager.memory.network.max: 1gb +pipeline.operator-chaining.enabled: false diff --git a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html index 60f1fd39bd13..a3526d7d0d28 100644 --- a/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_java_pipeline_options.html @@ -107,6 +107,11 @@ Address of the Flink Master where the Pipeline should be executed. Can either be of the form "host:port" or one of the special values [local], [collection] or [auto]. Default: [auto] + + forceSlotSharingGroup + Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API. + Default: true + forceUnalignedCheckpointEnabled Forces unaligned checkpoints, particularly allowing them for iterative jobs. diff --git a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html index 4faad5a994ba..183dacfd5a09 100644 --- a/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html +++ b/website/www/site/layouts/shortcodes/flink_python_pipeline_options.html @@ -107,6 +107,11 @@ Address of the Flink Master where the Pipeline should be executed. Can either be of the form "host:port" or one of the special values [local], [collection] or [auto]. Default: [auto] + + force_slot_sharing_group + Set a slot sharing group for all bounded sources. This is required when using Datastream to have the same scheduling behaviour as the Dataset API. + Default: true + force_unaligned_checkpoint_enabled Forces unaligned checkpoints, particularly allowing them for iterative jobs.