diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json index 96e098eb7f97..c98ca2b07f9d 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should run this test", - "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test" + "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test", + "https://github.com/apache/beam/pull/33318": "noting that PR #33318 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index 96e098eb7f97..c98ca2b07f9d 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should run this test", - "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test" + "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should run this test", + "https://github.com/apache/beam/pull/33318": "noting that PR #33318 should run this test" } diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 811a3c15f836..aeace769b4c1 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -462,7 +462,8 @@ task validatesRunnerStreaming { description "Validates Dataflow runner forcing streaming mode" dependsOn(createLegacyWorkerValidatesRunnerTest( name: 'validatesRunnerLegacyWorkerTestStreaming', - pipelineOptions: legacyPipelineOptions + ['--streaming'], + pipelineOptions: legacyPipelineOptions + ['--streaming'] + + ['--experiments=enable_gbk_state_multiplexing'], excludedCategories: [ 'org.apache.beam.sdk.testing.UsesCommittedMetrics', 'org.apache.beam.sdk.testing.UsesMapState', diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index c01096716c97..a548cbf63eba 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -935,7 +935,20 @@ private void dataflowGroupByKeyHelper( stepContext.addOutput(PropertyNames.OUTPUT, context.getOutput(transform)); WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - stepContext.addInput(PropertyNames.DISALLOW_COMBINER_LIFTING, true); + boolean isStreaming = + context.getPipelineOptions().as(StreamingOptions.class).isStreaming(); + // :TODO do we set this for batch? + boolean allowCombinerLifting = false; + if (isStreaming) { + allowCombinerLifting = + !windowingStrategy.needsMerge() + && windowingStrategy.getWindowFn().assignsToOneWindow(); + allowCombinerLifting &= transform.fewKeys(); + // TODO: Allow combiner lifting on the non-default trigger, as appropriate. + allowCombinerLifting &= (windowingStrategy.getTrigger() instanceof DefaultTrigger); + } + stepContext.addInput(PropertyNames.DISALLOW_COMBINER_LIFTING, !allowCombinerLifting); + stepContext.addInput( PropertyNames.SERIALIZED_FN, byteArrayToJsonString( diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 9f41ea138bd5..676ceb495c21 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -53,6 +53,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.SortedSet; @@ -64,6 +65,7 @@ import org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification; import org.apache.beam.runners.dataflow.StreamingViewOverrides.StreamingCreatePCollectionViewFactory; import org.apache.beam.runners.dataflow.TransformTranslator.StepTranslationContext; +import org.apache.beam.runners.dataflow.internal.StateMultiplexingGroupByKey; import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; @@ -215,8 +217,6 @@ public class DataflowRunner extends PipelineRunner { "unsafely_attempt_to_process_unbounded_data_in_batch_mode"; private static final Logger LOG = LoggerFactory.getLogger(DataflowRunner.class); - private static final String EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING = - "enable_gbk_state_multiplexing"; /** Provided configuration options. */ private final DataflowPipelineOptions options; @@ -801,11 +801,12 @@ private List getOverrides(boolean streaming) { new RedistributeByKeyOverrideFactory())); if (streaming) { - if (DataflowRunner.hasExperiment(options, EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING)) { + if (DataflowRunner.hasExperiment( + options, StateMultiplexingGroupByKey.EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING)) { overridesBuilder.add( PTransformOverride.of( PTransformMatchers.classEqualTo(GroupByKey.class), - new StateMultiplexingGroupByKeyOverrideFactory<>())); + new StateMultiplexingGroupByKeyOverrideFactory<>(options))); } // For update compatibility, always use a Read for Create in streaming mode. overridesBuilder @@ -1714,6 +1715,22 @@ public static boolean hasExperiment(DataflowPipelineDebugOptions options, String return experiments.contains(experiment); } + /** Return the value for the specified experiment or null if not present. */ + public static Optional getExperimentValue( + DataflowPipelineDebugOptions options, String experiment) { + List experiments = options.getExperiments(); + if (experiments == null) { + return Optional.empty(); + } + String prefix = experiment + "="; + for (String experimentEntry : experiments) { + if (experimentEntry.startsWith(prefix)) { + return Optional.of(experimentEntry.substring(prefix.length())); + } + } + return Optional.empty(); + } + /** Helper to configure the Dataflow Job Environment based on the user's job options. */ private static Map getEnvironmentVersion(DataflowPipelineOptions options) { DataflowRunnerInfo runnerInfo = DataflowRunnerInfo.getDataflowRunnerInfo(); diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java index 468a0a95d77c..93c97a547fc1 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow; import org.apache.beam.runners.dataflow.internal.StateMultiplexingGroupByKey; +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.util.construction.PTransformReplacements; @@ -28,6 +29,11 @@ class StateMultiplexingGroupByKeyOverrideFactory extends SingleInputOutputOverrideFactory< PCollection>, PCollection>>, GroupByKey> { + private final DataflowPipelineOptions options; + + StateMultiplexingGroupByKeyOverrideFactory(DataflowPipelineOptions options) { + this.options = options; + } @Override public PTransformReplacement>, PCollection>>> @@ -37,6 +43,6 @@ class StateMultiplexingGroupByKeyOverrideFactory transform) { return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - StateMultiplexingGroupByKey.create(transform.getTransform().fewKeys())); + StateMultiplexingGroupByKey.create(options, transform.getTransform().fewKeys())); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/DataflowGroupByKey.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/DataflowGroupByKey.java index 89135641689e..811eb888379b 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/DataflowGroupByKey.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/DataflowGroupByKey.java @@ -47,8 +47,11 @@ public class DataflowGroupByKey // Plumbed from Redistribute transform. private final boolean allowDuplicates; - private DataflowGroupByKey(boolean allowDuplicates) { + private final boolean fewKeys; + + private DataflowGroupByKey(boolean allowDuplicates, boolean fewKeys) { this.allowDuplicates = allowDuplicates; + this.fewKeys = fewKeys; } /** @@ -59,7 +62,11 @@ private DataflowGroupByKey(boolean allowDuplicates) { * {@code Iterable}s in the output {@code PCollection} */ public static DataflowGroupByKey create() { - return new DataflowGroupByKey<>(false); + return new DataflowGroupByKey<>(/*allowDuplicates=*/ false, /*fewKeys=*/ false); + } + + static DataflowGroupByKey createWithFewKeys() { + return new DataflowGroupByKey<>(/*allowDuplicates=*/ false, /*fewKeys=*/ true); } /** @@ -71,7 +78,7 @@ public static DataflowGroupByKey create() { * {@code Iterable}s in the output {@code PCollection} */ public static DataflowGroupByKey createWithAllowDuplicates() { - return new DataflowGroupByKey<>(true); + return new DataflowGroupByKey<>(/*allowDuplicates=*/ true, /*fewKeys=*/ false); } /** Returns whether it allows duplicated elements in the output. */ @@ -79,6 +86,10 @@ public boolean allowDuplicates() { return allowDuplicates; } + /** Returns whether it groups just few keys. */ + public boolean fewKeys() { + return fewKeys; + } ///////////////////////////////////////////////////////////////////////////// public static void applicableTo(PCollection input) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/KeyedWindow.java similarity index 89% rename from sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java rename to runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/KeyedWindow.java index c0e9e513afda..1f7f568c142f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/KeyedWindow.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.transforms.windowing; +package org.apache.beam.runners.dataflow.internal; import java.io.IOException; import java.io.InputStream; @@ -27,15 +27,18 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.stream.Collectors; -import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.runners.dataflow.util.ByteStringCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IncompatibleWindowException; +import org.apache.beam.sdk.transforms.windowing.NonMergingWindowFn; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.Preconditions; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -214,7 +217,7 @@ public static class KeyedWindowCoder extends Coder coder; public KeyedWindowCoder(Coder windowCoder) { - //:TODO consider swapping the order for improved state locality + // :TODO consider swapping the order for improved state locality this.coder = KvCoder.of(ByteStringCoder.of(), windowCoder); } @@ -244,26 +247,4 @@ public boolean consistentWithEquals() { return coder.getValueCoder().consistentWithEquals(); } } - - public static class ByteStringCoder extends AtomicCoder { - public static ByteStringCoder of() { - return INSTANCE; - } - - private static final ByteStringCoder INSTANCE = new ByteStringCoder(); - - private ByteStringCoder() {} - - @Override - public void encode(ByteString value, OutputStream os) throws IOException { - VarInt.encode(value.size(), os); - value.writeTo(os); - } - - @Override - public ByteString decode(InputStream is) throws IOException { - int size = VarInt.decodeInt(is); - return ByteString.readFrom(ByteStreams.limit(is, size), size); - } - } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java index 661652ce3453..487032dd977d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java @@ -20,12 +20,13 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.Map; +import org.apache.beam.runners.dataflow.DataflowRunner; +import org.apache.beam.runners.dataflow.internal.KeyedWindow.KeyedWindowFn; +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.runners.dataflow.util.ByteStringCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.NonDeterministicException; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.MapElements; @@ -33,20 +34,13 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate; -import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.KeyedWindow; -import org.apache.beam.sdk.transforms.windowing.KeyedWindow.KeyedWindowFn; -import org.apache.beam.sdk.transforms.windowing.Never.NeverTrigger; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; @@ -54,9 +48,6 @@ import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.Preconditions; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; -import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString.Output; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; -import org.checkerframework.checker.nullness.qual.Nullable; /** * A GroupByKey implementation that multiplexes many small user keys over a fixed set of sharding @@ -65,142 +56,97 @@ public class StateMultiplexingGroupByKey extends PTransform>, PCollection>>> { + public static final String EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING = + "enable_gbk_state_multiplexing"; + private static final String EXPERIMENT_NUM_VIRTUAL_KEYS = "gbk_state_multiplexing_num_keys"; + private static final String EXPERIMENT_SMALL_KEY_BYTES_THRESHOLD = + "gbk_state_multiplexing_small_key_bytes"; + /* * Keys larger than this threshold will not be multiplexed. */ - private static final int SMALL_KEY_BYTES_THRESHOLD = 4096; + private static final int DEFAULT_SMALL_KEY_BYTES_THRESHOLD = 4096; + private static final int DEFAULT_NUM_VIRTUAL_KEYS = 32 << 10; private final boolean fewKeys; - private final int numShardingKeys; + private final int numVirtualKeys; + private final int smallKeyBytesThreshold; - private StateMultiplexingGroupByKey(boolean fewKeys) { - // :TODO plumb fewKeys to DataflowGroupByKey + private StateMultiplexingGroupByKey(DataflowPipelineOptions options, boolean fewKeys) { this.fewKeys = fewKeys; - // :TODO Make this configurable - this.numShardingKeys = 32 << 10; + this.numVirtualKeys = + getExperimentValue(options, EXPERIMENT_NUM_VIRTUAL_KEYS, DEFAULT_NUM_VIRTUAL_KEYS); + this.smallKeyBytesThreshold = + getExperimentValue( + options, EXPERIMENT_SMALL_KEY_BYTES_THRESHOLD, DEFAULT_SMALL_KEY_BYTES_THRESHOLD); + } + + private static int getExperimentValue( + DataflowPipelineOptions options, String experiment, int defaultValue) { + return DataflowRunner.getExperimentValue(options, experiment) + .map(Integer::parseInt) + .orElse(defaultValue); } /** - * Returns a {@code GroupByKey} {@code PTransform}. + * Returns a {@code StateMultiplexingGroupByKey} {@code PTransform}. * * @param the type of the keys of the input and output {@code PCollection}s * @param the type of the values of the input {@code PCollection} and the elements of the * {@code Iterable}s in the output {@code PCollection} */ - public static StateMultiplexingGroupByKey create(boolean fewKeys) { - return new StateMultiplexingGroupByKey<>(fewKeys); - } - - ///////////////////////////////////////////////////////////////////////////// - - public static void applicableTo(PCollection input) { - WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - // Verify that the input PCollection is bounded, or that there is windowing/triggering being - // used. Without this, the watermark (at end of global window) will never be reached. - if (windowingStrategy.getWindowFn() instanceof GlobalWindows - && windowingStrategy.getTrigger() instanceof DefaultTrigger - && input.isBounded() != IsBounded.BOUNDED) { - throw new IllegalStateException( - "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without a" - + " trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); - } - - // Validate that the trigger does not finish before garbage collection time - if (!triggerIsSafe(windowingStrategy)) { - throw new IllegalArgumentException( - String.format( - "Unsafe trigger '%s' may lose data, did you mean to wrap it in" - + "`Repeatedly.forever(...)`?%nSee " - + "https://s.apache.org/finishing-triggers-drop-data " - + "for details.", - windowingStrategy.getTrigger())); - } - } - - @Override - public void validate( - @Nullable PipelineOptions options, - Map, PCollection> inputs, - Map, PCollection> outputs) { - PCollection input = Iterables.getOnlyElement(inputs.values()); - KvCoder inputCoder = getInputKvCoder(input.getCoder()); - - // Ensure that the output coder key and value types aren't different. - Coder outputCoder = Iterables.getOnlyElement(outputs.values()).getCoder(); - KvCoder expectedOutputCoder = getOutputKvCoder(inputCoder); - if (!expectedOutputCoder.equals(outputCoder)) { - throw new IllegalStateException( - String.format( - "the GroupByKey requires its output coder to be %s but found %s.", - expectedOutputCoder, outputCoder)); - } - } - - // Note that Never trigger finishes *at* GC time so it is OK, and - // AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is - // OK if there is no allowed lateness. - private static boolean triggerIsSafe(WindowingStrategy windowingStrategy) { - if (!windowingStrategy.getTrigger().mayFinish()) { - return true; - } - - if (windowingStrategy.getTrigger() instanceof NeverTrigger) { - return true; - } - - if (windowingStrategy.getTrigger() instanceof FromEndOfWindow - && windowingStrategy.getAllowedLateness().getMillis() == 0) { - return true; - } - - if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate - && windowingStrategy.getAllowedLateness().getMillis() == 0) { - return true; - } - - if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate - && ((AfterWatermarkEarlyAndLate) windowingStrategy.getTrigger()).getLateTrigger() != null) { - return true; - } - - return false; + public static StateMultiplexingGroupByKey create( + DataflowPipelineOptions options, boolean fewKeys) { + return new StateMultiplexingGroupByKey<>(options, fewKeys); } @Override public PCollection>> expand(PCollection> input) { - applicableTo(input); + DataflowGroupByKey.applicableTo(input); // Verify that the input Coder> is a KvCoder, and that // the key coder is deterministic. - Coder keyCoder = getKeyCoder(input.getCoder()); - Coder valueCoder = getInputValueCoder(input.getCoder()); - KvCoder> outputKvCoder = getOutputKvCoder(input.getCoder()); - + Coder keyCoder = DataflowGroupByKey.getKeyCoder(input.getCoder()); try { keyCoder.verifyDeterministic(); } catch (NonDeterministicException e) { throw new IllegalStateException("the keyCoder of a GroupByKey must be deterministic", e); } - Preconditions.checkArgument(numShardingKeys > 0); - final TupleTag> largeKeys = new TupleTag>() {}; - final TupleTag> smallKeys = new TupleTag>() {}; + Coder valueCoder = DataflowGroupByKey.getInputValueCoder(input.getCoder()); + KvCoder> outputKvCoder = DataflowGroupByKey.getOutputKvCoder(input.getCoder()); + + Preconditions.checkArgument(numVirtualKeys > 0); + final TupleTag> largeKeys = new TupleTag>() { + }; + final TupleTag> smallKeys = new TupleTag>() { + }; WindowingStrategy originalWindowingStrategy = input.getWindowingStrategy(); + WindowFn originalWindowFn = originalWindowingStrategy.getWindowFn(); PCollectionTuple mapKeysToBytes = input.apply( "MapKeysToBytes", ParDo.of( new DoFn, KV>() { + transient ByteStringOutputStream byteStringOutputStream; + + @StartBundle + public void setup() { + byteStringOutputStream = new ByteStringOutputStream(); + } + @ProcessElement public void processElement(ProcessContext c) { KV kv = c.element(); - Output output = ByteString.newOutput(); try { - keyCoder.encode(kv.getKey(), output); + // clear output stream + byteStringOutputStream.toByteStringAndReset(); + keyCoder.encode(kv.getKey(), byteStringOutputStream); } catch (IOException e) { throw new RuntimeException(e); } - KV outputKV = KV.of(output.toByteString(), kv.getValue()); - if (outputKV.getKey().size() <= SMALL_KEY_BYTES_THRESHOLD) { + ByteString keyBytes = byteStringOutputStream.toByteStringAndReset(); + KV outputKV = KV.of(keyBytes, kv.getValue()); + if (keyBytes.size() <= smallKeyBytesThreshold) { c.output(smallKeys, outputKV); } else { c.output(largeKeys, outputKV); @@ -209,11 +155,12 @@ public void processElement(ProcessContext c) { }) .withOutputTags(largeKeys, TupleTagList.of(smallKeys))); + // Pass large keys as it is through DataflowGroupByKey PCollection>> largeKeyBranch = mapKeysToBytes .get(largeKeys) - .setCoder(KvCoder.of(KeyedWindow.ByteStringCoder.of(), valueCoder)) - .apply(DataflowGroupByKey.create()) + .setCoder(KvCoder.of(ByteStringCoder.of(), valueCoder)) + .apply(fewKeys ? DataflowGroupByKey.createWithFewKeys() : DataflowGroupByKey.create()) .apply( "DecodeKey", MapElements.via( @@ -229,30 +176,31 @@ public KV> apply(KV> kv) { })) .setCoder(outputKvCoder); - WindowFn windowFn = originalWindowingStrategy.getWindowFn(); + // Multiplex small keys over `numShardingKeys` virtual keys. + // Original user keys are sent as part of windows. + // After GroupByKey the original keys are restored from windows. PCollection>> smallKeyBranch = mapKeysToBytes .get(smallKeys) - .apply(Window.into(new KeyedWindowFn<>(windowFn))) + .apply(Window.into(new KeyedWindowFn<>(originalWindowFn))) .apply( - "MapKeys", + "MapKeysToVirtualKeys", MapElements.via( new SimpleFunction, KV>() { @Override public KV apply(KV value) { - return KV.of(value.getKey().hashCode() % numShardingKeys, value.getValue()); + return KV.of(value.getKey().hashCode() % numVirtualKeys, value.getValue()); } })) - .apply(DataflowGroupByKey.create()) + .apply(fewKeys ? DataflowGroupByKey.createWithFewKeys() : DataflowGroupByKey.create()) .apply( - "Restore Keys", + "RestoreOriginalKeys", ParDo.of( new DoFn>, KV>>() { @ProcessElement public void processElement(ProcessContext c, BoundedWindow w, PaneInfo pane) { ByteString key = ((KeyedWindow) w).getKey(); try { - // is it correct to use the pane from Keyed window here? c.outputWindowedValue( KV.of(keyCoder.decode(key.newInput()), c.element().getValue()), @@ -270,47 +218,14 @@ public void processElement(ProcessContext c, BoundedWindow w, PaneInfo pane) { .apply(Flatten.pCollections()); } - /** - * Returns the {@code Coder} of the input to this transform, which should be a {@code KvCoder}. - */ - @SuppressWarnings("unchecked") - static KvCoder getInputKvCoder(Coder inputCoder) { - if (!(inputCoder instanceof KvCoder)) { - throw new IllegalStateException("GroupByKey requires its input to use KvCoder"); - } - return (KvCoder) inputCoder; - } - - ///////////////////////////////////////////////////////////////////////////// - - /** - * Returns the {@code Coder} of the keys of the input to this transform, which is also used as the - * {@code Coder} of the keys of the output of this transform. - */ - public static Coder getKeyCoder(Coder> inputCoder) { - return StateMultiplexingGroupByKey.getInputKvCoder(inputCoder).getKeyCoder(); - } - - /** Returns the {@code Coder} of the values of the input to this transform. */ - public static Coder getInputValueCoder(Coder> inputCoder) { - return StateMultiplexingGroupByKey.getInputKvCoder(inputCoder).getValueCoder(); - } - - /** Returns the {@code Coder} of the {@code Iterable} values of the output of this transform. */ - static Coder> getOutputValueCoder(Coder> inputCoder) { - return IterableCoder.of(getInputValueCoder(inputCoder)); - } - - /** Returns the {@code Coder} of the output of this transform. */ - public static KvCoder> getOutputKvCoder(Coder> inputCoder) { - return KvCoder.of(getKeyCoder(inputCoder), getOutputValueCoder(inputCoder)); - } - @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); if (fewKeys) { builder.add(DisplayData.item("fewKeys", true).withLabel("Has Few Keys")); } + builder.add(DisplayData.item("numVirtualKeys", numVirtualKeys).withLabel("Num Virtual Keys")); + builder.add(DisplayData.item("smallKeyBytesThreshold", smallKeyBytesThreshold) + .withLabel("Small Key Bytes Threshold")); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteStringCoder.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/ByteStringCoder.java similarity index 97% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteStringCoder.java rename to runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/ByteStringCoder.java index b9b1b45902c8..0c1a80aa2ace 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteStringCoder.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/ByteStringCoder.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.dataflow.worker; +package org.apache.beam.runners.dataflow.util; import java.io.IOException; import java.io.InputStream;