From 11fde0ed595dc77eac3d6b5885ac892bea73d40b Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Wed, 20 Nov 2024 13:25:45 +0100 Subject: [PATCH] Decouple consumer threads from harness threads --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 744 +++++++++++++----- .../sdk/io/kafka/ReadFromKafkaDoFnTest.java | 30 +- 2 files changed, 570 insertions(+), 204 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 1cf4aad34e4e..58c6cfd20fdf 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -17,19 +17,26 @@ */ package org.apache.beam.sdk.io.kafka; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; -import java.math.BigDecimal; -import java.math.MathContext; +import java.time.Duration; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicMarkableReference; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; +import java.util.stream.Collectors; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.coders.Coder; @@ -56,25 +63,37 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.AbstractExecutionThreadService; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.FluentFuture; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.FutureCallback; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Futures; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListenableFuture; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListenableFutureTask; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Service; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; -import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigDef; import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.serialization.Deserializer; +import org.checkerframework.checker.nullness.qual.EnsuresNonNull; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.RequiresNonNull; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -189,7 +208,6 @@ private ReadFromKafkaDoFn( ReadSourceDescriptors transform, TupleTag>> recordTag) { this.consumerConfig = transform.getConsumerConfig(); - this.offsetConsumerConfig = transform.getOffsetConsumerConfig(); this.keyDeserializerProvider = Preconditions.checkArgumentNotNull(transform.getKeyDeserializerProvider()); this.valueDeserializerProvider = @@ -215,20 +233,345 @@ private ReadFromKafkaDoFn( * must run clean up tasks when {@link #teardown()} is called. */ private static final class SharedStateHolder { - - private static final Map> - OFFSET_ESTIMATOR_CACHE = new ConcurrentHashMap<>(); private static final Map> AVG_RECORD_SIZE_CACHE = new ConcurrentHashMap<>(); + private static final Map< + Long, LoadingCache>, ConsumerExecutionContext>> + CONSUMER_EXECUTION_CONTEXT_CACHE = new ConcurrentHashMap<>(); + } + + static final class TopicPartitionPollState implements AutoCloseable { + private final TopicPartition topicPartition; + private final OffsetRange offsetRange; + private final AtomicMarkableReference>>> + recordsFutureReference; + + TopicPartitionPollState(final TopicPartition topicPartition, final OffsetRange offsetRange) { + this.topicPartition = topicPartition; + this.offsetRange = offsetRange; + this.recordsFutureReference = new AtomicMarkableReference<>(SettableFuture.create(), false); + } + + TopicPartition getTopicPartition() { + return this.topicPartition; + } + + OffsetRange getOffsetRange() { + return this.offsetRange; + } + + Future>> getRecords() { + final boolean[] isMarked = new boolean[1]; + final SettableFuture>> future = + this.recordsFutureReference.get(isMarked); + // Add a callback if the future wasn't cancelled. + if (!isMarked[0] && !future.isCancelled()) { + // At this point, multiple listeners may be registered and executed if getRecords is called + // more than once before the future is completed. + // The completed future's state may be observed by all such callers, but only one of the + // callback executions will replace the completed future. + // Note that it is only safe to have a single consumer call this method. + Futures.addCallback( + future, + new FutureCallback>>() { + @Override + public void onSuccess(List> result) { + final SettableFuture>> nextFuture = + SettableFuture.create(); + TopicPartitionPollState.this.recordsFutureReference.compareAndSet( + future, nextFuture, false, false); + } + + @Override + public void onFailure(Throwable t) { + // Completion due to an exception is a terminal state and will be kept observable + // until close is called. + } + }, + MoreExecutors.directExecutor()); + } + return future; + } + + boolean setRecords(ListenableFuture>> recordsFuture) { + final boolean[] isMarked = new boolean[1]; + final SettableFuture>> future = + this.recordsFutureReference.get(isMarked); + return !isMarked[0] && future.setFuture(recordsFuture); + } + + boolean hasCapacity() { + final boolean[] isMarked = new boolean[1]; + final SettableFuture>> future = + this.recordsFutureReference.get(isMarked); + return !isMarked[0] && !future.isCancelled() && !future.isDone(); + } + + boolean isClosed() { + return this.recordsFutureReference.isMarked(); + } + + @Override + public void close() { + final boolean[] isMarked = new boolean[1]; + SettableFuture>> expectedFuture = + this.recordsFutureReference.get(isMarked); + + // Set the mark to true for the expected reference or repeatedly refresh both if the reference + // was stale. + // It doesn't matter if the current future is already cancelled, the mark is used to signal + // that close was called. + while (!isMarked[0] && !this.recordsFutureReference.attemptMark(expectedFuture, true)) { + expectedFuture = this.recordsFutureReference.get(isMarked); + } + // Only cancel if the mark was observed transitioning from false to true. + // Add a callback if cancel was called on a completed future. + if (!isMarked[0] && !expectedFuture.cancel(true) && !expectedFuture.isCancelled()) { + // At this point, the listener will execute immediately and replace the completed future + // with a cancelled future. + // The completed future's state may never be observed. + final SettableFuture>> currentFuture = expectedFuture; + currentFuture.addListener( + () -> { + final SettableFuture>> nextFuture = + SettableFuture.create(); + nextFuture.cancel(false); + TopicPartitionPollState.this.recordsFutureReference.compareAndSet( + currentFuture, nextFuture, true, true); + }, + MoreExecutors.directExecutor()); + } + } + } + + static final class ConsumerExecutionContext extends AbstractExecutionThreadService { + private final Consumer consumer; + private final LinkedBlockingQueue workQueue; + private final AtomicReference> endOffsets; + private final Map offsetsForStartReadTimesArgument; + private ListenableFutureTask> + offsetsForStartReadTimesTask; + private final Map offsetsForStopReadTimesArgument; + private ListenableFutureTask> + offsetsForStopReadTimesTask; + private final List queuedSplits; + + ConsumerExecutionContext(final Consumer consumer) { + this.consumer = consumer; + this.workQueue = new LinkedBlockingQueue<>(); + this.endOffsets = new AtomicReference<>(Collections.emptyMap()); + + this.offsetsForStartReadTimesArgument = new HashMap<>(); + this.offsetsForStartReadTimesTask = + ListenableFutureTask.create( + () -> + consumer.offsetsForTimes( + this.offsetsForStartReadTimesArgument, + Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + + this.offsetsForStopReadTimesArgument = new HashMap<>(); + this.offsetsForStopReadTimesTask = + ListenableFutureTask.create( + () -> + consumer.offsetsForTimes( + this.offsetsForStopReadTimesArgument, + Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + + this.queuedSplits = new ArrayList<>(); + } + + public ListenableFuture<@Nullable OffsetAndTimestamp> getOffsetForStartReadTime( + final TopicPartition topicPartition, long time) { + // Note: The ordering of statements is deliberate. Transformation callbacks are executed in + // the thread that first observes the completion after the callback's registration. + // Registration depends on inter-thread actions and enqueueing the task guarantees that these + // actions are observed to happen before dequeueing the element and the consequent completion + // of its future. + // The partial ordering of these actions ensures that the future can not be observed as + // fulfilled when the registration happens. + final ListenableFutureTask task = + ListenableFutureTask.create( + () -> this.offsetsForStartReadTimesArgument.put(topicPartition, time), null); + final FluentFuture<@Nullable OffsetAndTimestamp> future = + FluentFuture.from(task) + .transformAsync( + unused -> Futures.nonCancellationPropagating(this.offsetsForStartReadTimesTask), + MoreExecutors.directExecutor()) + .<@Nullable OffsetAndTimestamp>transform( + beginningOffsets -> beginningOffsets.get(topicPartition), + MoreExecutors.directExecutor()); + workQueue.add(task); + return future; + } + + public ListenableFuture<@Nullable OffsetAndTimestamp> getOffsetForStopReadTime( + final TopicPartition topicPartition, long time) { + // Note: The ordering of statements is deliberate. Transformation callbacks are executed in + // the thread that first observes the completion after the callback's registration. + // Registration depends on inter-thread actions and enqueueing the task guarantees that these + // actions are observed to happen before dequeueing the element and the consequent completion + // of its future. + // The partial ordering of these actions ensures that the future can not be observed as + // fulfilled when the registration happens. + final ListenableFutureTask task = + ListenableFutureTask.create( + () -> this.offsetsForStopReadTimesArgument.put(topicPartition, time), null); + final FluentFuture<@Nullable OffsetAndTimestamp> future = + FluentFuture.from(task) + .transformAsync( + unused -> Futures.nonCancellationPropagating(this.offsetsForStopReadTimesTask), + MoreExecutors.directExecutor()) + .<@Nullable OffsetAndTimestamp>transform( + beginningOffsets -> beginningOffsets.get(topicPartition), + MoreExecutors.directExecutor()); + workQueue.add(task); + return future; + } + + public TopicPartitionPollState assign( + final TopicPartition topicPartition, final OffsetRange offsetRange) { + final TopicPartitionPollState pollState = + new TopicPartitionPollState(topicPartition, offsetRange); + workQueue.add(() -> this.queuedSplits.add(pollState)); + return pollState; + } + + @Override + public void run() { + Map splits = new HashMap<>(); + try (Consumer consumer = this.consumer) { + while (true) { + List tasks; + + // If there's no assigned split, sleep until a task becomes available. + // Otherwise drain the queue immediately. + if (splits.isEmpty()) { + tasks = Collections.singletonList(this.workQueue.take()); + } else { + tasks = new ArrayList<>(); + workQueue.drainTo(tasks); + } + tasks.forEach(Runnable::run); + + if (!offsetsForStartReadTimesArgument.isEmpty()) { + final ListenableFutureTask> + currentOffsetsForStartReadTimeTask = this.offsetsForStartReadTimesTask; + currentOffsetsForStartReadTimeTask.addListener( + () -> this.offsetsForStartReadTimesArgument.clear(), + MoreExecutors.directExecutor()); + this.offsetsForStartReadTimesTask = + ListenableFutureTask.create( + () -> + consumer.offsetsForTimes( + this.offsetsForStartReadTimesArgument, + Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + + currentOffsetsForStartReadTimeTask.run(); + } + + if (!offsetsForStopReadTimesArgument.isEmpty()) { + final ListenableFutureTask> + currentOffsetsForStopReadTimeTask = this.offsetsForStopReadTimesTask; + currentOffsetsForStopReadTimeTask.addListener( + () -> this.offsetsForStopReadTimesArgument.clear(), MoreExecutors.directExecutor()); + this.offsetsForStopReadTimesTask = + ListenableFutureTask.create( + () -> + consumer.offsetsForTimes( + this.offsetsForStopReadTimesArgument, + Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + + currentOffsetsForStopReadTimeTask.run(); + } + + if (!queuedSplits.isEmpty()) { + queuedSplits.forEach( + pollState -> + splits.compute( + pollState.getTopicPartition(), + (topicPartition, currentPollState) -> { + if (currentPollState != null) { + currentPollState.close(); + } + return pollState; + })); + this.consumer.assign(splits.keySet()); + queuedSplits.forEach( + pollState -> { + if (pollState.getOffsetRange().getFrom() >= 0L) { + this.consumer.seek( + pollState.getTopicPartition(), pollState.getOffsetRange().getFrom()); + } + }); + queuedSplits.clear(); + } + + splits.entrySet().removeIf(entry -> entry.getValue().isClosed()); + + if (splits.isEmpty()) { + continue; + } + + this.consumer.pause( + splits.values().stream() + .filter( + ((Predicate) TopicPartitionPollState::hasCapacity) + .negate()) + .map(TopicPartitionPollState::getTopicPartition) + .collect(Collectors.toSet())); + final ListenableFutureTask> currentPollTask = + ListenableFutureTask.create( + () -> this.consumer.poll(Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + splits.forEach( + (topicPartition, pollState) -> + pollState.setRecords( + FluentFuture.from(currentPollTask) + .transform( + records -> records.records(topicPartition), + MoreExecutors.directExecutor()))); + currentPollTask.run(); + this.endOffsets.set( + consumer.endOffsets( + consumer.assignment(), Duration.ofSeconds(DEFAULT_KAFKA_POLL_TIMEOUT))); + this.consumer.resume( + splits.values().stream() + .filter(TopicPartitionPollState::hasCapacity) + .map(TopicPartitionPollState::getTopicPartition) + .collect(Collectors.toSet())); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } finally { + splits.values().forEach(TopicPartitionPollState::close); + } + } + + // public boolean isRunning() { + // return running.get(); + // } + + // public boolean isDone() { + // return done.get(); + // } + + public GrowableOffsetRangeTracker.RangeEndEstimator getRangeEndEstimator( + TopicPartition topicPartition) { + return new KafkaLatestOffsetEstimator(topicPartition, this.endOffsets); + } + + public void wakeup() { + LOG.info("Waking up consumer"); + this.consumer.wakeup(); + } } private static final AtomicLong FN_ID = new AtomicLong(); + private static final Joiner COMMA_JOINER = Joiner.on(','); + // A unique identifier for the instance. Generally unique unless the ID generator overflows. private final long fnId = FN_ID.getAndIncrement(); - private final @Nullable Map offsetConsumerConfig; - private final @Nullable CheckStopReadingFn checkStopReadingFn; private final SerializableFunction, Consumer> @@ -245,11 +588,23 @@ private static final class SharedStateHolder { // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; - private transient @Nullable LoadingCache - offsetEstimatorCache; + // Only used to retain a strong reference to the consumer execution context until this function + // instance is torn down. + // This ties the lifetime of the consumer execution context to that of the bundle processor (or + // equivalent for non-portable runners). + // The consumer execution context cache stores weak references to consumer execution contexts, + // thus allowing the garbage collector to finalize the consumer execution context when no strong + // references to it are held. + @SuppressWarnings("unused") + private transient @Nullable ConsumerExecutionContext consumerExecutionContextInstance = null; - private transient @Nullable LoadingCache + private transient @MonotonicNonNull LoadingCache avgRecordSizeCache; + + private transient @MonotonicNonNull LoadingCache< + Optional>, ConsumerExecutionContext> + consumerExecutionContextCache; + private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @@ -267,80 +622,104 @@ private static final class SharedStateHolder { private static class KafkaLatestOffsetEstimator implements GrowableOffsetRangeTracker.RangeEndEstimator { - private final Consumer offsetConsumer; private final TopicPartition topicPartition; - private final Supplier memoizedBacklog; + private final AtomicReference> endOffsets; KafkaLatestOffsetEstimator( - Consumer offsetConsumer, TopicPartition topicPartition) { - this.offsetConsumer = offsetConsumer; + TopicPartition topicPartition, AtomicReference> endOffsets) { this.topicPartition = topicPartition; - memoizedBacklog = - Suppliers.memoizeWithExpiration( - () -> { - synchronized (offsetConsumer) { - return Preconditions.checkStateNotNull( - offsetConsumer - .endOffsets(Collections.singleton(topicPartition)) - .get(topicPartition), - "No end offset found for partition %s.", - topicPartition); - } - }, - 1, - TimeUnit.SECONDS); - } - - @Override - protected void finalize() { - try { - Closeables.close(offsetConsumer, true); - LOG.info("Offset Estimator consumer was closed for {}", topicPartition); - } catch (Exception anyException) { - LOG.warn("Failed to close offset consumer for {}", topicPartition); - } + this.endOffsets = endOffsets; } @Override public long estimate() { - return memoizedBacklog.get(); + return this.endOffsets.get().getOrDefault(this.topicPartition, -1L); } } @GetInitialRestriction - public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSourceDescriptor) { - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - TopicPartition partition = kafkaSourceDescriptor.getTopicPartition(); - LOG.info("Creating Kafka consumer for initial restriction for {}", kafkaSourceDescriptor); - try (Consumer offsetConsumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - ConsumerSpEL.evaluateAssign(offsetConsumer, ImmutableList.of(partition)); - long startOffset; - @Nullable Instant startReadTime = kafkaSourceDescriptor.getStartReadTime(); - if (kafkaSourceDescriptor.getStartReadOffset() != null) { - startOffset = kafkaSourceDescriptor.getStartReadOffset(); - } else if (startReadTime != null) { - startOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, startReadTime); - } else { - startOffset = offsetConsumer.position(partition); - } + public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSourceDescriptor) + throws Throwable { + LOG.info("Creating initial restriction for {}", kafkaSourceDescriptor); - long endOffset = Long.MAX_VALUE; - @Nullable Instant stopReadTime = kafkaSourceDescriptor.getStopReadTime(); - if (kafkaSourceDescriptor.getStopReadOffset() != null) { - endOffset = kafkaSourceDescriptor.getStopReadOffset(); - } else if (stopReadTime != null) { - endOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, stopReadTime); + // The context may not be used at all, but unconditionally fetching it here may avoid a load + // during processing. + Optional> consumerExecutionContextKey = + Optional.ofNullable(kafkaSourceDescriptor.getBootStrapServers()).map(ImmutableSet::copyOf); + LoadingCache>, ConsumerExecutionContext> + consumerExecutionContextCache = checkNotNull(this.consumerExecutionContextCache); + ConsumerExecutionContext consumerExecutionContext = + consumerExecutionContextCache.get(consumerExecutionContextKey); + try { + consumerExecutionContext.awaitRunning(); + } catch (IllegalStateException ex) { + if (ex.getCause() == null) { + consumerExecutionContextCache.refresh(consumerExecutionContextKey); } - new OffsetRange(startOffset, endOffset); - Lineage.getSources() - .add( - "kafka", - ImmutableList.of( - (String) updatedConsumerConfig.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), - MoreObjects.firstNonNull(kafkaSourceDescriptor.getTopic(), partition.topic()))); - return new OffsetRange(startOffset, endOffset); + throw ex; + } + this.consumerExecutionContextInstance = consumerExecutionContext; + + final ListenableFuture startOffset; + @Nullable Long startReadOffset = kafkaSourceDescriptor.getStartReadOffset(); + @Nullable Instant startReadTime = kafkaSourceDescriptor.getStartReadTime(); + if (startReadOffset != null) { + startOffset = Futures.immediateFuture(startReadOffset); + } else if (startReadTime != null) { + startOffset = + Futures.transform( + consumerExecutionContext.getOffsetForStartReadTime( + kafkaSourceDescriptor.getTopicPartition(), startReadTime.getMillis()), + offsetAndTimestamp -> offsetAndTimestamp == null ? -1L : offsetAndTimestamp.offset(), + MoreExecutors.directExecutor()); + } else { + startOffset = Futures.immediateFuture(-1L); + } + + final ListenableFuture endOffset; + @Nullable Long stopReadOffset = kafkaSourceDescriptor.getStopReadOffset(); + @Nullable Instant stopReadTime = kafkaSourceDescriptor.getStopReadTime(); + if (stopReadOffset != null) { + endOffset = Futures.immediateFuture(stopReadOffset); + } else if (stopReadTime != null) { + endOffset = + Futures.transform( + consumerExecutionContext.getOffsetForStopReadTime( + kafkaSourceDescriptor.getTopicPartition(), stopReadTime.getMillis()), + offsetAndTimestamp -> + offsetAndTimestamp == null ? Long.MAX_VALUE : offsetAndTimestamp.offset(), + MoreExecutors.directExecutor()); + } else { + endOffset = Futures.immediateFuture(Long.MAX_VALUE); + } + + OffsetRange initialRestriction; + try { + initialRestriction = new OffsetRange(startOffset.get(), endOffset.get()); + } catch (ExecutionException ex) { + throw MoreObjects.firstNonNull(ex.getCause(), ex); } + + Lineage.getSources() + .add( + "kafka", + ImmutableList.of( + Optional.ofNullable( + (@Nullable List) + ConfigDef.parseType( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + MoreObjects.firstNonNull( + (@Nullable Object) + consumerConfig.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), + (@Nullable Object) kafkaSourceDescriptor.getBootStrapServers()), + ConfigDef.Type.LIST)) + .map(ImmutableSet::copyOf) + .map(COMMA_JOINER::join) + .orElse(""), + MoreObjects.firstNonNull( + kafkaSourceDescriptor.getTopic(), + kafkaSourceDescriptor.getTopicPartition().topic()))); + return initialRestriction; } @GetInitialWatermarkEstimatorState @@ -357,13 +736,14 @@ public WatermarkEstimator newWatermarkEstimator( } @GetSize + @RequiresNonNull({"avgRecordSizeCache", "consumerExecutionContextCache"}) public double getSize( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) throws ExecutionException { + final LoadingCache avgRecordSizeCache = + this.avgRecordSizeCache; // If present, estimates the record size to offset gap ratio. Compacted topics may hold less // records than the estimated offset range due to record deletion within a partition. - final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); final @Nullable AverageRecordSize avgRecordSize = avgRecordSizeCache.getIfPresent(kafkaSourceDescriptor); // The tracker estimates the offset range by subtracting the last claimed position from the @@ -381,35 +761,46 @@ public double getSize( } @NewTracker + @RequiresNonNull("consumerExecutionContextCache") public OffsetRangeTracker restrictionTracker( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) throws ExecutionException { + final LoadingCache>, ConsumerExecutionContext> + consumerExecutionContextCache = this.consumerExecutionContextCache; if (restriction.getTo() < Long.MAX_VALUE) { return new OffsetRangeTracker(restriction); } + Optional> consumerExecutionContextKey = + Optional.ofNullable(kafkaSourceDescriptor.getBootStrapServers()).map(ImmutableSet::copyOf); + ConsumerExecutionContext consumerExecutionContext = + consumerExecutionContextCache.get(consumerExecutionContextKey); + try { + consumerExecutionContext.awaitRunning(); + } catch (IllegalStateException ex) { + if (ex.getCause() == null) { + consumerExecutionContextCache.refresh(consumerExecutionContextKey); + } + throw ex; + } + this.consumerExecutionContextInstance = consumerExecutionContext; - // OffsetEstimators are cached for each topic-partition because they hold a stateful connection, - // so we want to minimize the amount of connections that we start and track with Kafka. Another - // point is that it has a memoized backlog, and this should make that more reusable estimations. - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); - final KafkaLatestOffsetEstimator offsetEstimator = - offsetEstimatorCache.get(kafkaSourceDescriptor); - - return new GrowableOffsetRangeTracker(restriction.getFrom(), offsetEstimator); + return new GrowableOffsetRangeTracker( + restriction.getFrom(), + consumerExecutionContext.getRangeEndEstimator(kafkaSourceDescriptor.getTopicPartition())); } @ProcessElement + @RequiresNonNull({"avgRecordSizeCache", "consumerExecutionContextCache"}) public ProcessContinuation processElement( @Element KafkaSourceDescriptor kafkaSourceDescriptor, RestrictionTracker tracker, WatermarkEstimator watermarkEstimator, MultiOutputReceiver receiver) - throws Exception { + throws Throwable { final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); + this.avgRecordSizeCache; + final LoadingCache>, ConsumerExecutionContext> + consumerExecutionContextCache = this.consumerExecutionContextCache; final Deserializer keyDeserializerInstance = Preconditions.checkStateNotNull(this.keyDeserializerInstance); final Deserializer valueDeserializerInstance = @@ -431,8 +822,21 @@ public ProcessContinuation processElement( tracker.tryClaim(tracker.currentRestriction().getTo() - 1); return ProcessContinuation.stop(); } - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + + Optional> consumerExecutionContextKey = + Optional.ofNullable(kafkaSourceDescriptor.getBootStrapServers()).map(ImmutableSet::copyOf); + ConsumerExecutionContext consumerExecutionContext = + consumerExecutionContextCache.get(consumerExecutionContextKey); + try { + consumerExecutionContext.awaitRunning(); + } catch (IllegalStateException ex) { + if (ex.getCause() == null) { + consumerExecutionContextCache.refresh(consumerExecutionContextKey); + } + throw ex; + } + this.consumerExecutionContextInstance = consumerExecutionContext; + // If there is a timestampPolicyFactory, create the TimestampPolicy for current // TopicPartition. TimestampPolicy timestampPolicy = null; @@ -443,32 +847,37 @@ public ProcessContinuation processElement( } LOG.info("Creating Kafka consumer for process continuation for {}", kafkaSourceDescriptor); - try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { - ConsumerSpEL.evaluateAssign( - consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); + + try (TopicPartitionPollState pollState = + consumerExecutionContext.assign( + kafkaSourceDescriptor.getTopicPartition(), tracker.currentRestriction())) { long startOffset = tracker.currentRestriction().getFrom(); long expectedOffset = startOffset; - consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset); - ConsumerRecords rawRecords = ConsumerRecords.empty(); long skippedRecords = 0L; final Stopwatch sw = Stopwatch.createStarted(); - while (true) { - rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition()); + Future>> recordsFuture; + while (!(recordsFuture = pollState.getRecords()).isCancelled()) { + List> rawRecords; + try { + rawRecords = recordsFuture.get(); + } catch (CancellationException ex) { + break; + } catch (ExecutionException ex) { + throw MoreObjects.firstNonNull(ex.getCause(), ex); + } // When there are no records available for the current TopicPartition, self-checkpoint // and move to process the next element. if (rawRecords.isEmpty()) { - if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), - consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { - return ProcessContinuation.stop(); - } if (timestampPolicy != null) { updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); } return ProcessContinuation.resume(); } for (ConsumerRecord rawRecord : rawRecords) { + if (startOffset < 0L) { + expectedOffset = rawRecord.offset(); + } // If the Kafka consumer returns a record with an offset that is already processed // the record can be safely skipped. This is needed because there is a possibility // that the seek() above fails to move the offset to the desired position. In which @@ -552,49 +961,12 @@ public ProcessContinuation processElement( backlogBytes.set( (long) - (BigDecimal.valueOf( - Preconditions.checkStateNotNull( - offsetEstimatorCache.get(kafkaSourceDescriptor).estimate())) - .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) - .doubleValue() + (((HasProgress) tracker).getProgress().getWorkRemaining() * avgRecordSize.estimateRecordByteSizeToOffsetCountRatio())); } } - } - private boolean topicPartitionExists( - TopicPartition topicPartition, List partitionInfos) { - // Check if the current TopicPartition still exists. - return partitionInfos.stream() - .anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition())); - } - - // see https://github.com/apache/beam/issues/25962 - private ConsumerRecords poll( - Consumer consumer, TopicPartition topicPartition) { - final Stopwatch sw = Stopwatch.createStarted(); - long previousPosition = -1; - java.time.Duration elapsed = java.time.Duration.ZERO; - java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout); - while (true) { - final ConsumerRecords rawRecords = consumer.poll(timeout.minus(elapsed)); - if (!rawRecords.isEmpty()) { - // return as we have found some entries - return rawRecords; - } - if (previousPosition == (previousPosition = consumer.position(topicPartition))) { - // there was no progress on the offset/position, which indicates end of stream - return rawRecords; - } - elapsed = sw.elapsed(); - if (elapsed.toMillis() >= timeout.toMillis()) { - // timeout is over - LOG.warn( - "No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.", - consumerPollingTimeout); - return rawRecords; - } - } + return ProcessContinuation.stop(); } private TimestampPolicyContext updateWatermarkManually( @@ -616,9 +988,10 @@ public Coder restrictionCoder() { } @Setup + @EnsuresNonNull({"avgRecordSizeCache", "consumerExecutionContextCache"}) public void setup() throws Exception { // Start to track record size and offset gap per bundle. - avgRecordSizeCache = + this.avgRecordSizeCache = SharedStateHolder.AVG_RECORD_SIZE_CACHE.computeIfAbsent( fnId, k -> { @@ -633,53 +1006,60 @@ public AverageRecordSize load(KafkaSourceDescriptor kafkaSourceDescriptor) } }); }); - keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); - valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); - offsetEstimatorCache = - SharedStateHolder.OFFSET_ESTIMATOR_CACHE.computeIfAbsent( + this.consumerExecutionContextCache = + SharedStateHolder.CONSUMER_EXECUTION_CONTEXT_CACHE.computeIfAbsent( fnId, k -> { - final Map consumerConfig = ImmutableMap.copyOf(this.consumerConfig); - final @Nullable Map offsetConsumerConfig = - this.offsetConsumerConfig == null - ? null - : ImmutableMap.copyOf(this.offsetConsumerConfig); return CacheBuilder.newBuilder() .weakValues() - .expireAfterAccess(1, TimeUnit.MINUTES) .build( - new CacheLoader() { + new CacheLoader>, ConsumerExecutionContext>() { @Override - public KafkaLatestOffsetEstimator load( - KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { - LOG.info( - "Creating Kafka consumer for offset estimation for {}", - kafkaSourceDescriptor); - - TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - Consumer offsetConsumer = - consumerFactoryFn.apply( - KafkaIOUtils.getOffsetConsumerConfig( - "tracker-" + topicPartition, - offsetConsumerConfig, - updatedConsumerConfig)); - return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); + public ConsumerExecutionContext load( + Optional> optionalBootstrapServers) + throws Exception { + final Map consumerConfig = + new HashMap<>(ReadFromKafkaDoFn.this.consumerConfig); + ImmutableSet bootstrapServers; + if (optionalBootstrapServers.isPresent() + && (bootstrapServers = optionalBootstrapServers.get()).size() > 0) { + consumerConfig.put( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + COMMA_JOINER.join(bootstrapServers)); + } + checkState( + consumerConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)); + final ConsumerExecutionContext context = + new ConsumerExecutionContext( + ReadFromKafkaDoFn.this.consumerFactoryFn.apply(consumerConfig)); + context.addListener( + new Service.Listener() { + @Override + public void terminated(Service.State from) {} + + @Override + public void failed(Service.State from, Throwable cause) {} + }, + MoreExecutors.directExecutor()); + context.startAsync(); + return context; } }); }); + keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); + valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); if (checkStopReadingFn != null) { checkStopReadingFn.setup(); } } @Teardown + @RequiresNonNull({"avgRecordSizeCache", "consumerExecutionContextCache"}) public void teardown() throws Exception { final LoadingCache avgRecordSizeCache = - Preconditions.checkStateNotNull(this.avgRecordSizeCache); - final LoadingCache offsetEstimatorCache = - Preconditions.checkStateNotNull(this.offsetEstimatorCache); + this.avgRecordSizeCache; + final LoadingCache>, ConsumerExecutionContext> + consumerExecutionContextCache = this.consumerExecutionContextCache; try { if (valueDeserializerInstance != null) { Closeables.close(valueDeserializerInstance, true); @@ -698,21 +1078,7 @@ public void teardown() throws Exception { // Allow the cache to perform clean up tasks when this instance is about to be deleted. avgRecordSizeCache.cleanUp(); - offsetEstimatorCache.cleanUp(); - } - - private Map overrideBootstrapServersConfig( - Map currentConfig, KafkaSourceDescriptor description) { - checkState( - currentConfig.containsKey(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) - || description.getBootStrapServers() != null); - Map config = new HashMap<>(currentConfig); - if (description.getBootStrapServers() != null && description.getBootStrapServers().size() > 0) { - config.put( - ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, - String.join(",", description.getBootStrapServers())); - } - return config; + consumerExecutionContextCache.cleanUp(); } // TODO: Collapse the two moving average trackers into a single accumulator using a single Guava diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index cbff0f896619..f3850bcfa280 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -419,7 +419,7 @@ public void setUp() throws Exception { } @Test - public void testInitialRestrictionWhenHasStartOffset() throws Exception { + public void testInitialRestrictionWhenHasStartOffset() throws Throwable { long expectedStartOffset = 10L; consumer.setStartOffsetForTime(15L, Instant.now()); consumer.setCurrentPos(5L); @@ -431,7 +431,7 @@ public void testInitialRestrictionWhenHasStartOffset() throws Exception { } @Test - public void testInitialRestrictionWhenHasStopOffset() throws Exception { + public void testInitialRestrictionWhenHasStopOffset() throws Throwable { long expectedStartOffset = 10L; long expectedStopOffset = 20L; consumer.setStartOffsetForTime(15L, Instant.now()); @@ -450,7 +450,7 @@ public void testInitialRestrictionWhenHasStopOffset() throws Exception { } @Test - public void testInitialRestrictionWhenHasStartTime() throws Exception { + public void testInitialRestrictionWhenHasStartTime() throws Throwable { long expectedStartOffset = 10L; Instant startReadTime = Instant.now(); consumer.setStartOffsetForTime(expectedStartOffset, startReadTime); @@ -463,7 +463,7 @@ public void testInitialRestrictionWhenHasStartTime() throws Exception { } @Test - public void testInitialRestrictionWhenHasStopTime() throws Exception { + public void testInitialRestrictionWhenHasStopTime() throws Throwable { long expectedStartOffset = 10L; Instant startReadTime = Instant.now(); long expectedStopOffset = 100L; @@ -479,7 +479,7 @@ public void testInitialRestrictionWhenHasStopTime() throws Exception { } @Test - public void testInitialRestrictionWithConsumerPosition() throws Exception { + public void testInitialRestrictionWithConsumerPosition() throws Throwable { long expectedStartOffset = 5L; consumer.setCurrentPos(5L); OffsetRange result = @@ -489,7 +489,7 @@ public void testInitialRestrictionWithConsumerPosition() throws Exception { } @Test - public void testInitialRestrictionWithException() throws Exception { + public void testInitialRestrictionWithException() throws Throwable { thrown.expect(KafkaException.class); thrown.expectMessage("PositionException"); @@ -498,7 +498,7 @@ public void testInitialRestrictionWithException() throws Exception { } @Test - public void testProcessElement() throws Exception { + public void testProcessElement() throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setNumOfRecordsPerPoll(3L); long startOffset = 5L; @@ -514,7 +514,7 @@ public void testProcessElement() throws Exception { } @Test - public void testProcessElementWithEarlierOffset() throws Exception { + public void testProcessElementWithEarlierOffset() throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumerWithBrokenSeek.setNumOfRecordsPerPoll(6L); consumerWithBrokenSeek.setCurrentPos(0L); @@ -532,7 +532,7 @@ public void testProcessElementWithEarlierOffset() throws Exception { } @Test - public void testRawSizeMetric() throws Exception { + public void testRawSizeMetric() throws Throwable { final int numElements = 1000; final int recordSize = 8; // The size of key and value is defined in SimpleMockKafkaConsumer. MetricsContainerImpl container = new MetricsContainerImpl("any"); @@ -558,7 +558,7 @@ public void testRawSizeMetric() throws Exception { } @Test - public void testProcessElementWithEmptyPoll() throws Exception { + public void testProcessElementWithEmptyPoll() throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setNumOfRecordsPerPoll(-1); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); @@ -573,7 +573,7 @@ public void testProcessElementWithEmptyPoll() throws Exception { } @Test - public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception { + public void testProcessElementWhenTopicPartitionIsRemoved() throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setRemoved(); consumer.setNumOfRecordsPerPoll(-1); @@ -600,7 +600,7 @@ public void testSDFCommitOffsetNotEnabled() { } @Test - public void testProcessElementWhenTopicPartitionIsStopped() throws Exception { + public void testProcessElementWhenTopicPartitionIsStopped() throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); ReadFromKafkaDoFn instance = ReadFromKafkaDoFn.create( @@ -630,7 +630,7 @@ public Boolean apply(TopicPartition input) { } @Test - public void testProcessElementWithException() throws Exception { + public void testProcessElementWithException() throws Throwable { thrown.expect(KafkaException.class); thrown.expectMessage("SeekException"); @@ -646,7 +646,7 @@ public void testProcessElementWithException() throws Exception { @Test public void testProcessElementWithDeserializationExceptionDefaultRecordHandler() - throws Exception { + throws Throwable { thrown.expect(SerializationException.class); thrown.expectMessage("Intentional serialization exception"); @@ -672,7 +672,7 @@ public void testProcessElementWithDeserializationExceptionDefaultRecordHandler() @Test public void testProcessElementWithDeserializationExceptionRecordingRecordHandler() - throws Exception { + throws Throwable { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, 1L));