Skip to content

Commit

Permalink
support batch get/set in dht
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkeldenker committed Mar 27, 2024
1 parent 37eee39 commit b823d1f
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 27 deletions.
68 changes: 66 additions & 2 deletions crates/core/src/ampc/dht/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,17 @@ impl Node {
self.api.get(table, key).await
}

async fn batch_get(&self, table: Table, keys: Vec<Key>) -> Result<Vec<(Key, Value)>> {
self.api.batch_get(table, keys).await
}

async fn set(&self, table: Table, key: Key, value: Value) -> Result<()> {
self.api.set(table, key, value).await
}

async fn batch_set(&self, table: Table, values: Vec<(Key, Value)>) -> Result<()> {
self.api.batch_set(table, values).await
}
}

#[derive(Clone, serde::Serialize, serde::Deserialize)]
Expand All @@ -81,9 +89,17 @@ impl Shard {
self.node().get(table, key).await
}

async fn batch_get(&self, table: Table, keys: Vec<Key>) -> Result<Vec<(Key, Value)>> {
self.node().batch_get(table, keys).await
}

async fn set(&self, table: Table, key: Key, value: Value) -> Result<()> {
self.node().set(table, key, value).await
}

async fn batch_set(&self, table: Table, values: Vec<(Key, Value)>) -> Result<()> {
self.node().batch_set(table, values).await
}
}

#[derive(Clone, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -130,28 +146,76 @@ impl Client {
self.ids = self.shards.keys().cloned().collect();
}

