Skip to content

Commit

Permalink
fix(citrusframework#1281): introduce separate consumers per subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
bbortt committed Dec 17, 2024
1 parent 607c527 commit 1c3db3d
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;

import static java.util.Objects.isNull;
import static java.util.Objects.nonNull;
import static java.util.UUID.randomUUID;
import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG;
import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_OFFSET_RESET_CONFIG;
Expand All @@ -44,28 +46,70 @@ public class KafkaConsumer extends AbstractSelectiveMessageConsumer {
private static final Logger logger = LoggerFactory.getLogger(KafkaConsumer.class);

private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumer;
private final ConcurrentLinkedQueue<org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object>> managedConsumers = new ConcurrentLinkedQueue<>();

/**
* Default constructor using endpoint.
*/
public KafkaConsumer(String name, KafkaEndpointConfiguration endpointConfiguration) {
super(name, endpointConfiguration);
this.consumer = createConsumer();
}

/**
* Initializes and provides a new {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance in a thread-safe manner.
* This method is the preferred way to obtain a consumer instance as it ensures proper lifecycle management and thread-safety.
* <p>
* The created consumer is automatically registered for lifecycle management and cleanup.
* Each call to this method creates a new consumer instance.
*
* @return a new thread-safe {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance
*/
public org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createManagedConsumer() {
if (nonNull(consumer)) {
return consumer;
}

var managedConsumer = createKafkaConsumer();
managedConsumers.add(managedConsumer);
return managedConsumer;
}

/**
* Returns the current {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance.
*
* @return the current {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance
* @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe and manual consumer management is error-prone.
* Use {@link #createManagedConsumer()} instead to obtain properly managed consumer instances.
* This method will be removed in a future release.
*/
@Deprecated(forRemoval = true)
public org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> getConsumer() {
if (isNull(consumer)) {
consumer = createKafkaConsumer();
}

return consumer;
}

/**
* Sets the {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance.
*
* @param consumer the KafkaConsumer to set
* @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe and manual consumer management is error-prone.
* Use {@link #createManagedConsumer()} instead to obtain properly managed consumer instances.
* This method will be removed in a future release.
*/
@Deprecated(forRemoval = true)
public void setConsumer(org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumer) {
this.consumer = consumer;
this.managedConsumers.add(consumer);
}

@Override
public Message receive(TestContext testContext, long timeout) {
logger.debug("Receiving single message");
return KafkaMessageSingleConsumer.builder()
.consumer(consumer)
.consumer(createManagedConsumer())
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(testContext, timeout);
Expand All @@ -75,7 +119,7 @@ public Message receive(TestContext testContext, long timeout) {
public Message receive(String selector, TestContext testContext, long timeout) {
logger.debug("Receiving selected message: {}", selector);
return KafkaMessageFilteringConsumer.builder()
.consumer(consumer)
.consumer(createManagedConsumer())
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(selector, testContext, timeout);
Expand All @@ -90,19 +134,20 @@ protected KafkaEndpointConfiguration getEndpointConfiguration() {
* Stop message listener container.
*/
public void stop() {
try {
if (consumer.subscription() != null && !consumer.subscription().isEmpty()) {
consumer.unsubscribe();
org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumerToDelete;
while (nonNull(consumerToDelete = managedConsumers.poll())) {
try {
consumerToDelete.unsubscribe();
} finally {
consumerToDelete.close();
}
} finally {
consumer.close(Duration.ofMillis(10 * 1000L));
}
}

/**
* Create new Kafka consumer with given endpoint configuration.
*/
private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createConsumer() {
private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createKafkaConsumer() {
Map<String, Object> consumerProps = new HashMap<>();
consumerProps.put(CLIENT_ID_CONFIG, Optional.ofNullable(getEndpointConfiguration().getClientId()).orElseGet(() -> KAFKA_PREFIX + "consumer_" + randomUUID()));
consumerProps.put(GROUP_ID_CONFIG, getEndpointConfiguration().getConsumerGroup());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public SimpleKafkaEndpointBuilder topic(String topic) {
}

public KafkaEndpoint build() {
return KafkaEndpoint.newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic);
return newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
Expand All @@ -51,15 +55,15 @@ public class KafkaConsumerTest extends AbstractTestNGUnitTest {
private final org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> kafkaConsumerMock = mock(KafkaConsumer.class);

@Test
public void testReceiveMessage() {
public void receiveMessage() {
String topic = "default";

KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.topic(topic)
.build();

TopicPartition partition = new TopicPartition(topic, 0);
var partition = new TopicPartition(topic, 0);

reset(kafkaConsumerMock);

Expand All @@ -83,15 +87,15 @@ public void testReceiveMessage() {
}

@Test
public void testReceiveMessage_inRandomConsumerGroup() {
public void receiveMessage_inRandomConsumerGroup() {
String topic = "default";

KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.topic(topic)
.build();

TopicPartition partition = new TopicPartition(topic, 0);
var partition = new TopicPartition(topic, 0);

reset(kafkaConsumerMock);

Expand All @@ -115,7 +119,7 @@ public void testReceiveMessage_inRandomConsumerGroup() {
}

@Test
public void testReceiveMessageTimeout() {
public void receiveMessage_runIntoTimeout() {
String topic = "test";

KafkaEndpoint endpoint = KafkaEndpoint.builder()
Expand All @@ -140,7 +144,7 @@ public void testReceiveMessageTimeout() {
}

@Test
public void testWithCustomTimeout() {
public void receiveMessage_customTimeout_runIntoTimeout() {
String topic = "timeout";

KafkaEndpoint endpoint = KafkaEndpoint.builder()
Expand All @@ -149,7 +153,7 @@ public void testWithCustomTimeout() {
.topic(topic)
.build();

TopicPartition partition = new TopicPartition(topic, 0);
var partition = new TopicPartition(topic, 0);

reset(kafkaConsumerMock);
when(kafkaConsumerMock.subscription()).thenReturn(singleton(topic));
Expand All @@ -165,7 +169,7 @@ public void testWithCustomTimeout() {
}

@Test
public void testWithMessageHeaders() {
public void receiveMessage_withMessageHeaders() {
String topic = "headers";

KafkaEndpoint endpoint = KafkaEndpoint.builder()
Expand All @@ -174,7 +178,7 @@ public void testWithMessageHeaders() {
.topic(topic)
.build();

TopicPartition partition = new TopicPartition(topic, 0);
var partition = new TopicPartition(topic, 0);

reset(kafkaConsumerMock);
when(kafkaConsumerMock.subscription()).thenReturn(singleton(topic));
Expand All @@ -193,4 +197,87 @@ public void testWithMessageHeaders() {
assertNotNull(receivedMessage.getHeader("Operation"));
assertEquals(receivedMessage.getHeader("Operation"), "sayHello");
}

@Test
public void getConsumer_returnsSetConsumer() {
var kafkaConsumerMock = mock(KafkaConsumer.class);
KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.build();

var result = endpoint.createConsumer().getConsumer();
assertThat(result)
.isEqualTo(kafkaConsumerMock);
}

@Test
public void getConsumer_createsConsumerIfNonSet() {
KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(null) // null for explicity
.build();

var result = endpoint.createConsumer().getConsumer();
assertThat(result)
.isNotNull();
}

@Test
public void createManagedConsumer_createsDifferentManagedConsumers() {
KafkaEndpoint endpoint = KafkaEndpoint.builder()
.build();

var managedConsumer1 = endpoint.createConsumer().createManagedConsumer();
assertThat(managedConsumer1)
.isNotNull();

var managedConsumer2 = endpoint.createConsumer().createManagedConsumer();

assertThat(managedConsumer2)
.isNotNull()
.isNotEqualTo(managedConsumer1)
.isNotSameAs(managedConsumer1);
}

@Test
@SuppressWarnings({"unchecked"})
public void createManagedConsumer_returnsConsumerIfOneIsSet() {
var kafkaConsumerMock = mock(KafkaConsumer.class);
KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.build();

var managedConsumer = endpoint.createConsumer().createManagedConsumer();
assertThat(managedConsumer)
.isEqualTo(kafkaConsumerMock);
}

@Test
@SuppressWarnings({"unchecked"})
public void stop_unsubscribesAndClosesConsumer() {
var kafkaConsumerMock = mock(KafkaConsumer.class);
KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.build();

endpoint.createConsumer().stop();
verify(kafkaConsumerMock).unsubscribe();
verify(kafkaConsumerMock).close();
}

@Test
@SuppressWarnings({"unchecked"})
public void stop_closesConsumerEvenAfterUnsubscriptionError() {
var kafkaConsumerMock = mock(KafkaConsumer.class);
var unsubscribeException = new RuntimeException();
doThrow(unsubscribeException).when(kafkaConsumerMock).unsubscribe();

KafkaEndpoint endpoint = KafkaEndpoint.builder()
.kafkaConsumer(kafkaConsumerMock)
.build();

assertThatThrownBy(() -> endpoint.createConsumer().stop())
.isEqualTo(unsubscribeException);

verify(kafkaConsumerMock).close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,14 @@ public void newKafkaEndpoint_isAbleToCreateRandomConsumerGroup() {
.startsWith(KAFKA_PREFIX)
.hasSize(23)
.containsPattern(".*[a-z]{10}$")
// Make sure the random group id is propagated to new consumers
.satisfies(
// Additionally make sure that gets passed downstream
groupId -> assertThat(fixture.createConsumer().getConsumer())
.extracting("delegate")
.extracting("groupId")
.asInstanceOf(OPTIONAL)
.hasValue(groupId),
groupId -> assertThat(fixture.createConsumer().createManagedConsumer())
.extracting("delegate")
.extracting("groupId")
.asInstanceOf(OPTIONAL)
Expand Down

0 comments on commit 1c3db3d

Please sign in to comment.