Skip to content

Commit

Permalink
RedisSinkCluster: remove dependency on redis during startup (#1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Sep 19, 2023
1 parent 2ab0748 commit ffe6821
Showing 1 changed file with 126 additions and 69 deletions.
195 changes: 126 additions & 69 deletions shotover/src/transforms/redis/sink_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ use rand::SeedableRng;
use redis_protocol::types::Redirection;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, RwLock};
use tokio::time::{timeout, Duration};
use tracing::{debug, error, info, trace, warn};
use tracing::{debug, error, trace, warn};

const SLOT_SIZE: usize = 16384;

Expand All @@ -50,35 +51,73 @@ pub struct RedisSinkClusterConfig {
#[async_trait(?Send)]
impl TransformConfig for RedisSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
let mut cluster = RedisSinkCluster::new(
self.first_contact_points.clone(),
self.direct_destination.clone(),
self.connection_count.unwrap_or(1),
let connection_pool = ConnectionPool::new_with_auth(
Duration::from_millis(self.connect_timeout_ms),
RedisCodecBuilder::new(Direction::Sink),
RedisAuthenticator {},
self.tls.clone(),
self.connect_timeout_ms,
chain_name,
)?;
Ok(Box::new(RedisSinkClusterBuilder {
first_contact_points: self.first_contact_points.clone(),
direct_destination: self.direct_destination.clone(),
connection_count: self.connection_count.unwrap_or(1),
connection_pool,
chain_name,
shared_topology: Arc::new(RwLock::new(Topology::new())),
}))
}
}

match cluster.build_connections(None).await {
Ok(()) => {
info!("connected to upstream");
}
Err(TransformError::Upstream(RedisError::NotAuthenticated)) => {
info!("upstream requires auth");
}
Err(e) => {
return Err(anyhow!(e).context("failed to connect to upstream"));
}
}
pub struct RedisSinkClusterBuilder {
first_contact_points: Vec<String>,
direct_destination: Option<String>,
connection_count: usize,
connection_pool: ConnectionPool<RedisCodecBuilder, RedisAuthenticator, UsernamePasswordToken>,
chain_name: String,
shared_topology: Arc<RwLock<Topology>>,
}

impl TransformBuilder for RedisSinkClusterBuilder {
fn build(&self) -> Transforms {
Transforms::RedisSinkCluster(RedisSinkCluster::new(
self.first_contact_points.clone(),
self.direct_destination.clone(),
self.connection_count,
self.chain_name.clone(),
self.shared_topology.clone(),
self.connection_pool.clone(),
))
}

fn get_name(&self) -> &'static str {
"RedisSinkCluster"
}

fn is_terminating(&self) -> bool {
true
}
}

#[derive(Debug, Clone)]
struct Topology {
slots: SlotMap,
channels: ChannelMap,
}

