Skip to content

Commit

Permalink
Implement ExchangeServerHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
fl0rek committed Sep 18, 2023
1 parent e435c1f commit 44f0de0
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 26 deletions.
7 changes: 4 additions & 3 deletions node/src/exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mod utils;
pub use utils::ExtendedHeaderExt;

use crate::exchange::client::ExchangeClientHandler;
use crate::exchange::server::ExchangeServerHandler;
use crate::exchange::server::{ExchangeServerHandler, RequestResponseResponder};
use crate::p2p::P2pError;
use crate::peer_tracker::PeerTracker;
use crate::store::Store;
Expand Down Expand Up @@ -82,7 +82,7 @@ impl ExchangeBehaviour {
request_response::Config::default(),
),
client_handler: ExchangeClientHandler::new(config.peer_tracker),
server_handler: ExchangeServerHandler::new(),
server_handler: ExchangeServerHandler::new(config.header_store),
}
}

Expand Down Expand Up @@ -144,8 +144,9 @@ impl ExchangeBehaviour {
},
peer,
} => {
let responder = RequestResponseResponder::new(&mut self.req_resp, channel);
self.server_handler
.on_request_received(peer, request_id, request, channel);
.on_request_received(peer, request_id, request, responder);
}

// Response to inbound request was sent
Expand Down
2 changes: 1 addition & 1 deletion node/src/exchange/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::exchange::utils::ExtendedHeaderExt;
use crate::exchange::utils::ToHeaderResponse;
use celestia_proto::p2p::pb::header_request::Data;
use celestia_proto::p2p::pb::StatusCode;
use celestia_types::consts::HASH_SIZE;
Expand Down
303 changes: 287 additions & 16 deletions node/src/exchange/server.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,65 @@
use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
use libp2p::{
request_response::{InboundFailure, RequestId, ResponseChannel},
PeerId,
};
use crate::exchange::utils::{HeaderRequestExt, ToHeaderResponse};
use crate::exchange::HeaderCodec;
use crate::store::Store;
use celestia_proto::p2p::pb::{header_request, HeaderRequest, HeaderResponse, StatusCode};
use celestia_types::Hash;
use libp2p::request_response::{Behaviour, InboundFailure, RequestId, ResponseChannel};
use libp2p::PeerId;
use std::fmt::{Debug, Display};
use std::sync::Arc;
use tendermint::hash::Algorithm;
use tracing::instrument;
use tracing::{error, info, trace, warn};

pub(super) struct ExchangeServerHandler {
// TODO
const INVALID_REQUEST_MSG: &str = "invalid request";

pub(crate) struct ExchangeServerHandler {
store: Arc<Store>,
}

// TODO: how much do we want to augument response with information about failure
fn gen_invalid_request_response() -> Vec<HeaderResponse> {
vec![HeaderResponse {
status_code: StatusCode::Invalid.into(),
body: INVALID_REQUEST_MSG.as_bytes().to_vec(),
}]
}

impl ExchangeServerHandler {
pub(super) fn new() -> Self {
ExchangeServerHandler {}
pub(super) fn new(store: Arc<Store>) -> Self {
ExchangeServerHandler { store }
}

#[instrument(level = "trace", skip(self, _respond_to))]
pub(super) fn on_request_received(
#[instrument(level = "trace", skip(self, responder))]
pub(super) fn on_request_received<ID: Display + Debug, R: Responder>(
&mut self,
peer: PeerId,
request_id: RequestId,
request_id: ID,
request: HeaderRequest,
_respond_to: ResponseChannel<Vec<HeaderResponse>>,
responder: R,
) {
// TODO
trace!("request_received; request_id: {request_id}, request: {request:?}");
let Some((amount, data)) = self.parse_request(request) else {
responder.send_response(gen_invalid_request_response());
return;
};

let response = match data {
header_request::Data::Origin(0) => self.handle_request_current_head(),
header_request::Data::Origin(height) => self.handle_request_by_height(height, amount),
header_request::Data::Hash(hash) => self.handle_request_by_hash(hash),
};

responder.send_response(response);
}

fn handle_request_current_head(&self) -> Vec<HeaderResponse> {
vec![self.store.get_head().to_header_response()]
}

#[instrument(level = "trace", skip(self))]
pub(super) fn on_response_sent(&mut self, peer: PeerId, request_id: RequestId) {
// TODO
info!("response_sent; request_id: {request_id}, peer: {peer}");
}

#[instrument(level = "trace", skip(self))]
Expand All @@ -37,6 +69,245 @@ impl ExchangeServerHandler {
request_id: RequestId,
error: InboundFailure,
) {
// TODO
info!("on_failure; request_id: {request_id}, peer: {peer}, error: {error:?}");
}

fn parse_request(&self, request: HeaderRequest) -> Option<(u64, header_request::Data)> {
if !request.is_valid() {
return None;
}

let HeaderRequest {
amount,
data: Some(data),
} = request
else {
return None;
};

Some((amount, data))
}

fn handle_request_by_hash(&self, hash: Vec<u8>) -> Vec<HeaderResponse> {
let hash = match Hash::from_bytes(Algorithm::Sha256, &hash) {
Ok(h) => h,
Err(e) => {
error!("error decoding hash: {e}");
return gen_invalid_request_response();
}
};

let header = self.store.get_by_hash(&hash);

vec![header.to_header_response()]
}

fn handle_request_by_height(&self, origin: u64, amount: u64) -> Vec<HeaderResponse> {
info!("get by height {origin} +{amount}");
let mut r = vec![];
for i in origin..origin + amount {
r.push(self.store.get_by_height(i).to_header_response());
}

r
}
}

pub(super) trait Responder {
fn send_response(self, response: Vec<HeaderResponse>);
}

pub(super) struct RequestResponseResponder<'a> {
behaviour: &'a mut Behaviour<HeaderCodec>,
response_channel: ResponseChannel<Vec<HeaderResponse>>,
}

impl<'a> RequestResponseResponder<'a> {
pub fn new(
behaviour: &'a mut Behaviour<HeaderCodec>,
response_channel: ResponseChannel<Vec<HeaderResponse>>,
) -> Self {
Self {
behaviour,
response_channel,
}
}
}

impl Responder for RequestResponseResponder<'_> {
fn send_response(self, response: Vec<HeaderResponse>) {
let Self {
behaviour,
response_channel,
} = self;
// response was prepared specifically for this request, we can drop it
// we'll get notified about failure via `Event::InboundFailure`
behaviour.send_response(response_channel, response).ok();
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::exchange::utils::HeaderRequestExt;
use crate::store::Store;
use celestia_proto::p2p::pb::header_request::Data;
use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse, StatusCode};
use celestia_types::ExtendedHeader;
use libp2p::PeerId;
use std::sync::Arc;
use tendermint_proto::Protobuf;
use tokio::sync::oneshot;

struct TestResponder(oneshot::Sender<Vec<HeaderResponse>>);

impl TestResponder {
fn new() -> (Self, oneshot::Receiver<Vec<HeaderResponse>>) {
let (tx, rx) = oneshot::channel();
(Self(tx), rx)
}
}

impl Responder for TestResponder {
fn send_response(self, response: Vec<HeaderResponse>) {
self.0.send(response).ok();
}
}

#[tokio::test]
async fn request_header_test() {
let store = Arc::new(Store::test_filled_store(3));
let expected_origin = store.get_by_height(1).unwrap();
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest::with_origin(1, 1);
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].status_code, i32::from(StatusCode::Ok));
let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap();
assert_eq!(decoded_header, expected_origin);
}

