From aa0d8bba6da994a0b0d03b7de6b9ceb974ab720a Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Wed, 2 Oct 2024 12:03:46 +1000 Subject: [PATCH] KafkaSinkCluster transaction request routing --- .../tests/kafka_int_tests/test_cases.rs | 117 ++++++++ .../src/transforms/kafka/sink_cluster/mod.rs | 268 ++++++++++++++++-- test-helpers/src/connection/kafka/cpp.rs | 38 ++- test-helpers/src/connection/kafka/java.rs | 49 ++++ test-helpers/src/connection/kafka/mod.rs | 50 ++++ 5 files changed, 501 insertions(+), 21 deletions(-) diff --git a/shotover-proxy/tests/kafka_int_tests/test_cases.rs b/shotover-proxy/tests/kafka_int_tests/test_cases.rs index 42414fda4..3562b9ebf 100644 --- a/shotover-proxy/tests/kafka_int_tests/test_cases.rs +++ b/shotover-proxy/tests/kafka_int_tests/test_cases.rs @@ -74,6 +74,16 @@ async fn admin_setup(connection_builder: &KafkaConnectionBuilder) { num_partitions: 3, replication_factor: 1, }, + NewTopic { + name: "transaction_topic1", + num_partitions: 3, + replication_factor: 1, + }, + NewTopic { + name: "transaction_topic2", + num_partitions: 3, + replication_factor: 1, + }, ]) .await; @@ -735,6 +745,111 @@ async fn produce_consume_partitions3( } } +async fn produce_consume_transactions(connection_builder: &KafkaConnectionBuilder) { + let producer = connection_builder.connect_producer("1", 0).await; + for i in 0..5 { + producer + .assert_produce( + Record { + payload: &format!("Message1_{i}"), + topic_name: "transaction_topic1", + key: Some("Key".into()), + }, + // We cant predict the offsets since that will depend on which partition the keyless record ends up in + None, + ) + .await; + producer + .assert_produce( + Record { + payload: &format!("Message2_{i}"), + topic_name: "transaction_topic1", + key: Some("Key".into()), + }, + None, + ) + .await; + } + + let producer = connection_builder + .connect_producer_with_transactions("some_transaction_id".to_owned()) + .await; + let mut consumer_topic1 = connection_builder + .connect_consumer( + ConsumerConfig::consume_from_topic("transaction_topic1".to_owned()) + .with_group("some_group"), + ) + .await; + let mut consumer_topic2 = connection_builder + .connect_consumer( + ConsumerConfig::consume_from_topic("transaction_topic2".to_owned()) + .with_group("some_group"), + ) + .await; + + for i in 0..5 { + consumer_topic1 + .assert_consume_in_any_order(vec![ + ExpectedResponse { + message: format!("Message1_{i}"), + key: Some("Key".to_owned()), + topic_name: "transaction_topic1".to_owned(), + offset: None, + }, + ExpectedResponse { + message: format!("Message2_{i}"), + key: Some("Key".into()), + topic_name: "transaction_topic1".to_owned(), + offset: None, + }, + ]) + .await; + producer.begin_transaction(); + + producer + .assert_produce( + Record { + payload: &format!("Message1_{i}"), + topic_name: "transaction_topic2", + key: Some("Key".into()), + }, + // We cant predict the offsets since that will depend on which partition the keyless record ends up in + None, + ) + .await; + producer + .assert_produce( + Record { + payload: &format!("Message2_{i}"), + topic_name: "transaction_topic2", + key: Some("Key".into()), + }, + None, + ) + .await; + + producer.send_offsets_to_transaction(&consumer_topic1); + producer.commit_transaction(); + + consumer_topic2 + .assert_consume_in_any_order(vec![ + ExpectedResponse { + message: format!("Message1_{i}"), + key: Some("Key".to_owned()), + topic_name: "transaction_topic2".to_owned(), + offset: None, + }, + ExpectedResponse { + message: format!("Message2_{i}"), + key: Some("Key".to_owned()), + topic_name: "transaction_topic2".to_owned(), + offset: None, + }, + ]) + .await; + } +} + async fn produce_consume_acks0(connection_builder: &KafkaConnectionBuilder) { let topic_name = "acks0"; let producer = connection_builder.connect_producer("0", 0).await; @@ -844,6 +959,8 @@ pub async fn standard_test_suite(connection_builder: &KafkaConnectionBuilder) { // set the bytes limit to 1MB so that we will not reach it and will hit the 100ms timeout every time. produce_consume_partitions3(connection_builder, "partitions3_case4", 1_000_000, 100).await; + produce_consume_transactions(connection_builder).await; + // Only run this test case on the java driver, // since even without going through shotover the cpp driver fails this test. #[allow(irrefutable_let_patterns)] diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index 4ba41a4ab..39ac7c8d5 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -20,10 +20,11 @@ use kafka_protocol::messages::metadata_response::MetadataResponseBroker; use kafka_protocol::messages::produce_request::TopicProduceData; use kafka_protocol::messages::produce_response::LeaderIdAndEpoch as ProduceResponseLeaderIdAndEpoch; use kafka_protocol::messages::{ - ApiKey, BrokerId, FetchRequest, FetchResponse, FindCoordinatorRequest, FindCoordinatorResponse, - GroupId, HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, MetadataRequest, - MetadataResponse, ProduceRequest, ProduceResponse, RequestHeader, SaslAuthenticateRequest, - SaslAuthenticateResponse, SaslHandshakeRequest, SyncGroupRequest, TopicName, + ApiKey, BrokerId, EndTxnRequest, FetchRequest, FetchResponse, FindCoordinatorRequest, + FindCoordinatorResponse, GroupId, HeartbeatRequest, InitProducerIdRequest, JoinGroupRequest, + LeaveGroupRequest, MetadataRequest, MetadataResponse, ProduceRequest, ProduceResponse, + RequestHeader, SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest, + SyncGroupRequest, TopicName, TransactionalId, }; use kafka_protocol::protocol::StrBytes; use kafka_protocol::ResponseError; @@ -138,6 +139,7 @@ struct KafkaSinkClusterBuilder { read_timeout: Option, controller_broker: Arc, group_to_coordinator_broker: Arc>, + transaction_to_coordinator_broker: Arc>, topic_by_name: Arc>, topic_by_id: Arc>, nodes_shared: Arc>>, @@ -173,6 +175,7 @@ impl KafkaSinkClusterBuilder { read_timeout, controller_broker: Arc::new(AtomicBrokerId::new()), group_to_coordinator_broker: Arc::new(DashMap::new()), + transaction_to_coordinator_broker: Arc::new(DashMap::new()), topic_by_name: Arc::new(DashMap::new()), topic_by_id: Arc::new(DashMap::new()), nodes_shared: Arc::new(RwLock::new(vec![])), @@ -193,6 +196,7 @@ impl TransformBuilder for KafkaSinkClusterBuilder { nodes_shared: self.nodes_shared.clone(), controller_broker: self.controller_broker.clone(), group_to_coordinator_broker: self.group_to_coordinator_broker.clone(), + transaction_to_coordinator_broker: self.transaction_to_coordinator_broker.clone(), topic_by_name: self.topic_by_name.clone(), topic_by_id: self.topic_by_id.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), @@ -254,6 +258,7 @@ struct KafkaSinkCluster { nodes_shared: Arc>>, controller_broker: Arc, group_to_coordinator_broker: Arc>, + transaction_to_coordinator_broker: Arc>, topic_by_name: Arc>, topic_by_id: Arc>, rng: SmallRng, @@ -321,6 +326,8 @@ enum PendingRequestTy { FindCoordinator(FindCoordinator), // Covers multiple request types: JoinGroup, DeleteGroups etc. RoutedToGroup(GroupId), + // Covers multiple request types: InitProducerId, EndTxn etc. + RoutedToTransaction(TransactionalId), Other, } @@ -502,6 +509,28 @@ impl KafkaSinkCluster { } } + fn store_transaction( + &self, + transactions: &mut Vec, + transaction: TransactionalId, + ) { + let cache_is_missing_or_outdated = + match self.transaction_to_coordinator_broker.get(&transaction) { + Some(broker_id) => self + .nodes + .iter() + .find(|node| node.broker_id == *broker_id) + .map(|node| !node.is_up()) + .unwrap_or(true), + None => true, + }; + + if cache_is_missing_or_outdated && !transactions.contains(&transaction) { + debug_assert!(transaction.0.as_str() != ""); + transactions.push(transaction); + } + } + async fn update_local_nodes(&mut self) { self.nodes.clone_from(&*self.nodes_shared.read().await); } @@ -587,6 +616,7 @@ impl KafkaSinkCluster { let mut topic_names = vec![]; let mut topic_ids = vec![]; let mut groups = vec![]; + let mut transactions = vec![]; for request in &mut requests { match request.frame() { Some(Frame::Kafka(KafkaFrame::Request { @@ -619,6 +649,43 @@ impl KafkaSinkCluster { })) => { self.store_group(&mut groups, group_id.clone()); } + Some(Frame::Kafka(KafkaFrame::Request { + body: + // TODO: only keep the ones we actually to route for + // RequestBody::TxnOffsetCommit(TxnOffsetCommitRequest { + // transactional_id, .. + // })| + RequestBody::InitProducerId(InitProducerIdRequest { + transactional_id: Some(transactional_id), .. + }) + | RequestBody::EndTxn(EndTxnRequest { + transactional_id, .. + }), + // | RequestBody::AddOffsetsToTxn(AddOffsetsToTxnRequest { + // transactional_id, .. + // }), + .. + })) => { + self.store_transaction(&mut transactions, transactional_id.clone()); + } + Some(Frame::Kafka(KafkaFrame::Request { + body: + RequestBody::AddPartitionsToTxn(add_partitions_to_txn_request) + , + header, + })) => { + if header.request_api_version <= 3 { + self.store_transaction( + &mut transactions, + add_partitions_to_txn_request.v3_and_below_transactional_id.clone() + ); + } + else { + for transaction in add_partitions_to_txn_request.transactions.keys() { + self.store_transaction(&mut transactions, transaction.clone()); + } + } + } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::OffsetFetch(offset_fetch), header, @@ -636,7 +703,10 @@ impl KafkaSinkCluster { } for group in groups { - match self.find_coordinator_of_group(group.clone()).await { + match self + .find_coordinator(CoordinatorKey::Group(group.clone())) + .await + { Ok(node) => { tracing::debug!( "Storing group_to_coordinator_broker metadata, group {:?} -> broker {}", @@ -657,6 +727,31 @@ impl KafkaSinkCluster { } } + for transaction in transactions { + match self + .find_coordinator(CoordinatorKey::Transaction(transaction.clone())) + .await + { + Ok(node) => { + tracing::debug!( + "Storing transaction_to_coordinator_broker metadata, group {:?} -> broker {}", + transaction.0, + node.broker_id.0 + ); + self.transaction_to_coordinator_broker + .insert(transaction, node.broker_id); + self.add_node_if_new(node).await; + } + Err(FindCoordinatorError::CoordinatorNotAvailable) => { + // We cant find the coordinator so do nothing so that the request will be routed to a random node: + // * If it happens to be the coordinator all is well + // * If its not the coordinator then it will return a NOT_COORDINATOR message to + // the client prompting it to retry the whole process again. + } + Err(FindCoordinatorError::Unrecoverable(err)) => Err(err)?, + } + } + // request and process metadata if we are missing topics or the controller broker id if !topic_names.is_empty() || !topic_ids.is_empty() @@ -714,14 +809,55 @@ impl KafkaSinkCluster { .. })) => { let group_id = heartbeat.group_id.clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); + } + Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::AddPartitionsToTxn(add_partitions_to_txn), + header, + })) => { + if header.request_api_version <= 3 { + let transaction_id = + add_partitions_to_txn.v3_and_below_transactional_id.clone(); + self.route_to_transaction_coordinator(message, transaction_id); + } else { + #[allow(clippy::never_loop)] + for transaction_id in add_partitions_to_txn.transactions.keys() { + let transaction_id = transaction_id.clone(); + self.route_to_transaction_coordinator(message, transaction_id); + break; + } + } + } + Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::EndTxn(end_txn), + .. + })) => { + let transaction_id = end_txn.transactional_id.clone(); + self.route_to_transaction_coordinator(message, transaction_id); + } + Some(Frame::Kafka(KafkaFrame::Request { + body: RequestBody::InitProducerId(init_producer_id), + .. + })) => { + if let Some(transaction_id) = init_producer_id.transactional_id.clone() { + self.route_to_transaction_coordinator(message, transaction_id); + } else { + // TODO: dedupe + let destination = random_broker_id(&self.nodes, &mut self.rng); + tracing::debug!("Routing request to random broker {}", destination.0); + self.pending_requests.push_back(PendingRequest { + state: PendingRequestState::routed(destination, message), + ty: PendingRequestTy::Other, + combine_responses: 1, + }); + } } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::SyncGroup(sync_group), .. })) => { let group_id = sync_group.group_id.clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::OffsetFetch(offset_fetch), @@ -738,28 +874,28 @@ impl KafkaSinkCluster { // For now just pick the first group as that is sufficient for the simple cases. offset_fetch.groups.first().unwrap().group_id.clone() }; - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::OffsetCommit(offset_commit), .. })) => { let group_id = offset_commit.group_id.clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::JoinGroup(join_group), .. })) => { let group_id = join_group.group_id.clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::LeaveGroup(leave_group), .. })) => { let group_id = leave_group.group_id.clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::DeleteGroups(groups), @@ -767,7 +903,7 @@ impl KafkaSinkCluster { })) => { // TODO: we need to split this up into multiple requests so it can be correctly routed to all possible nodes let group_id = groups.groups_names.first().unwrap().clone(); - self.route_to_coordinator(message, group_id); + self.route_to_group_coordinator(message, group_id); } Some(Frame::Kafka(KafkaFrame::Request { body: RequestBody::FindCoordinator(_), @@ -1101,9 +1237,9 @@ impl KafkaSinkCluster { Ok(()) } - async fn find_coordinator_of_group( + async fn find_coordinator( &mut self, - group: GroupId, + key: CoordinatorKey, ) -> Result { let request = Message::from_frame(Frame::Kafka(KafkaFrame::Request { header: RequestHeader::default() @@ -1112,15 +1248,21 @@ impl KafkaSinkCluster { .with_correlation_id(0), body: RequestBody::FindCoordinator( FindCoordinatorRequest::default() - .with_key_type(0) - .with_key(group.0.clone()), + .with_key_type(match key { + CoordinatorKey::Group(_) => 0, + CoordinatorKey::Transaction(_) => 1, + }) + .with_key(match &key { + CoordinatorKey::Group(id) => id.0.clone(), + CoordinatorKey::Transaction(id) => id.0.clone(), + }), ), })); let mut response = self .control_send_receive(request) .await - .with_context(|| format!("Failed to query for coordinator of group {:?}", group.0))?; + .with_context(|| format!("Failed to query for coordinator of {key:?}"))?; match response.frame() { Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::FindCoordinator(coordinator), @@ -1218,6 +1360,7 @@ impl KafkaSinkCluster { } } PendingRequestTy::RoutedToGroup(_) => None, + PendingRequestTy::RoutedToTransaction(_) => None, PendingRequestTy::FindCoordinator(_) => None, PendingRequestTy::Other => None, }; @@ -1767,6 +1910,40 @@ impl KafkaSinkCluster { body: ResponseBody::LeaveGroup(leave_group), .. })) => self.handle_group_coordinator_routing_error(&request_ty, leave_group.error_code), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::EndTxn(end_txn), + .. + })) => { + self.handle_transaction_coordinator_routing_error(&request_ty, end_txn.error_code) + } + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::InitProducerId(init_producer_id), + .. + })) => self.handle_transaction_coordinator_routing_error( + &request_ty, + init_producer_id.error_code, + ), + Some(Frame::Kafka(KafkaFrame::Response { + body: ResponseBody::AddPartitionsToTxn(response), + version, + .. + })) => { + if *version <= 3 { + for topic_result in response.results_by_topic_v3_and_below.values() { + for partition_result in topic_result.results_by_partition.values() { + self.handle_transaction_coordinator_routing_error( + &request_ty, + partition_result.partition_error_code, + ) + } + } + } else { + self.handle_transaction_coordinator_routing_error( + &request_ty, + response.error_code, + ) + } + } Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::DeleteGroups(delete_groups), .. @@ -1821,7 +1998,7 @@ impl KafkaSinkCluster { Ok(()) } - /// This method must be called for every response to a request that was routed via `route_to_coordinator` + /// This method must be called for every response to a request that was routed via `route_to_group_coordinator` fn handle_group_coordinator_routing_error( &mut self, pending_request_ty: &PendingRequestTy, @@ -1839,6 +2016,26 @@ impl KafkaSinkCluster { } } + /// This method must be called for every response to a request that was routed via `route_to_transaction_coordinator` + fn handle_transaction_coordinator_routing_error( + &mut self, + pending_request_ty: &PendingRequestTy, + error_code: i16, + ) { + if let Some(ResponseError::NotCoordinator) = ResponseError::try_from_code(error_code) { + if let PendingRequestTy::RoutedToTransaction(transaction_id) = pending_request_ty { + let broker_id = self + .transaction_to_coordinator_broker + .remove(transaction_id); + tracing::info!( + "Response was error NOT_COORDINATOR and so cleared transaction id {:?} coordinator mapping to broker {:?}", + transaction_id, + broker_id, + ); + } + } + } + fn process_sasl_authenticate( &mut self, authenticate: &mut SaslAuthenticateResponse, @@ -1922,7 +2119,7 @@ impl KafkaSinkCluster { ); } - fn route_to_coordinator(&mut self, request: Message, group_id: GroupId) { + fn route_to_group_coordinator(&mut self, request: Message, group_id: GroupId) { let destination = self.group_to_coordinator_broker.get(&group_id); let destination = match destination { Some(destination) => *destination, @@ -1966,6 +2163,33 @@ impl KafkaSinkCluster { } } + fn route_to_transaction_coordinator( + &mut self, + request: Message, + transaction_id: TransactionalId, + ) { + let destination = self.transaction_to_coordinator_broker.get(&transaction_id); + let destination = match destination { + Some(destination) => *destination, + None => { + tracing::warn!("no known coordinator for {transaction_id:?}, routing message to a random broker so that a NOT_COORDINATOR or similar error is returned to the client"); + random_broker_id(&self.nodes, &mut self.rng) + } + }; + + tracing::debug!( + "Routing request relating to transaction id {:?} to broker {}", + transaction_id.0, + destination.0 + ); + + self.pending_requests.push_back(PendingRequest { + state: PendingRequestState::routed(destination, request), + ty: PendingRequestTy::RoutedToTransaction(transaction_id), + combine_responses: 1, + }); + } + async fn process_metadata_response(&mut self, metadata: &MetadataResponse) { for (id, broker) in &metadata.brokers { let node = KafkaNode::new( @@ -2350,6 +2574,12 @@ impl KafkaSinkCluster { } } +#[derive(Debug)] +enum CoordinatorKey { + Group(GroupId), + Transaction(TransactionalId), +} + fn hash_partition(topic_id: Uuid, partition_index: i32) -> usize { let mut hasher = xxhash_rust::xxh3::Xxh3::new(); hasher.write(topic_id.as_bytes()); diff --git a/test-helpers/src/connection/kafka/cpp.rs b/test-helpers/src/connection/kafka/cpp.rs index 47a44fe76..5a4dc3182 100644 --- a/test-helpers/src/connection/kafka/cpp.rs +++ b/test-helpers/src/connection/kafka/cpp.rs @@ -13,10 +13,10 @@ use rdkafka::admin::{ use rdkafka::client::DefaultClientContext; use rdkafka::config::ClientConfig; use rdkafka::consumer::{Consumer, StreamConsumer}; -use rdkafka::producer::{FutureProducer, FutureRecord}; +use rdkafka::producer::{FutureProducer, FutureRecord, Producer}; use rdkafka::types::RDKafkaErrorCode; use rdkafka::util::Timeout; -use rdkafka::Message; +use rdkafka::{Message, TopicPartitionList}; use std::time::Duration; pub struct KafkaConnectionBuilderCpp { @@ -63,6 +63,22 @@ impl KafkaConnectionBuilderCpp { } } + pub fn connect_producer_with_transactions(&self, transaction_id: String) -> KafkaProducerCpp { + let producer: FutureProducer = self + .client + .clone() + .set("transactional.id", transaction_id) + .set("message.timeout.ms", "5000") + .set("linger.ms", "0") + .set("acks", "all") + .create() + .unwrap(); + // If the timeout is too low we hit: Transaction error: Failed to initialize Producer ID: Broker: Coordinator load in progress + // 5s seems fine + producer.init_transactions(Duration::from_secs(5)).unwrap(); + KafkaProducerCpp { producer } + } + pub async fn connect_consumer(&self, config: ConsumerConfig) -> KafkaConsumerCpp { let consumer: StreamConsumer = self .client @@ -142,6 +158,24 @@ impl KafkaProducerCpp { assert_eq!(delivery_status.1, offset, "Unexpected offset"); } } + + pub fn begin_transaction(&self) { + self.producer.begin_transaction().unwrap(); + } + + pub fn send_offsets_to_transaction(&self, consumer: &KafkaConsumerCpp) { + let topic_partitions = TopicPartitionList::new(); + let consumer_group = consumer.consumer.group_metadata().unwrap(); + self.producer + .send_offsets_to_transaction(&topic_partitions, &consumer_group, Duration::from_secs(1)) + .unwrap(); + } + + pub fn commit_transaction(&self) { + self.producer + .commit_transaction(Duration::from_secs(1)) + .unwrap(); + } } pub struct KafkaConsumerCpp { diff --git a/test-helpers/src/connection/kafka/java.rs b/test-helpers/src/connection/kafka/java.rs index 21ae4692b..34a66c132 100644 --- a/test-helpers/src/connection/kafka/java.rs +++ b/test-helpers/src/connection/kafka/java.rs @@ -103,6 +103,35 @@ impl KafkaConnectionBuilderJava { KafkaProducerJava { jvm, producer } } + pub async fn connect_producer_with_transactions( + &self, + transaction_id: String, + ) -> KafkaProducerJava { + let mut config = self.base_config.clone(); + config.insert("acks".to_owned(), "all".to_owned()); + config.insert("linger.ms".to_owned(), "0".to_owned()); + config.insert("transactional.id".to_owned(), transaction_id); + config.insert( + "key.serializer".to_owned(), + "org.apache.kafka.common.serialization.StringSerializer".to_owned(), + ); + config.insert( + "value.serializer".to_owned(), + "org.apache.kafka.common.serialization.StringSerializer".to_owned(), + ); + + let properties = properties(&self.jvm, &config); + let producer = self.jvm.construct( + "org.apache.kafka.clients.producer.KafkaProducer", + vec![properties], + ); + + producer.call("initTransactions", vec![]); + + let jvm = self.jvm.clone(); + KafkaProducerJava { jvm, producer } + } + pub async fn connect_consumer(&self, consumer_config: ConsumerConfig) -> KafkaConsumerJava { let mut config = self.base_config.clone(); config.insert("group.id".to_owned(), consumer_config.group); @@ -192,6 +221,26 @@ impl KafkaProducerJava { assert_eq!(expected_offset, actual_offset); } } + + pub fn begin_transaction(&self) { + self.producer.call("beginTransaction", vec![]); + } + + pub fn commit_transaction(&self) { + self.producer.call("commitTransaction", vec![]); + } + + pub fn send_offsets_to_transaction(&self, consumer: &KafkaConsumerJava) { + let offsets = self.jvm.new_map(vec![]); + + let consumer_group_id = consumer + .consumer + .call("groupMetadata", vec![]) + .call("groupId", vec![]); + + self.producer + .call("sendOffsetsToTransaction", vec![offsets, consumer_group_id]); + } } pub struct KafkaConsumerJava { diff --git a/test-helpers/src/connection/kafka/mod.rs b/test-helpers/src/connection/kafka/mod.rs index aaecac4c8..531649ecc 100644 --- a/test-helpers/src/connection/kafka/mod.rs +++ b/test-helpers/src/connection/kafka/mod.rs @@ -75,6 +75,22 @@ impl KafkaConnectionBuilder { } } + pub async fn connect_producer_with_transactions( + &self, + transaction_id: String, + ) -> KafkaProducer { + match self { + #[cfg(feature = "kafka-cpp-driver-tests")] + Self::Cpp(cpp) => { + KafkaProducer::Cpp(cpp.connect_producer_with_transactions(transaction_id)) + } + Self::Java(java) => KafkaProducer::Java( + java.connect_producer_with_transactions(transaction_id) + .await, + ), + } + } + pub async fn connect_consumer(&self, config: ConsumerConfig) -> KafkaConsumer { match self { #[cfg(feature = "kafka-cpp-driver-tests")] @@ -127,6 +143,40 @@ impl KafkaProducer { Self::Java(java) => java.assert_produce(record, expected_offset).await, } } + + pub fn begin_transaction(&self) { + match self { + #[cfg(feature = "kafka-cpp-driver-tests")] + Self::Cpp(cpp) => cpp.begin_transaction(), + Self::Java(java) => java.begin_transaction(), + } + } + + pub fn commit_transaction(&self) { + match self { + #[cfg(feature = "kafka-cpp-driver-tests")] + Self::Cpp(cpp) => cpp.commit_transaction(), + Self::Java(java) => java.commit_transaction(), + } + } + + pub fn send_offsets_to_transaction(&self, consumer: &KafkaConsumer) { + match self { + #[cfg(feature = "kafka-cpp-driver-tests")] + Self::Cpp(cpp) => match consumer { + KafkaConsumer::Cpp(consumer) => cpp.send_offsets_to_transaction(consumer), + KafkaConsumer::Java(_) => { + panic!("Cannot use transactions across java and cpp driver") + } + }, + Self::Java(java) => match consumer { + KafkaConsumer::Java(consumer) => java.send_offsets_to_transaction(consumer), + KafkaConsumer::Cpp(_) => { + panic!("Cannot use transactions across java and cpp driver") + } + }, + } + } } pub struct Record<'a> {