Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug30870]: make consumer polling timeout configurable for KafkaIO.Read #30877

Merged
merged 9 commits into from
Apr 9, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ public static <K, V> Read<K, V> read() {
.setCommitOffsetsInFinalizeEnabled(false)
.setDynamicRead(false)
.setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
.setConsumerPollingTimeout(Duration.standardSeconds(1L))
xianhualiu marked this conversation as resolved.
Show resolved Hide resolved
.build();
}

Expand Down Expand Up @@ -706,6 +707,9 @@ public abstract static class Read<K, V>
@Pure
public abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();

@Pure
public abstract @Nullable Duration getConsumerPollingTimeout();

abstract Builder<K, V> toBuilder();

@AutoValue.Builder
Expand Down Expand Up @@ -762,6 +766,8 @@ Builder<K, V> setCheckStopReadingFn(
return setCheckStopReadingFn(CheckStopReadingFnWrapper.of(checkStopReadingFn));
}

abstract Builder<K, V> setConsumerPollingTimeout(Duration consumerPollingTimeout);

abstract Read<K, V> build();

static <K, V> void setupExternalBuilder(
Expand Down Expand Up @@ -1334,6 +1340,17 @@ public Read<K, V> withBadRecordErrorHandler(ErrorHandler<BadRecord, ?> badRecord
return toBuilder().setBadRecordErrorHandler(badRecordErrorHandler).build();
}

/**
* Sets the timeout time for Kafka consumer polling request in the {@link ReadFromKafkaDoFn}.
* The default is 1 second.
*/
public Read<K, V> withConsumerPollingTimeout(Duration duration) {
checkState(
duration == null || duration.compareTo(Duration.ZERO) > 0,
"Consumer polling timeout must be greater than 0.");
return toBuilder().setConsumerPollingTimeout(duration).build();
}

/** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */
public PTransform<PBegin, PCollection<KV<K, V>>> withoutMetadata() {
return new TypedWithoutMetadata<>(this);
Expand Down Expand Up @@ -1596,7 +1613,8 @@ public PCollection<KafkaRecord<K, V>> expand(PBegin input) {
.withValueDeserializerProvider(kafkaRead.getValueDeserializerProvider())
.withManualWatermarkEstimator()
.withTimestampPolicyFactory(kafkaRead.getTimestampPolicyFactory())
.withCheckStopReadingFn(kafkaRead.getCheckStopReadingFn());
.withCheckStopReadingFn(kafkaRead.getCheckStopReadingFn())
.withConsumerPollingTimeout(kafkaRead.getConsumerPollingTimeout());
if (kafkaRead.isCommitOffsetsInFinalizeEnabled()) {
readTransform = readTransform.commitOffsets();
}
Expand Down Expand Up @@ -2036,6 +2054,9 @@ public abstract static class ReadSourceDescriptors<K, V>
@Pure
abstract ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();

@Pure
abstract @Nullable Duration getConsumerPollingTimeout();

abstract boolean isBounded();

abstract ReadSourceDescriptors.Builder<K, V> toBuilder();
Expand Down Expand Up @@ -2086,6 +2107,9 @@ abstract ReadSourceDescriptors.Builder<K, V> setBadRecordRouter(
abstract ReadSourceDescriptors.Builder<K, V> setBadRecordErrorHandler(
ErrorHandler<BadRecord, ?> badRecordErrorHandler);

abstract ReadSourceDescriptors.Builder<K, V> setConsumerPollingTimeout(
@Nullable Duration duration);

abstract ReadSourceDescriptors.Builder<K, V> setBounded(boolean bounded);

abstract ReadSourceDescriptors<K, V> build();
Expand All @@ -2099,6 +2123,7 @@ public static <K, V> ReadSourceDescriptors<K, V> read() {
.setBounded(false)
.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
.setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>())
.setConsumerPollingTimeout(Duration.standardSeconds(1L))
.build()
.withProcessingTime()
.withMonotonicallyIncreasingWatermarkEstimator();
Expand Down Expand Up @@ -2360,6 +2385,14 @@ public ReadSourceDescriptors<K, V> withBadRecordErrorHandler(
.build();
}

/**
* Sets the timeout time for Kafka consumer polling request in the {@link ReadFromKafkaDoFn}.
* The default is 1 second.
*/
public ReadSourceDescriptors<K, V> withConsumerPollingTimeout(@Nullable Duration duration) {
return toBuilder().setConsumerPollingTimeout(duration).build();
}

ReadAllFromRow<K, V> forExternalBuild() {
return new ReadAllFromRow<>(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Object getDefaultValue() {
VALUE_DESERIALIZER_PROVIDER,
CHECK_STOP_READING_FN(SDF),
BAD_RECORD_ERROR_HANDLER(SDF),
CONSUMER_POLLING_TIMEOUT,
;

@Nonnull private final ImmutableSet<KafkaIOReadImplementation> supportedImplementations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ private ReadFromKafkaDoFn(
this.checkStopReadingFn = transform.getCheckStopReadingFn();
this.badRecordRouter = transform.getBadRecordRouter();
this.recordTag = recordTag;
if (transform.getConsumerPollingTimeout() != null) {
this.consumerPollingTimeout =
java.time.Duration.ofMillis(transform.getConsumerPollingTimeout().getMillis());
} else {
this.consumerPollingTimeout = KAFKA_POLL_TIMEOUT;
}
xianhualiu marked this conversation as resolved.
Show resolved Hide resolved
}

private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
Expand Down Expand Up @@ -219,6 +225,7 @@ private ReadFromKafkaDoFn(

private static final java.time.Duration KAFKA_POLL_TIMEOUT = java.time.Duration.ofSeconds(1);

@VisibleForTesting final java.time.Duration consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
@VisibleForTesting final DeserializerProvider<V> valueDeserializerProvider;
@VisibleForTesting final Map<String, Object> consumerConfig;
Expand Down Expand Up @@ -508,7 +515,7 @@ private ConsumerRecords<byte[], byte[]> poll(
java.time.Duration elapsed = java.time.Duration.ZERO;
while (true) {
final ConsumerRecords<byte[], byte[]> rawRecords =
consumer.poll(KAFKA_POLL_TIMEOUT.minus(elapsed));
consumer.poll(consumerPollingTimeout.minus(elapsed));
if (!rawRecords.isEmpty()) {
// return as we have found some entries
return rawRecords;
Expand All @@ -518,7 +525,7 @@ private ConsumerRecords<byte[], byte[]> poll(
return rawRecords;
}
elapsed = sw.elapsed();
if (elapsed.toMillis() >= KAFKA_POLL_TIMEOUT.toMillis()) {
if (elapsed.toMillis() >= consumerPollingTimeout.toMillis()) {
// timeout is over
return rawRecords;
xianhualiu marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,18 @@ public void testSinkMetrics() throws Exception {
}
}

@Test(expected = IllegalStateException.class)
public void testWithInvalidConsumerPollingTimeout() {
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(Duration.standardSeconds(-5));
}

@Test
public void testWithValidConsumerPollingTimeout() {
KafkaIO.Read<Integer, Long> reader =
KafkaIO.<Integer, Long>read().withConsumerPollingTimeout(Duration.standardSeconds(15));
assertEquals(15, reader.getConsumerPollingTimeout().getStandardSeconds());
}

private static void verifyProducerRecords(
MockProducer<Integer, Long> mockProducer,
String topic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,20 @@ public void testUnbounded() {
Assert.assertNotEquals(0, visitor.unboundedPCollections.size());
}

@Test
public void testConstructorWithPollTimeout() {
ReadSourceDescriptors<String, String> descriptors = makeReadSourceDescriptor(consumer);
// default poll timeout = 1 scond
ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(Duration.ofSeconds(1L), dofnInstance.consumerPollingTimeout);
// updated timeout = 5 seconds
descriptors =
descriptors.withConsumerPollingTimeout(org.joda.time.Duration.standardSeconds(5L));
ReadFromKafkaDoFn<String, String> dofnInstanceNew =
ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
}

private BoundednessVisitor testBoundedness(
Function<ReadSourceDescriptors<String, String>, ReadSourceDescriptors<String, String>>
readSourceDescriptorsDecorator) {
Expand Down
Loading