Ok(Box::new(cluster))
impl Topology {
fn new() -> Self {
Topology {
slots: SlotMap::new(),
channels: ChannelMap::new(),
}
}
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct RedisSinkCluster {
pub slots: SlotMap,
pub channels: ChannelMap,
has_run_init: bool,
topology: Topology,
shared_topology: Arc<RwLock<Topology>>,
direct_connection: Option<UnboundedSender<Request>>,
load_scores: HashMap<(String, usize), usize>,
rng: SmallRng,
Expand All @@ -92,42 +131,37 @@ pub struct RedisSinkCluster {
}

impl RedisSinkCluster {
pub fn new(
fn new(
first_contact_points: Vec<String>,
direct_destination: Option<String>,
connection_count: usize,
tls: Option<TlsConnectorConfig>,
connect_timeout_ms: u64,
chain_name: String,
) -> Result<Self> {
let authenticator = RedisAuthenticator {};

let connect_timeout = Duration::from_millis(connect_timeout_ms);
let connection_pool = ConnectionPool::new_with_auth(
connect_timeout,
RedisCodecBuilder::new(Direction::Sink),
authenticator,
tls,
)?;

shared_topology: Arc<RwLock<Topology>>,
connection_pool: ConnectionPool<
RedisCodecBuilder,
RedisAuthenticator,
UsernamePasswordToken,
>,
) -> Self {
let sink_cluster = RedisSinkCluster {
has_run_init: false,
first_contact_points,
direct_destination,
slots: SlotMap::new(),
channels: ChannelMap::new(),
topology: Topology::new(),
shared_topology,
direct_connection: None,
load_scores: HashMap::new(),
rng: SmallRng::from_rng(rand::thread_rng()).unwrap(),
connection_count,
connection_pool,
reason_for_no_nodes: None,
rebuild_connections: false,
rebuild_connections: true,
token: None,
};

register_counter!("failed_requests", "chain" => chain_name, "transform" => sink_cluster.get_name());

Ok(sink_cluster)
sink_cluster
}

async fn direct_connection(&mut self) -> Result<&UnboundedSender<Request>> {
Expand Down Expand Up @@ -173,7 +207,7 @@ impl RedisSinkCluster {
slot: u16,
message: Message,
) -> Result<ResponseFuture> {
if let Some((_, lookup)) = self.slots.masters.range(&slot..).next() {
if let Some((_, lookup)) = self.topology.slots.masters.range(&slot..).next() {
let lookup = lookup.to_string();
let one_rx = self.choose_and_send(&lookup, message).await?;
Ok(Box::pin(
Expand Down Expand Up @@ -254,9 +288,14 @@ impl RedisSinkCluster {
}

fn latest_contact_points(&self) -> Vec<&str> {
if !self.slots.nodes.is_empty() {
if !self.topology.slots.nodes.is_empty() {
// Use latest node addresses as contact points.
self.slots.nodes.iter().map(|x| x.as_str()).collect()
self.topology
.slots
.nodes
.iter()
.map(|x| x.as_str())
.collect()
} else {
// Fallback to initial contact points.
self.first_contact_points
Expand Down Expand Up @@ -312,10 +351,12 @@ impl RedisSinkCluster {
match self.build_connections_inner(&token).await {
Ok((slots, channels)) => {
debug!("connected to cluster: {:?}", channels.keys());
self.topology = Topology { slots, channels };
if token.is_none() {
// when authentication isnt used we can share topology between connections
*self.shared_topology.write().await = self.topology.clone();
}
self.token = token;
self.slots = slots;
self.channels = channels;

self.reason_for_no_nodes = None;
self.rebuild_connections = false;
Ok(())
Expand Down Expand Up @@ -365,7 +406,7 @@ impl RedisSinkCluster {

#[inline]
async fn choose_and_send(&mut self, host: &str, message: Message) -> Result<ResponseFuture> {
let channel = match self.channels.get_mut(host) {
let channel = match self.topology.channels.get_mut(host) {
Some(channels) if channels.len() == 1 => channels.get_mut(0),
Some(channels) if channels.len() > 1 => {
let candidates = rand::seq::index::sample(&mut self.rng, channels.len(), 2);
Expand Down Expand Up @@ -409,8 +450,13 @@ impl RedisSinkCluster {
{
Ok(Ok(connections)) => {
debug!("Found {} live connections for {}", connections.len(), host);
self.channels.insert(host.to_string(), connections);
self.channels.get_mut(host).unwrap().get_mut(0).unwrap()
self.topology.channels.insert(host.to_string(), connections);
self.topology
.channels
.get_mut(host)
.unwrap()
.get_mut(0)
.unwrap()
}
Ok(Err(e)) => {
debug!("failed to connect to {}: {}", host, e);
Expand All @@ -434,7 +480,7 @@ impl RedisSinkCluster {
.is_err()
{
self.rebuild_connections = true;
self.channels.remove(host);
self.topology.channels.remove(host);
return self.short_circuit_with_error();
}

Expand All @@ -450,22 +496,23 @@ impl RedisSinkCluster {
RoutingInfo::Slot(slot) => self.send_message_to_slot(slot, message).await,
RoutingInfo::AllNodes(_) => {
self.send_message_to_channels(
&self.slots.nodes.iter().cloned().collect_vec(),
&self.topology.slots.nodes.iter().cloned().collect_vec(),
message,
routing_info,
)
.await
}
RoutingInfo::AllMasters(_) => {
self.send_message_to_channels(
&self.slots.masters.values().cloned().collect_vec(),
&self.topology.slots.masters.values().cloned().collect_vec(),
message,
routing_info,
)
.await
}
RoutingInfo::Random => {
let lookup = self
.topology
.slots
.masters
.values()
Expand Down Expand Up @@ -932,26 +979,36 @@ fn short_circuit(frame: RedisFrame) -> Result<ResponseFuture> {
}))
}

impl TransformBuilder for RedisSinkCluster {
fn build(&self) -> Transforms {
Transforms::RedisSinkCluster(self.clone())
}

fn get_name(&self) -> &'static str {
"RedisSinkCluster"
}

fn is_terminating(&self) -> bool {
true
}
}

#[async_trait]
impl Transform for RedisSinkCluster {
async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
if !self.has_run_init {
self.topology = (*self.shared_topology.read().await).clone();
if self.topology.channels.is_empty() {
// The code paths for authenticated and unauthenticated redis are quite different.
// * For unauthenticated redis this initial build_connections should succeed.
// + This is required to process the messages we are about to receive.
// + We also share the results to skip having to run build_connections again for new connection
// * For authenticated redis this initial build_connections always fails
// + The first message to come through should be an AUTH command which will give us the credentials required for us to run build_connections.
// As soon as we receive it we will rerun build_connections so we can process other message types afterwards.
// + It is important we do not share the results of the successful build_connections as that would leak authenticated shotover<->redis connections to other client<->shotover connections.
if let Err(err) = self.build_connections(self.token.clone()).await {
match err {
TransformError::Upstream(RedisError::NotAuthenticated) => {
// Build_connections sent an internal `CLUSTER SLOTS` command to redis and redis refused to respond because it is enforcing authentication.
// When the client sends an AUTH message we will rerun build_connections.
}
_ => tracing::warn!("Error when building connections: {err:?}"),
}
}
}
self.has_run_init = true;
}

if self.rebuild_connections {
if let Err(err) = self.build_connections(self.token.clone()).await {
tracing::warn!("Error when rebuilding connections {err:?}");
tracing::warn!("Error when rebuilding connections: {err:?}");
}
}

Expand Down Expand Up @@ -989,7 +1046,7 @@ impl Transform for RedisSinkCluster {
debug!("Got MOVE {} {}", slot, server);

// The destination of a MOVE should always be a master.
self.slots.masters.insert(slot, server.clone());
self.topology.slots.masters.insert(slot, server.clone());

self.rebuild_connections = true;

Expand Down

0 comments on commit ffe6821

Please sign in to comment.