Skip to content

Commit

Permalink
join existing raft cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkeldenker committed Mar 15, 2024
1 parent 14e749d commit 555010f
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 13 deletions.
11 changes: 10 additions & 1 deletion crates/core/src/distributed/sonic/replication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@ use super::Result;
use crate::distributed::{retry_strategy::ExponentialBackoff, sonic};
use std::{net::SocketAddr, time::Duration};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct RemoteClient<S: sonic::service::Service> {
addr: SocketAddr,
_phantom: std::marker::PhantomData<S>,
}

impl<S> Clone for RemoteClient<S>
where
S: sonic::service::Service,
{
fn clone(&self) -> Self {
Self::create(self.addr)
}
}

impl<S> RemoteClient<S>
where
S: sonic::service::Service,
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/mapreduce/dht/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ mod tests {

#[tokio::test]
#[traced_test]
#[ignore = "[WIP] need to figure out how to add a new node to the cluster"]
async fn test_member_join() -> anyhow::Result<()> {
let (raft1, server1, addr1) = server(1).await?;
let (raft2, server2, addr2) = server(2).await?;
Expand Down Expand Up @@ -220,7 +219,8 @@ mod tests {
.collect();

raft3.initialize(members.clone()).await?;
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let rc1 = network::raft::RemoteClient::new(1, BasicNode::new(addr1));
rc1.join(3, addr3, members.clone()).await?;

let c3 = RemoteClient::new(addr3);
let res = c3.get("hello".to_string()).await?;
Expand Down
19 changes: 16 additions & 3 deletions crates/core/src/mapreduce/dht/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>

pub mod api;
mod raft;
pub mod raft;

use api::{Get, Set};
use std::sync::Arc;
use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};

use openraft::{BasicNode, Raft, RaftNetworkFactory};

Expand Down Expand Up @@ -48,14 +48,27 @@ pub type InstallSnapshotResponse = openraft::raft::InstallSnapshotResponse<NodeI
pub type VoteRequest = openraft::raft::VoteRequest<NodeId>;
pub type VoteResponse = openraft::raft::VoteResponse<NodeId>;

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct AddLearnerRequest {
pub id: NodeId,
pub addr: SocketAddr,
}

#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct AddNodesRequest {
members: BTreeMap<NodeId, BasicNode>,
}

sonic_service!(
Server,
[
AppendEntriesRequest,
InstallSnapshotRequest,
VoteRequest,
AddLearnerRequest,
AddNodesRequest,
Get,
Set
Set,
]
);

Expand Down
168 changes: 161 additions & 7 deletions crates/core/src/mapreduce/dht/network/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>

use std::net::SocketAddr;
use std::{collections::BTreeMap, net::SocketAddr, time::Duration};

use openraft::{
error::{InstallSnapshotError, RaftError},
error::{ClientWriteError, ForwardToLeader, InstallSnapshotError, RaftError},
network::RPCOption,
BasicNode, RaftNetwork,
BasicNode, ChangeMembers, RaftNetwork,
};
use tokio::sync::RwLock;

use crate::{
distributed::sonic::{self, service::ResilientConnection},
distributed::{
retry_strategy::ExponentialBackoff,
sonic::{self, service::ResilientConnection},
},
mapreduce::dht::{NodeId, TypeConfig},
Result,
};

use super::{
AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
Server, VoteRequest, VoteResponse,
AddLearnerRequest, AddNodesRequest, AppendEntriesRequest, AppendEntriesResponse,
InstallSnapshotRequest, InstallSnapshotResponse, Server, VoteRequest, VoteResponse,
};

impl sonic::service::Message<Server> for AppendEntriesRequest {
Expand Down Expand Up @@ -59,24 +64,53 @@ impl sonic::service::Message<Server> for VoteRequest {
}
}

impl sonic::service::Message<Server> for AddLearnerRequest {
type Response = Result<(), RaftError<NodeId, ClientWriteError<NodeId, BasicNode>>>;

async fn handle(self, server: &Server) -> Self::Response {
tracing::debug!("received add learner request: {:?}", self);
let node = BasicNode::new(self.addr);
server.raft.add_learner(self.id, node, false).await?;

Ok(())
}
}

impl sonic::service::Message<Server> for AddNodesRequest {
type Response = Result<(), RaftError<NodeId, ClientWriteError<NodeId, BasicNode>>>;

async fn handle(self, server: &Server) -> Self::Response {
tracing::debug!("received add nodes request: {:?}", self);
server
.raft
.change_membership(ChangeMembers::AddNodes(self.members), true)
.await?;

Ok(())
}
}

type RPCError<E = openraft::error::Infallible> =
openraft::error::RPCError<NodeId, BasicNode, RaftError<NodeId, E>>;

pub struct RemoteClient {
target: NodeId,
node: BasicNode,
inner: sonic::replication::RemoteClient<Server>,
likely_leader: RwLock<sonic::replication::RemoteClient<Server>>,
}

impl RemoteClient {
pub fn new(target: NodeId, node: BasicNode) -> Self {
let addr: SocketAddr = node.addr.parse().expect("addr is not a valid address");
let inner = sonic::replication::RemoteClient::new(addr);
let likely_leader = RwLock::new(inner.clone());

Self {
target,
node,
inner,
likely_leader,
}
}
async fn raft_conn<E: std::error::Error>(
Expand All @@ -89,7 +123,7 @@ impl RemoteClient {
}

async fn send_raft_rpc<R, E>(
&mut self,
&self,
rpc: R,
option: RPCOption,
) -> Result<R::Response, RPCError<E>>
Expand All @@ -109,6 +143,126 @@ impl RemoteClient {
}
})
}

