diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 676ceb495c21..a7b080e87841 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -119,7 +119,6 @@ import org.apache.beam.sdk.transforms.Combine.GroupedValues; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.PTransform; @@ -805,7 +804,7 @@ private List getOverrides(boolean streaming) { options, StateMultiplexingGroupByKey.EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING)) { overridesBuilder.add( PTransformOverride.of( - PTransformMatchers.classEqualTo(GroupByKey.class), + StateMultiplexingGroupByKeyTransformMatcher.getInstance(), new StateMultiplexingGroupByKeyOverrideFactory<>(options))); } // For update compatibility, always use a Read for Create in streaming mode. diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyTransformMatcher.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyTransformMatcher.java new file mode 100644 index 000000000000..06953a0928d5 --- /dev/null +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyTransformMatcher.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformMatcher; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.values.PCollection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class StateMultiplexingGroupByKeyTransformMatcher implements PTransformMatcher { + + private static final Logger LOG = + LoggerFactory.getLogger(StateMultiplexingGroupByKeyTransformMatcher.class); + private static final StateMultiplexingGroupByKeyTransformMatcher INSTANCE = + new StateMultiplexingGroupByKeyTransformMatcher(); + + @Override + public boolean matches(AppliedPTransform application) { + LOG.info(application.getFullName()); + if (!(application.getTransform() instanceof GroupByKey)) { + LOG.info(application.getFullName() + " returning false"); + return false; + } + for (PCollection pCollection : application.getMainInputs().values()) { + LOG.info(application.getFullName() + " " + pCollection.getCoder()); + Coder coder = pCollection.getCoder(); + if (!(coder instanceof KvCoder)) { + return false; + } + // Don't enable multiplexing on Void keys + if (((KvCoder) coder).getKeyCoder() instanceof VoidCoder) { + return false; + } + } + LOG.info(application.getFullName() + " returning true"); + return true; + } + + public static StateMultiplexingGroupByKeyTransformMatcher getInstance() { + return INSTANCE; + } +} diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java index 4df6a745c783..831d9493af82 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java @@ -201,6 +201,7 @@ public KV> apply(KV> kv) { new SimpleFunction, KV>() { @Override public KV apply(KV value) { + // should we use a different hash code? return KV.of(value.getKey().hashCode() % numVirtualKeys, value.getValue()); } }))