fn shard_for_key(&self, key: &[u8]) -> Result<&Shard> {
fn shard_id_for_key(&self, key: &[u8]) -> Result<&ShardId> {
if self.ids.is_empty() {
return Err(anyhow::anyhow!("No shards"));
}

let hash = md5::compute(key);
let hash = u64::from_le_bytes((&hash.0[..(u64::BITS / 8) as usize]).try_into().unwrap());

let shard_id = &self.ids[hash as usize % self.ids.len()];
Ok(&self.ids[hash as usize % self.ids.len()])
}

fn shard_for_key(&self, key: &[u8]) -> Result<&Shard> {
let shard_id = self.shard_id_for_key(key)?;
Ok(self.shards.get(shard_id).unwrap())
}

pub async fn get(&self, table: Table, key: Key) -> Result<Option<Value>> {
self.shard_for_key(key.as_bytes())?.get(table, key).await
}

pub async fn batch_get(&self, table: Table, keys: Vec<Key>) -> Result<Vec<(Key, Value)>> {
let mut shard_keys: BTreeMap<ShardId, Vec<Key>> = BTreeMap::new();

for key in keys {
let shard = self.shard_id_for_key(key.as_bytes())?;
shard_keys.entry(*shard).or_default().push(key);
}

let mut futures = Vec::with_capacity(shard_keys.len());

for (shard_id, keys) in shard_keys {
futures.push(self.shards[&shard_id].batch_get(table.clone(), keys));
}

let mut results: Vec<_> = futures::future::try_join_all(futures)
.await?
.into_iter()
.flatten()
.collect();
results.sort_by(|(a, _), (b, _)| a.cmp(b));
results.dedup_by(|(a, _), (b, _)| a == b);

Ok(results)
}

pub async fn set(&self, table: Table, key: Key, value: Value) -> Result<()> {
self.shard_for_key(key.as_bytes())?
.set(table, key, value)
.await
}

pub async fn batch_set(&self, table: Table, values: Vec<(Key, Value)>) -> Result<()> {
let mut shard_values: BTreeMap<ShardId, Vec<(Key, Value)>> = BTreeMap::new();

for (key, value) in values {
let shard = self.shard_id_for_key(key.as_bytes())?;
shard_values.entry(*shard).or_default().push((key, value));
}

let mut futures = Vec::with_capacity(shard_values.len());

for (shard_id, values) in shard_values {
futures.push(self.shards[&shard_id].batch_set(table.clone(), values));
}

futures::future::try_join_all(futures).await?;

Ok(())
}

pub async fn drop_table(&self, table: Table) -> Result<()> {
for shard in self.shards.values() {
for node in &shard.nodes {
Expand Down
11 changes: 5 additions & 6 deletions crates/core/src/ampc/dht/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub mod log_store;
mod network;
pub mod store;

use network::api::{AllTables, CloneTable, CreateTable, DropTable, Get, Set};
use network::api::{AllTables, BatchSet, CloneTable, CreateTable, DropTable, Set};

use std::fmt::Debug;
use std::io::Cursor;
Expand All @@ -45,7 +45,7 @@ pub use network::api::RemoteClient as ApiClient;
pub use network::raft::RemoteClient as RaftClient;

pub use client::Client;
pub use store::Table;
pub use store::{Key, Table, Value};

pub type NodeId = u64;

Expand Down Expand Up @@ -91,7 +91,7 @@ macro_rules! raft_sonic_request_response {

raft_sonic_request_response!(
Server,
[Get, Set, CreateTable, DropTable, AllTables, CloneTable]
[Set, BatchSet, CreateTable, DropTable, AllTables, CloneTable]
);

#[cfg(test)]
Expand Down Expand Up @@ -360,9 +360,6 @@ mod tests {
let res = c1.get(table.clone(), "hello".as_bytes().into()).await?;
assert_eq!(res, Some("world".as_bytes().into()));

let res = c2.get(table.clone(), "hello".as_bytes().into()).await?;
assert_eq!(res, Some("world".as_bytes().into()));

// crash node 2
handles[1].abort();
drop(raft2);
Expand All @@ -375,6 +372,7 @@ mod tests {
});

rc1.join(2, addr2, members.clone()).await?;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

let c2 = RemoteClient::new(addr2);

Expand All @@ -400,6 +398,7 @@ mod tests {
});
raft2.initialize(members.clone()).await?;
rc1.join(2, addr2, members.clone()).await?;
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

let c2 = RemoteClient::new(addr2);

Expand Down
141 changes: 141 additions & 0 deletions crates/core/src/ampc/dht/network/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ pub struct Get {
pub key: Key,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BatchGet {
pub table: Table,
pub keys: Vec<Key>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BatchSet {
pub table: Table,
pub values: Vec<(Key, Value)>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DropTable {
pub table: Table,
Expand Down Expand Up @@ -77,6 +89,19 @@ impl sonic::service::Message<Server> for Set {
}
}

impl sonic::service::Message<Server> for BatchSet {
type Response = Result<(), RaftError<NodeId, ClientWriteError<NodeId, BasicNode>>>;

async fn handle(self, server: &Server) -> Self::Response {
tracing::debug!("received batch set request: {:?}", self);

match server.raft.client_write(self.into()).await {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
}

impl sonic::service::Message<Server> for Get {
type Response = Option<Value>;

Expand All @@ -91,6 +116,20 @@ impl sonic::service::Message<Server> for Get {
}
}

impl sonic::service::Message<Server> for BatchGet {
type Response = Vec<(Key, Value)>;

async fn handle(self, server: &Server) -> Self::Response {
server
.state_machine_store
.state_machine
.read()
.await
.db
.batch_get(&self.table, &self.keys)
}
}

impl sonic::service::Message<Server> for DropTable {
type Response = Result<(), RaftError<NodeId, ClientWriteError<NodeId, BasicNode>>>;

Expand Down Expand Up @@ -232,6 +271,73 @@ impl RemoteClient {
Err(anyhow!("failed to set key"))
}

pub async fn batch_set(&self, table: Table, values: Vec<(Key, Value)>) -> Result<()> {
for backoff in Self::retry_strat() {
let res = self
.likely_leader
.read()
.await
.as_ref()
.unwrap_or(&self.self_remote)
.send_with_timeout(
&BatchSet {
table: table.clone(),
values: values.clone(),
},
Duration::from_secs(5),
)
.await;

tracing::debug!(".batch_set() got response: {res:?}");

match res {
Ok(res) => match res {
Ok(_) => return Ok(()),
Err(RaftError::APIError(e)) => match e {
ClientWriteError::ForwardToLeader(ForwardToLeader {
leader_id: _,
leader_node,
}) => match leader_node {
Some(leader_node) => {
let mut likely_leader = self.likely_leader.write().await;
*likely_leader = Some(sonic::replication::RemoteClient::new(
leader_node
.addr
.parse()
.expect("node addr should always be valid addr"),
));
}
None => {
tokio::time::sleep(backoff).await;
}
},
ClientWriteError::ChangeMembershipError(_) => {
unreachable!(".batch_set() should not change membership")
}
},
Err(RaftError::Fatal(e)) => return Err(e.into()),
},
Err(e) => match e {
sonic::Error::IO(_)
| sonic::Error::Serialization(_)
| sonic::Error::ConnectionTimeout
| sonic::Error::RequestTimeout
| sonic::Error::PoolCreation => {
tokio::time::sleep(backoff).await;
}
sonic::Error::BadRequest
| sonic::Error::BodyTooLarge {
body_size: _,
max_size: _,
}
| sonic::Error::Application(_) => return Err(e.into()),
},
}
}

Err(anyhow!("failed to batch set values"))
}

pub async fn get(&self, table: Table, key: Key) -> Result<Option<Value>> {
for backoff in Self::retry_strat() {
match self
Expand Down Expand Up @@ -267,6 +373,41 @@ impl RemoteClient {
Err(anyhow!("failed to get key"))
}

pub async fn batch_get(&self, table: Table, keys: Vec<Key>) -> Result<Vec<(Key, Value)>> {
for backoff in Self::retry_strat() {
match self
.self_remote
.send_with_timeout(
&BatchGet {
table: table.clone(),
keys: keys.clone(),
},
Duration::from_secs(5),
)
.await
{
Ok(res) => return Ok(res),
Err(e) => match e {
sonic::Error::IO(_)
| sonic::Error::Serialization(_)
| sonic::Error::ConnectionTimeout
| sonic::Error::RequestTimeout
| sonic::Error::PoolCreation => {
tokio::time::sleep(backoff).await;
}
sonic::Error::BadRequest
| sonic::Error::BodyTooLarge {
body_size: _,
max_size: _,
}
| sonic::Error::Application(_) => return Err(e.into()),
},
}
}

Err(anyhow!("failed to batch get keys"))
}

pub async fn drop_table(&self, table: Table) -> Result<()> {
for backoff in Self::retry_strat() {
let res = self
Expand Down
4 changes: 3 additions & 1 deletion crates/core/src/ampc/dht/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pub mod api;
pub mod raft;

use api::{AllTables, CloneTable, CreateTable, DropTable, Get, Set};
use api::{AllTables, BatchGet, BatchSet, CloneTable, CreateTable, DropTable, Get, Set};
use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};

use openraft::{BasicNode, Raft, RaftNetworkFactory};
Expand Down Expand Up @@ -68,7 +68,9 @@ sonic_service!(
AddLearner,
AddNodes,
Get,
BatchGet,
Set,
BatchSet,
DropTable,
CreateTable,
AllTables,
Expand Down
Loading

0 comments on commit b823d1f

Please sign in to comment.