async fn add_learner(&self, id: NodeId, addr: SocketAddr) -> Result<()> {
let rpc = AddLearnerRequest { id, addr };
let retry = ExponentialBackoff::from_millis(500)
.with_limit(Duration::from_secs(60))
.take(5);

for backoff in retry {
let res = self.likely_leader.read().await.send(&rpc).await;

match res {
Ok(res) => match res {
Ok(_) => return Ok(()),
Err(RaftError::APIError(e)) => match e {
ClientWriteError::ForwardToLeader(ForwardToLeader {
leader_id: _,
leader_node,
}) => match leader_node {
Some(leader_node) => {
let mut likely_leader = self.likely_leader.write().await;
*likely_leader = sonic::replication::RemoteClient::new(
leader_node
.addr
.parse()
.expect("node addr should always be valid addr"),
);
}
None => tokio::time::sleep(backoff).await,
},
ClientWriteError::ChangeMembershipError(_) => {
tokio::time::sleep(backoff).await
}
},
Err(RaftError::Fatal(e)) => return Err(e.into()),
},
Err(e) => match e {
sonic::Error::IO(_)
| sonic::Error::Serialization(_)
| sonic::Error::ConnectionTimeout
| sonic::Error::RequestTimeout
| sonic::Error::PoolCreation => {
tokio::time::sleep(backoff).await;
}
sonic::Error::BadRequest
| sonic::Error::BodyTooLarge {
body_size: _,
max_size: _,
}
| sonic::Error::Application(_) => return Err(e.into()),
},
}
}

Err(anyhow::anyhow!("failed to add learner"))
}

async fn add_nodes(&self, members: BTreeMap<NodeId, BasicNode>) -> Result<()> {
let rpc = AddNodesRequest { members };
let retry = ExponentialBackoff::from_millis(500).with_limit(Duration::from_secs(10));

for backoff in retry {
let res = self.likely_leader.read().await.send(&rpc).await;

match res {
Ok(res) => match res {
Ok(_) => return Ok(()),
Err(RaftError::APIError(e)) => match e {
ClientWriteError::ForwardToLeader(ForwardToLeader {
leader_id: _,
leader_node,
}) => match leader_node {
Some(leader_node) => {
let mut likely_leader = self.likely_leader.write().await;
*likely_leader = sonic::replication::RemoteClient::new(
leader_node
.addr
.parse()
.expect("node addr should always be valid addr"),
);
}
None => tokio::time::sleep(backoff).await,
},
ClientWriteError::ChangeMembershipError(_) => {
tokio::time::sleep(backoff).await
}
},
Err(RaftError::Fatal(e)) => return Err(e.into()),
},
Err(e) => match e {
sonic::Error::IO(_)
| sonic::Error::Serialization(_)
| sonic::Error::ConnectionTimeout
| sonic::Error::RequestTimeout
| sonic::Error::PoolCreation => {
tokio::time::sleep(backoff).await;
}
sonic::Error::BadRequest
| sonic::Error::BodyTooLarge {
body_size: _,
max_size: _,
}
| sonic::Error::Application(_) => return Err(e.into()),
},
}
}

unreachable!("should continue to retry");
}

pub async fn join(
&self,
id: NodeId,
addr: SocketAddr,
new_all_nodes: BTreeMap<NodeId, BasicNode>,
) -> Result<()> {
self.add_learner(id, addr).await?;
self.add_nodes(new_all_nodes).await?;

Ok(())
}
}

impl RaftNetwork<TypeConfig> for RemoteClient {
Expand Down

0 comments on commit 555010f

Please sign in to comment.