#[tokio::test]
async fn request_head_test() {
let store = Arc::new(Store::test_filled_store(4));
let expected_head = store.get_head().unwrap();
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest::head_request();
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].status_code, i32::from(StatusCode::Ok));
let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap();
assert_eq!(decoded_header, expected_head);
}

#[tokio::test]
async fn invalid_amount_request_test() {
let store = Arc::new(Store::test_filled_store(1));
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest::with_origin(0, 0);
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid));
}

#[tokio::test]
async fn none_data_request_test() {
let store = Arc::new(Store::test_filled_store(1));
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest {
data: None,
amount: 1,
};
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid));
}

#[tokio::test]
async fn request_hash_test() {
let store = Arc::new(Store::test_filled_store(1));
let stored_header = store.get_head().unwrap();
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest::with_hash(stored_header.hash());
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0].status_code, i32::from(StatusCode::Ok));
let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap();
assert_eq!(decoded_header, stored_header);
}

#[tokio::test]
async fn request_range_test() {
let store = Arc::new(Store::test_filled_store(10));
let expected_headers = [
store.get_by_height(5).unwrap(),
store.get_by_height(6).unwrap(),
store.get_by_height(7).unwrap(),
];
let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest {
data: Some(Data::Origin(5)),
amount: u64::try_from(expected_headers.len()).unwrap(),
};
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), expected_headers.len());

for (rec, exp) in received.iter().zip(expected_headers.iter()) {
assert_eq!(rec.status_code, i32::from(StatusCode::Ok));
let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap();
assert_eq!(&decoded_header, exp);
}
}

#[tokio::test]
async fn request_range_behond_head_test() {
let store = Arc::new(Store::test_filled_store(5));
let expected_hashes = [store.get_by_height(5).ok(), None, None];
let expected_status_codes = [StatusCode::Ok, StatusCode::NotFound, StatusCode::NotFound];
assert_eq!(expected_hashes.len(), expected_status_codes.len());

let mut handler = ExchangeServerHandler::new(store);

let request = HeaderRequest::with_origin(5, u64::try_from(expected_hashes.len()).unwrap());
let (responder, response) = TestResponder::new();
handler.on_request_received(PeerId::random(), "test", request, responder);

let received = response.await.unwrap();
assert_eq!(received.len(), expected_hashes.len());

for (rec, (exp_status, exp_header)) in received
.iter()
.zip(expected_status_codes.iter().zip(expected_hashes.iter()))
{
assert_eq!(rec.status_code, i32::from(*exp_status));
if let Some(exp_header) = exp_header {
let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap();
assert_eq!(&decoded_header, exp_header);
}
}
}
}
Loading

0 comments on commit 44f0de0

Please sign in to comment.