Skip to content

Commit

Permalink
Merge branch 'main' into disable_encode_tests_when_alpha_transform_di…
Browse files Browse the repository at this point in the history
…sabled
  • Loading branch information
rukai authored Mar 5, 2024
2 parents ced8e46 + b81a092 commit 175dd9e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 39 deletions.
38 changes: 14 additions & 24 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use kafka_protocol::messages::{
TopicName,
};
use kafka_protocol::protocol::{Builder, StrBytes};
use node::{KafkaAddress, KafkaNode};
use node::{ConnectionFactory, KafkaAddress, KafkaNode};
use rand::rngs::SmallRng;
use rand::seq::{IteratorRandom, SliceRandom};
use rand::SeedableRng;
Expand Down Expand Up @@ -118,15 +118,14 @@ impl TransformBuilder for KafkaSinkClusterBuilder {
first_contact_points: self.first_contact_points.clone(),
shotover_nodes: self.shotover_nodes.clone(),
pushed_messages_tx: None,
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
nodes: vec![],
nodes_shared: self.nodes_shared.clone(),
controller_broker: self.controller_broker.clone(),
group_to_coordinator_broker: self.group_to_coordinator_broker.clone(),
topics: self.topics.clone(),
rng: SmallRng::from_rng(rand::thread_rng()).unwrap(),
tls: self.tls.clone(),
connection_factory: ConnectionFactory::new(self.tls.clone(), self.connect_timeout),
})
}

Expand Down Expand Up @@ -165,15 +164,14 @@ pub struct KafkaSinkCluster {
first_contact_points: Vec<String>,
shotover_nodes: Vec<KafkaAddress>,
pushed_messages_tx: Option<mpsc::UnboundedSender<Messages>>,
connect_timeout: Duration,
read_timeout: Option<Duration>,
nodes: Vec<KafkaNode>,
nodes_shared: Arc<RwLock<Vec<KafkaNode>>>,
controller_broker: Arc<AtomicBrokerId>,
group_to_coordinator_broker: Arc<DashMap<GroupId, BrokerId>>,
topics: Arc<DashMap<TopicName, Topic>>,
rng: SmallRng,
tls: Option<TlsConnector>,
connection_factory: ConnectionFactory,
}

#[async_trait]
Expand Down Expand Up @@ -345,9 +343,7 @@ impl KafkaSinkCluster {
for node in &mut self.nodes {
if node.broker_id == partition.leader_id {
connection = Some(
node.get_connection(self.connect_timeout, &self.tls)
.await?
.clone(),
node.get_connection(&self.connection_factory).await?.clone(),
);
}
}
Expand All @@ -359,7 +355,7 @@ impl KafkaSinkCluster {
self.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?
.clone()
}
Expand Down Expand Up @@ -399,7 +395,7 @@ impl KafkaSinkCluster {
.filter(|node| partition.replica_nodes.contains(&node.broker_id))
.choose(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?
.clone()
} else {
Expand All @@ -408,7 +404,7 @@ impl KafkaSinkCluster {
self.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?
.clone()
};
Expand Down Expand Up @@ -472,7 +468,7 @@ impl KafkaSinkCluster {
.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?;
let (tx, rx) = oneshot::channel();
connection
Expand Down Expand Up @@ -509,7 +505,7 @@ impl KafkaSinkCluster {
.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?;
let (tx, rx) = oneshot::channel();
connection
Expand Down Expand Up @@ -563,7 +559,7 @@ impl KafkaSinkCluster {
.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?;
let (tx, rx) = oneshot::channel();
connection
Expand Down Expand Up @@ -672,15 +668,13 @@ impl KafkaSinkCluster {
let connection = if let Some(node) =
self.nodes.iter_mut().find(|x| x.broker_id == *broker_id)
{
node.get_connection(self.connect_timeout, &self.tls)
.await?
.clone()
node.get_connection(&self.connection_factory).await?.clone()
} else {
tracing::warn!("no known broker with id {broker_id:?}, routing message to a random node so that a NOT_CONTROLLER or similar error is returned to the client");
self.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?
.clone()
};
Expand All @@ -704,11 +698,7 @@ impl KafkaSinkCluster {
for node in &mut self.nodes {
if let Some(broker_id) = self.group_to_coordinator_broker.get(&group_id) {
if node.broker_id == *broker_id {
connection = Some(
node.get_connection(self.connect_timeout, &self.tls)
.await?
.clone(),
);
connection = Some(node.get_connection(&self.connection_factory).await?.clone());
}
}
}
Expand All @@ -719,7 +709,7 @@ impl KafkaSinkCluster {
self.nodes
.choose_mut(&mut self.rng)
.unwrap()
.get_connection(self.connect_timeout, &self.tls)
.get_connection(&self.connection_factory)
.await?
.clone()
}
Expand Down
50 changes: 35 additions & 15 deletions shotover/src/transforms/kafka/sink_cluster/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,36 @@ use kafka_protocol::protocol::StrBytes;
use std::time::Duration;
use tokio::io::split;

pub struct ConnectionFactory {
tls: Option<TlsConnector>,
connect_timeout: Duration,
}

impl ConnectionFactory {
pub fn new(tls: Option<TlsConnector>, connect_timeout: Duration) -> Self {
ConnectionFactory {
tls,
connect_timeout,
}
}

pub async fn create_connection(&self, kafka_address: &KafkaAddress) -> Result<Connection> {
let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned());
let address = (kafka_address.host.to_string(), kafka_address.port as u16);
if let Some(tls) = self.tls.as_ref() {
let tls_stream = tls.connect(self.connect_timeout, address).await?;
let (rx, tx) = split(tls_stream);
let connection = spawn_read_write_tasks(&codec, rx, tx);
Ok(connection)
} else {
let tcp_stream = tcp::tcp_stream(self.connect_timeout, address).await?;
let (rx, tx) = tcp_stream.into_split();
let connection = spawn_read_write_tasks(&codec, rx, tx);
Ok(connection)
}
}
}

#[derive(Clone, PartialEq)]
pub struct KafkaAddress {
pub host: StrBytes,
Expand Down Expand Up @@ -55,24 +85,14 @@ impl KafkaNode {

pub async fn get_connection(
&mut self,
connect_timeout: Duration,
tls: &Option<TlsConnector>,
connection_factory: &ConnectionFactory,
) -> Result<&Connection> {
if self.connection.is_none() {
let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkCluster".to_owned());
let address = (
self.kafka_address.host.to_string(),
self.kafka_address.port as u16,
self.connection = Some(
connection_factory
.create_connection(&self.kafka_address)
.await?,
);
if let Some(tls) = tls.as_ref() {
let tls_stream = tls.connect(connect_timeout, address).await?;
let (rx, tx) = split(tls_stream);
self.connection = Some(spawn_read_write_tasks(&codec, rx, tx));
} else {
let tcp_stream = tcp::tcp_stream(connect_timeout, address).await?;
let (rx, tx) = tcp_stream.into_split();
self.connection = Some(spawn_read_write_tasks(&codec, rx, tx));
}
}
Ok(self.connection.as_ref().unwrap())
}
Expand Down

0 comments on commit 175dd9e

Please sign in to comment.