From cfc337902c6fce8fe245e0581407e025135d35b4 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Thu, 29 Aug 2024 16:03:47 +0200 Subject: [PATCH] Remove expensive shuffle of read data in KafkaIO when using sdf and commit offsets (#31682) --- .../beam/sdk/io/kafka/KafkaCommitOffset.java | 83 ++++++++- .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 80 ++++++--- .../sdk/io/kafka/KafkaCommitOffsetTest.java | 169 +++++++++++++++--- 3 files changed, 278 insertions(+), 54 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java index 3816ee0bb855..fa692d3aaf42 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -40,7 +41,9 @@ import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.joda.time.Duration; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,9 +52,12 @@ public class KafkaCommitOffset extends PTransform< PCollection>>, PCollection> { private final KafkaIO.ReadSourceDescriptors readSourceDescriptors; + private final boolean use259implementation; - KafkaCommitOffset(KafkaIO.ReadSourceDescriptors readSourceDescriptors) { + KafkaCommitOffset( + KafkaIO.ReadSourceDescriptors readSourceDescriptors, boolean use259implementation) { this.readSourceDescriptors = readSourceDescriptors; + this.use259implementation = use259implementation; } static class CommitOffsetDoFn extends DoFn, Void> { @@ -90,7 +96,7 @@ private Map overrideBootstrapServersConfig( || description.getBootStrapServers() != null); Map config = new HashMap<>(currentConfig); if (description.getBootStrapServers() != null - && description.getBootStrapServers().size() > 0) { + && !description.getBootStrapServers().isEmpty()) { config.put( ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, String.join(",", description.getBootStrapServers())); @@ -99,13 +105,78 @@ private Map overrideBootstrapServersConfig( } } + private static final class MaxOffsetFn + extends DoFn>, KV> { + private static class OffsetAndTimestamp { + OffsetAndTimestamp(long offset, Instant timestamp) { + this.offset = offset; + this.timestamp = timestamp; + } + + void merge(long offset, Instant timestamp) { + if (this.offset < offset) { + this.offset = offset; + this.timestamp = timestamp; + } + } + + long offset; + Instant timestamp; + } + + private transient @MonotonicNonNull Map maxObserved; + + @StartBundle + public void startBundle() { + if (maxObserved == null) { + maxObserved = new HashMap<>(); + } else { + maxObserved.clear(); + } + } + + @RequiresStableInput + @ProcessElement + @SuppressWarnings("nullness") // startBundle guaranteed to initialize + public void processElement( + @Element KV> element, + @Timestamp Instant timestamp) { + maxObserved.compute( + element.getKey(), + (k, v) -> { + long offset = element.getValue().getOffset(); + if (v == null) { + return new OffsetAndTimestamp(offset, timestamp); + } + v.merge(offset, timestamp); + return v; + }); + } + + @FinishBundle + @SuppressWarnings("nullness") // startBundle guaranteed to initialize + public void finishBundle(FinishBundleContext context) { + maxObserved.forEach( + (k, v) -> context.output(KV.of(k, v.offset), v.timestamp, GlobalWindow.INSTANCE)); + } + } + @Override public PCollection expand(PCollection>> input) { try { - return input - .apply( - MapElements.into(new TypeDescriptor>() {}) - .via(element -> KV.of(element.getKey(), element.getValue().getOffset()))) + PCollection> offsets; + if (use259implementation) { + offsets = + input.apply( + MapElements.into(new TypeDescriptor>() {}) + .via(element -> KV.of(element.getKey(), element.getValue().getOffset()))); + } else { + // Reduce the amount of data to combine by calculating a max within the generally dense + // bundles of reading + // from a Kafka partition. + offsets = input.apply(ParDo.of(new MaxOffsetFn<>())); + } + return offsets .setCoder( KvCoder.of( input diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index c1526d5382b4..1fd3e3e044ef 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -31,6 +31,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -60,12 +61,14 @@ import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; import org.apache.beam.sdk.schemas.transforms.Convert; @@ -103,6 +106,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Comparators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.kafka.clients.consumer.Consumer; @@ -2136,9 +2140,9 @@ public void populateDisplayData(DisplayData.Builder builder) { * the transform will expand to: * *
{@code
-   * PCollection --> ParDo(ReadFromKafkaDoFn>) --> Reshuffle() --> Map(output KafkaRecord)
-   *                                                                                                                                         |
-   *                                                                                                                                         --> KafkaCommitOffset
+   * PCollection --> ParDo(ReadFromKafkaDoFn>) --> Map(output KafkaRecord)
+   *                                                                                                          |
+   *                                                                                                          --> KafkaCommitOffset
    * }
* * . Note that this expansion is not supported when running with x-lang on Dataflow. @@ -2682,33 +2686,61 @@ public PCollection> expand(PCollection .getSchemaRegistry() .getSchemaCoder(KafkaSourceDescriptor.class), recordCoder)); - if (isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute()) { - outputWithDescriptor = - outputWithDescriptor - .apply(Reshuffle.viaRandomKey()) - .setCoder( - KvCoder.of( - input - .getPipeline() - .getSchemaRegistry() - .getSchemaCoder(KafkaSourceDescriptor.class), - recordCoder)); - - PCollection unused = outputWithDescriptor.apply(new KafkaCommitOffset(this)); - unused.setCoder(VoidCoder.of()); + + boolean applyCommitOffsets = + isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute(); + if (!applyCommitOffsets) { + return outputWithDescriptor + .apply(MapElements.into(new TypeDescriptor>() {}).via(KV::getValue)) + .setCoder(recordCoder); + } + + // Add transform for committing offsets to Kafka with consistency with beam pipeline data + // processing. + String requestedVersionString = + input + .getPipeline() + .getOptions() + .as(StreamingOptions.class) + .getUpdateCompatibilityVersion(); + if (requestedVersionString != null) { + List requestedVersion = Arrays.asList(requestedVersionString.split("\\.")); + List targetVersion = Arrays.asList("2", "60", "0"); + + if (Comparators.lexicographical(Comparator.naturalOrder()) + .compare(requestedVersion, targetVersion) + < 0) { + return expand259Commits( + outputWithDescriptor, recordCoder, input.getPipeline().getSchemaRegistry()); + } } - PCollection> output = - outputWithDescriptor - .apply( - MapElements.into(new TypeDescriptor>() {}) - .via(element -> element.getValue())) - .setCoder(recordCoder); - return output; + outputWithDescriptor.apply(new KafkaCommitOffset<>(this, false)).setCoder(VoidCoder.of()); + return outputWithDescriptor + .apply(MapElements.into(new TypeDescriptor>() {}).via(KV::getValue)) + .setCoder(recordCoder); } catch (NoSuchSchemaException e) { throw new RuntimeException(e.getMessage()); } } + private PCollection> expand259Commits( + PCollection>> outputWithDescriptor, + Coder> recordCoder, + SchemaRegistry schemaRegistry) + throws NoSuchSchemaException { + // Reshuffles the data and then branches off applying commit offsets. + outputWithDescriptor = + outputWithDescriptor + .apply(Reshuffle.viaRandomKey()) + .setCoder( + KvCoder.of( + schemaRegistry.getSchemaCoder(KafkaSourceDescriptor.class), recordCoder)); + outputWithDescriptor.apply(new KafkaCommitOffset<>(this, true)).setCoder(VoidCoder.of()); + return outputWithDescriptor + .apply(MapElements.into(new TypeDescriptor>() {}).via(KV::getValue)) + .setCoder(recordCoder); + } + private Coder getKeyCoder(CoderRegistry coderRegistry) { return (getKeyCoder() != null) ? getKeyCoder() diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java index f258328c1092..c16e25510ab8 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java @@ -17,13 +17,25 @@ */ package org.apache.beam.sdk.io.kafka; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.kafka.KafkaCommitOffset.CommitOffsetDoFn; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.testing.ExpectedLogs; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.MockConsumer; @@ -33,14 +45,14 @@ import org.junit.Assert; import org.junit.Rule; import org.junit.Test; +import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Unit tests for {@link KafkaCommitOffset}. */ @RunWith(JUnit4.class) public class KafkaCommitOffsetTest { - - private final TopicPartition partition = new TopicPartition("topic", 0); + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(CommitOffsetDoFn.class); private final KafkaCommitOffsetMockConsumer consumer = @@ -48,29 +60,132 @@ public class KafkaCommitOffsetTest { private final KafkaCommitOffsetMockConsumer errorConsumer = new KafkaCommitOffsetMockConsumer(null, true); + private static final KafkaCommitOffsetMockConsumer COMPOSITE_CONSUMER = + new KafkaCommitOffsetMockConsumer(null, false); + private static final KafkaCommitOffsetMockConsumer COMPOSITE_CONSUMER_BOOTSTRAP = + new KafkaCommitOffsetMockConsumer(null, false); + + private static final Map configMap = + ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, "group1"); + @Test public void testCommitOffsetDoFn() { - Map configMap = new HashMap<>(); - configMap.put(ConsumerConfig.GROUP_ID_CONFIG, "group1"); - ReadSourceDescriptors descriptors = ReadSourceDescriptors.read() .withBootstrapServers("bootstrap_server") .withConsumerConfigUpdates(configMap) .withConsumerFactoryFn( - new SerializableFunction, Consumer>() { - @Override - public Consumer apply(Map input) { - Assert.assertEquals("group1", input.get(ConsumerConfig.GROUP_ID_CONFIG)); - return consumer; - } - }); + (SerializableFunction, Consumer>) + input -> { + Assert.assertEquals("group1", input.get(ConsumerConfig.GROUP_ID_CONFIG)); + return consumer; + }); CommitOffsetDoFn doFn = new CommitOffsetDoFn(descriptors); + final TopicPartition topicPartition1 = new TopicPartition("topic", 0); + final TopicPartition topicPartition2 = new TopicPartition("other_topic", 1); doFn.processElement( - KV.of(KafkaSourceDescriptor.of(partition, null, null, null, null, null), 1L)); + KV.of(KafkaSourceDescriptor.of(topicPartition1, null, null, null, null, null), 2L)); + doFn.processElement( + KV.of(KafkaSourceDescriptor.of(topicPartition2, null, null, null, null, null), 200L)); - Assert.assertEquals(2L, consumer.commit.get(partition).offset()); + Assert.assertEquals(3L, (long) consumer.commitOffsets.get(topicPartition1)); + Assert.assertEquals(201L, (long) consumer.commitOffsets.get(topicPartition2)); + + doFn.processElement( + KV.of(KafkaSourceDescriptor.of(topicPartition1, null, null, null, null, null), 3L)); + Assert.assertEquals(4L, (long) consumer.commitOffsets.get(topicPartition1)); + } + + KafkaRecord makeTestRecord(int i) { + return new KafkaRecord<>( + "", 0, i, 0, KafkaTimestampType.NO_TIMESTAMP_TYPE, null, KV.of("key" + i, "value" + i)); + } + + @Test + @Category(NeedsRunner.class) + public void testKafkaOffsetComposite() throws CannotProvideCoderException { + testKafkaOffsetHelper(false); + } + + @Test + @Category(NeedsRunner.class) + public void testKafkaOffsetCompositeLegacy() throws CannotProvideCoderException { + testKafkaOffsetHelper(true); + } + + private void testKafkaOffsetHelper(boolean use259Implementation) + throws CannotProvideCoderException { + COMPOSITE_CONSUMER.commitOffsets.clear(); + COMPOSITE_CONSUMER_BOOTSTRAP.commitOffsets.clear(); + + ReadSourceDescriptors descriptors = + ReadSourceDescriptors.read() + .withBootstrapServers("bootstrap_server") + .withConsumerConfigUpdates(configMap) + .withConsumerFactoryFn( + (SerializableFunction, Consumer>) + input -> { + Assert.assertEquals("group1", input.get(ConsumerConfig.GROUP_ID_CONFIG)); + if (input + .get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) + .equals("bootstrap_server")) { + return COMPOSITE_CONSUMER; + } + Assert.assertEquals( + "bootstrap_overridden", + input.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + return COMPOSITE_CONSUMER_BOOTSTRAP; + }); + + String topic0 = "topic0_" + (use259Implementation ? "259" : "260"); + String topic1 = "topic1_" + (use259Implementation ? "259" : "260"); + KafkaSourceDescriptor d1 = + KafkaSourceDescriptor.of(new TopicPartition(topic0, 0), null, null, null, null, null); + KafkaSourceDescriptor d2 = + KafkaSourceDescriptor.of(new TopicPartition(topic0, 1), null, null, null, null, null); + KafkaSourceDescriptor d3 = + KafkaSourceDescriptor.of( + new TopicPartition(topic1, 0), + null, + null, + null, + null, + ImmutableList.of("bootstrap_overridden")); + KafkaSourceDescriptor d4 = + KafkaSourceDescriptor.of( + new TopicPartition(topic1, 1), + null, + null, + null, + null, + ImmutableList.of("bootstrap_overridden")); + List>> elements = new ArrayList<>(); + elements.add(KV.of(d1, makeTestRecord(10))); + + elements.add(KV.of(d2, makeTestRecord(20))); + elements.add(KV.of(d3, makeTestRecord(30))); + elements.add(KV.of(d4, makeTestRecord(40))); + elements.add(KV.of(d2, makeTestRecord(10))); + elements.add(KV.of(d1, makeTestRecord(100))); + PCollection>> input = + pipeline.apply( + Create.of(elements) + .withCoder( + KvCoder.of( + pipeline.getCoderRegistry().getCoder(KafkaSourceDescriptor.class), + KafkaRecordCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())))); + input.apply(new KafkaCommitOffset<>(descriptors, use259Implementation)); + pipeline.run(); + + HashMap expectedOffsets = new HashMap<>(); + expectedOffsets.put(d1.getTopicPartition(), 101L); + expectedOffsets.put(d2.getTopicPartition(), 21L); + Assert.assertEquals(expectedOffsets, COMPOSITE_CONSUMER.commitOffsets); + expectedOffsets.clear(); + expectedOffsets.put(d3.getTopicPartition(), 31L); + expectedOffsets.put(d4.getTopicPartition(), 41L); + Assert.assertEquals(expectedOffsets, COMPOSITE_CONSUMER_BOOTSTRAP.commitOffsets); } @Test @@ -83,25 +198,25 @@ public void testCommitOffsetError() { .withBootstrapServers("bootstrap_server") .withConsumerConfigUpdates(configMap) .withConsumerFactoryFn( - new SerializableFunction, Consumer>() { - @Override - public Consumer apply(Map input) { - Assert.assertEquals("group1", input.get(ConsumerConfig.GROUP_ID_CONFIG)); - return errorConsumer; - } - }); + (SerializableFunction, Consumer>) + input -> { + Assert.assertEquals("group1", input.get(ConsumerConfig.GROUP_ID_CONFIG)); + return errorConsumer; + }); CommitOffsetDoFn doFn = new CommitOffsetDoFn(descriptors); + final TopicPartition partition = new TopicPartition("topic", 0); doFn.processElement( KV.of(KafkaSourceDescriptor.of(partition, null, null, null, null, null), 1L)); expectedLogs.verifyWarn("Getting exception when committing offset: Test Exception"); + Assert.assertTrue(errorConsumer.commitOffsets.isEmpty()); } private static class KafkaCommitOffsetMockConsumer extends MockConsumer { - public Map commit; - private boolean throwException; + public final HashMap commitOffsets = new HashMap<>(); + private final boolean throwException; public KafkaCommitOffsetMockConsumer( OffsetResetStrategy offsetResetStrategy, boolean throwException) { @@ -115,8 +230,14 @@ public synchronized void commitSync(Map offse throw new RuntimeException("Test Exception"); } else { commitAsync(offsets, null); - commit = offsets; + offsets.forEach( + (topic, offsetMetadata) -> commitOffsets.put(topic, offsetMetadata.offset())); } } + + @Override + public synchronized void close(long timeout, TimeUnit unit) { + // Ignore closing since we're using a single consumer. + } } }