Skip to content

Commit

Permalink
Remove expensive shuffle of read data in KafkaIO when using sdf and c…
Browse files Browse the repository at this point in the history
…ommit offsets (apache#31682)
  • Loading branch information
scwhittle authored and reeba212 committed Dec 4, 2024
1 parent ccb2948 commit cfc3379
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -49,9 +52,12 @@ public class KafkaCommitOffset<K, V>
extends PTransform<
PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>>, PCollection<Void>> {
private final KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors;
private final boolean use259implementation;

KafkaCommitOffset(KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors) {
KafkaCommitOffset(
KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors, boolean use259implementation) {
this.readSourceDescriptors = readSourceDescriptors;
this.use259implementation = use259implementation;
}

static class CommitOffsetDoFn extends DoFn<KV<KafkaSourceDescriptor, Long>, Void> {
Expand Down Expand Up @@ -90,7 +96,7 @@ private Map<String, Object> overrideBootstrapServersConfig(
|| description.getBootStrapServers() != null);
Map<String, Object> config = new HashMap<>(currentConfig);
if (description.getBootStrapServers() != null
&& description.getBootStrapServers().size() > 0) {
&& !description.getBootStrapServers().isEmpty()) {
config.put(
ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG,
String.join(",", description.getBootStrapServers()));
Expand All @@ -99,13 +105,78 @@ private Map<String, Object> overrideBootstrapServersConfig(
}
}

private static final class MaxOffsetFn<K, V>
extends DoFn<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>, KV<KafkaSourceDescriptor, Long>> {
private static class OffsetAndTimestamp {
OffsetAndTimestamp(long offset, Instant timestamp) {
this.offset = offset;
this.timestamp = timestamp;
}

void merge(long offset, Instant timestamp) {
if (this.offset < offset) {
this.offset = offset;
this.timestamp = timestamp;
}
}

long offset;
Instant timestamp;
}

private transient @MonotonicNonNull Map<KafkaSourceDescriptor, OffsetAndTimestamp> maxObserved;

@StartBundle
public void startBundle() {
if (maxObserved == null) {
maxObserved = new HashMap<>();
} else {
maxObserved.clear();
}
}

@RequiresStableInput
@ProcessElement
@SuppressWarnings("nullness") // startBundle guaranteed to initialize
public void processElement(
@Element KV<KafkaSourceDescriptor, KafkaRecord<K, V>> element,
@Timestamp Instant timestamp) {
maxObserved.compute(
element.getKey(),
(k, v) -> {
long offset = element.getValue().getOffset();
if (v == null) {
return new OffsetAndTimestamp(offset, timestamp);
}
v.merge(offset, timestamp);
return v;
});
}

@FinishBundle
@SuppressWarnings("nullness") // startBundle guaranteed to initialize
public void finishBundle(FinishBundleContext context) {
maxObserved.forEach(
(k, v) -> context.output(KV.of(k, v.offset), v.timestamp, GlobalWindow.INSTANCE));
}
}

@Override
public PCollection<Void> expand(PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> input) {
try {
return input
.apply(
MapElements.into(new TypeDescriptor<KV<KafkaSourceDescriptor, Long>>() {})
.via(element -> KV.of(element.getKey(), element.getValue().getOffset())))
PCollection<KV<KafkaSourceDescriptor, Long>> offsets;
if (use259implementation) {
offsets =
input.apply(
MapElements.into(new TypeDescriptor<KV<KafkaSourceDescriptor, Long>>() {})
.via(element -> KV.of(element.getKey(), element.getValue().getOffset())));
} else {
// Reduce the amount of data to combine by calculating a max within the generally dense
// bundles of reading
// from a Kafka partition.
offsets = input.apply(ParDo.of(new MaxOffsetFn<>()));
}
return offsets
.setCoder(
KvCoder.of(
input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -60,12 +61,14 @@
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverride;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import org.apache.beam.sdk.schemas.transforms.Convert;
Expand Down Expand Up @@ -103,6 +106,7 @@
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Comparators;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.kafka.clients.consumer.Consumer;
Expand Down Expand Up @@ -2136,9 +2140,9 @@ public void populateDisplayData(DisplayData.Builder builder) {
* the transform will expand to:
*
* <pre>{@code
* PCollection<KafkaSourceDescriptor> --> ParDo(ReadFromKafkaDoFn<KafkaSourceDescriptor, KV<KafkaSourceDescriptor, KafkaRecord>>) --> Reshuffle() --> Map(output KafkaRecord)
* |
* --> KafkaCommitOffset
* PCollection<KafkaSourceDescriptor> --> ParDo(ReadFromKafkaDoFn<KafkaSourceDescriptor, KV<KafkaSourceDescriptor, KafkaRecord>>) --> Map(output KafkaRecord)
* |
* --> KafkaCommitOffset
* }</pre>
*
* . Note that this expansion is not supported when running with x-lang on Dataflow.
Expand Down Expand Up @@ -2682,33 +2686,61 @@ public PCollection<KafkaRecord<K, V>> expand(PCollection<KafkaSourceDescriptor>
.getSchemaRegistry()
.getSchemaCoder(KafkaSourceDescriptor.class),
recordCoder));
if (isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute()) {
outputWithDescriptor =
outputWithDescriptor
.apply(Reshuffle.viaRandomKey())
.setCoder(
KvCoder.of(
input
.getPipeline()
.getSchemaRegistry()
.getSchemaCoder(KafkaSourceDescriptor.class),
recordCoder));

PCollection<Void> unused = outputWithDescriptor.apply(new KafkaCommitOffset<K, V>(this));
unused.setCoder(VoidCoder.of());

boolean applyCommitOffsets =
isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute();
if (!applyCommitOffsets) {
return outputWithDescriptor
.apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() {}).via(KV::getValue))
.setCoder(recordCoder);
}

// Add transform for committing offsets to Kafka with consistency with beam pipeline data
// processing.
String requestedVersionString =
input
.getPipeline()
.getOptions()
.as(StreamingOptions.class)
.getUpdateCompatibilityVersion();
if (requestedVersionString != null) {
List<String> requestedVersion = Arrays.asList(requestedVersionString.split("\\."));
List<String> targetVersion = Arrays.asList("2", "60", "0");

if (Comparators.lexicographical(Comparator.<String>naturalOrder())
.compare(requestedVersion, targetVersion)
< 0) {
return expand259Commits(
outputWithDescriptor, recordCoder, input.getPipeline().getSchemaRegistry());
}
}
PCollection<KafkaRecord<K, V>> output =
outputWithDescriptor
.apply(
MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() {})
.via(element -> element.getValue()))
.setCoder(recordCoder);
return output;
outputWithDescriptor.apply(new KafkaCommitOffset<>(this, false)).setCoder(VoidCoder.of());
return outputWithDescriptor
.apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() {}).via(KV::getValue))
.setCoder(recordCoder);
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e.getMessage());
}
}

private PCollection<KafkaRecord<K, V>> expand259Commits(
PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> outputWithDescriptor,
Coder<KafkaRecord<K, V>> recordCoder,
SchemaRegistry schemaRegistry)
throws NoSuchSchemaException {
// Reshuffles the data and then branches off applying commit offsets.
outputWithDescriptor =
outputWithDescriptor
.apply(Reshuffle.viaRandomKey())
.setCoder(
KvCoder.of(
schemaRegistry.getSchemaCoder(KafkaSourceDescriptor.class), recordCoder));
outputWithDescriptor.apply(new KafkaCommitOffset<>(this, true)).setCoder(VoidCoder.of());
return outputWithDescriptor
.apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() {}).via(KV::getValue))
.setCoder(recordCoder);
}

private Coder<K> getKeyCoder(CoderRegistry coderRegistry) {
return (getKeyCoder() != null)
? getKeyCoder()
Expand Down
Loading

0 comments on commit cfc3379

Please sign in to comment.