diff --git a/shotover/src/transforms/kafka/sink_cluster.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs similarity index 92% rename from shotover/src/transforms/kafka/sink_cluster.rs rename to shotover/src/transforms/kafka/sink_cluster/mod.rs index 1ac0f8fc9..7763b17d3 100644 --- a/shotover/src/transforms/kafka/sink_cluster.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -1,11 +1,8 @@ use super::common::produce_channel; -use crate::codec::{kafka::KafkaCodecBuilder, CodecBuilder, Direction}; use crate::frame::kafka::{strbytes, KafkaFrame, RequestBody, ResponseBody}; use crate::frame::Frame; use crate::message::{Message, Messages}; -use crate::tcp; use crate::tls::{TlsConnector, TlsConnectorConfig}; -use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; use crate::transforms::{Transform, TransformBuilder, Wrapper}; use crate::transforms::{TransformConfig, TransformContextConfig}; @@ -20,6 +17,7 @@ use kafka_protocol::messages::{ TopicName, }; use kafka_protocol::protocol::{Builder, StrBytes}; +use node::{KafkaAddress, KafkaNode}; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; use rand::SeedableRng; @@ -30,10 +28,11 @@ use std::net::SocketAddr; use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; -use tokio::io::split; use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::time::timeout; +mod node; + #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct KafkaSinkClusterConfig { @@ -193,11 +192,10 @@ impl Transform for KafkaSinkCluster { .first_contact_points .iter() .map(|address| { - Ok(KafkaNode { - connection: None, - kafka_address: KafkaAddress::from_str(address)?, - broker_id: BrokerId(-1), - }) + Ok(KafkaNode::new( + BrokerId(-1), + KafkaAddress::from_str(address)?, + )) }) .collect(); self.nodes = nodes?; @@ -525,14 +523,10 @@ impl KafkaSinkCluster { Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::FindCoordinator(coordinator), .. - })) => Ok(KafkaNode { - broker_id: coordinator.node_id, - kafka_address: KafkaAddress { - host: coordinator.host.clone(), - port: coordinator.port, - }, - connection: None, - }), + })) => Ok(KafkaNode::new( + coordinator.node_id, + KafkaAddress::new(coordinator.host.clone(), coordinator.port), + )), other => Err(anyhow!( "Unexpected message returned to findcoordinator request {other:?}" )), @@ -742,14 +736,7 @@ impl KafkaSinkCluster { async fn process_metadata(&mut self, metadata: &MetadataResponse) { for (id, broker) in &metadata.brokers { - let node = KafkaNode { - broker_id: *id, - kafka_address: KafkaAddress { - host: broker.host.clone(), - port: broker.port, - }, - connection: None, - }; + let node = KafkaNode::new(*id, KafkaAddress::new(broker.host.clone(), broker.port)); self.add_node_if_new(node).await; } @@ -884,39 +871,6 @@ fn deduplicate_coordinators(coordinators: &mut Vec) { } } -#[derive(Clone)] -struct KafkaNode { - broker_id: BrokerId, - kafka_address: KafkaAddress, - connection: Option, -} - -impl KafkaNode { - async fn get_connection( - &mut self, - connect_timeout: Duration, - tls: &Option, - ) -> Result<&Connection> { - if self.connection.is_none() { - let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned()); - let address = ( - self.kafka_address.host.to_string(), - self.kafka_address.port as u16, - ); - if let Some(tls) = tls.as_ref() { - let tls_stream = tls.connect(connect_timeout, address).await?; - let (rx, tx) = split(tls_stream); - self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); - } else { - let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?; - let (rx, tx) = tcp_stream.into_split(); - self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); - } - } - Ok(self.connection.as_ref().unwrap()) - } -} - #[derive(Debug)] struct Topic { partitions: Vec, @@ -928,30 +882,6 @@ struct Partition { replica_nodes: Vec, } -#[derive(Clone, PartialEq)] -struct KafkaAddress { - host: StrBytes, - port: i32, -} - -impl KafkaAddress { - fn from_str(address: &str) -> Result { - let mut address_iter = address.split(':'); - Ok(KafkaAddress { - host: strbytes( - address_iter - .next() - .ok_or_else(|| anyhow!("Address must include ':' seperator"))?, - ), - port: address_iter - .next() - .ok_or_else(|| anyhow!("Address must include port after ':'"))? - .parse() - .map_err(|_| anyhow!("Failed to parse address port as integer"))?, - }) - } -} - struct FindCoordinator { index: usize, key: StrBytes, diff --git a/shotover/src/transforms/kafka/sink_cluster/node.rs b/shotover/src/transforms/kafka/sink_cluster/node.rs new file mode 100644 index 000000000..037f78e73 --- /dev/null +++ b/shotover/src/transforms/kafka/sink_cluster/node.rs @@ -0,0 +1,79 @@ +use crate::codec::{kafka::KafkaCodecBuilder, CodecBuilder, Direction}; +use crate::frame::kafka::strbytes; +use crate::tcp; +use crate::tls::TlsConnector; +use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; +use anyhow::{anyhow, Result}; +use kafka_protocol::messages::BrokerId; +use kafka_protocol::protocol::StrBytes; +use std::time::Duration; +use tokio::io::split; + +#[derive(Clone, PartialEq)] +pub struct KafkaAddress { + pub host: StrBytes, + pub port: i32, +} + +impl KafkaAddress { + pub fn new(host: StrBytes, port: i32) -> Self { + KafkaAddress { host, port } + } + + pub fn from_str(address: &str) -> Result { + let mut address_iter = address.split(':'); + Ok(KafkaAddress { + host: strbytes( + address_iter + .next() + .ok_or_else(|| anyhow!("Address must include ':' seperator"))?, + ), + port: address_iter + .next() + .ok_or_else(|| anyhow!("Address must include port after ':'"))? + .parse() + .map_err(|_| anyhow!("Failed to parse address port as integer"))?, + }) + } +} + +#[derive(Clone)] +pub struct KafkaNode { + pub broker_id: BrokerId, + pub kafka_address: KafkaAddress, + connection: Option, +} + +impl KafkaNode { + pub fn new(broker_id: BrokerId, kafka_address: KafkaAddress) -> Self { + KafkaNode { + broker_id, + kafka_address, + connection: None, + } + } + + pub async fn get_connection( + &mut self, + connect_timeout: Duration, + tls: &Option, + ) -> Result<&Connection> { + if self.connection.is_none() { + let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned()); + let address = ( + self.kafka_address.host.to_string(), + self.kafka_address.port as u16, + ); + if let Some(tls) = tls.as_ref() { + let tls_stream = tls.connect(connect_timeout, address).await?; + let (rx, tx) = split(tls_stream); + self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); + } else { + let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?; + let (rx, tx) = tcp_stream.into_split(); + self.connection = Some(spawn_read_write_tasks(&codec, rx, tx)); + } + } + Ok(self.connection.as_ref().unwrap()) + } +}