From 44f0de0f5a8e2af672ea25f590ce73cf982e653f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Florkiewicz?= Date: Mon, 18 Sep 2023 15:52:32 +0200 Subject: [PATCH] Implement ExchangeServerHandler --- node/src/exchange.rs | 7 +- node/src/exchange/client.rs | 2 +- node/src/exchange/server.rs | 303 ++++++++++++++++++++++++++++++++++-- node/src/exchange/utils.rs | 32 +++- 4 files changed, 318 insertions(+), 26 deletions(-) diff --git a/node/src/exchange.rs b/node/src/exchange.rs index 8738d7d7a..612420d5e 100644 --- a/node/src/exchange.rs +++ b/node/src/exchange.rs @@ -24,7 +24,7 @@ mod utils; pub use utils::ExtendedHeaderExt; use crate::exchange::client::ExchangeClientHandler; -use crate::exchange::server::ExchangeServerHandler; +use crate::exchange::server::{ExchangeServerHandler, RequestResponseResponder}; use crate::p2p::P2pError; use crate::peer_tracker::PeerTracker; use crate::store::Store; @@ -82,7 +82,7 @@ impl ExchangeBehaviour { request_response::Config::default(), ), client_handler: ExchangeClientHandler::new(config.peer_tracker), - server_handler: ExchangeServerHandler::new(), + server_handler: ExchangeServerHandler::new(config.header_store), } } @@ -144,8 +144,9 @@ impl ExchangeBehaviour { }, peer, } => { + let responder = RequestResponseResponder::new(&mut self.req_resp, channel); self.server_handler - .on_request_received(peer, request_id, request, channel); + .on_request_received(peer, request_id, request, responder); } // Response to inbound request was sent diff --git a/node/src/exchange/client.rs b/node/src/exchange/client.rs index f0e1270a0..51531f6ca 100644 --- a/node/src/exchange/client.rs +++ b/node/src/exchange/client.rs @@ -286,7 +286,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::exchange::utils::ExtendedHeaderExt; + use crate::exchange::utils::ToHeaderResponse; use celestia_proto::p2p::pb::header_request::Data; use celestia_proto::p2p::pb::StatusCode; use celestia_types::consts::HASH_SIZE; diff --git a/node/src/exchange/server.rs b/node/src/exchange/server.rs index 642274870..aebdfa0b8 100644 --- a/node/src/exchange/server.rs +++ b/node/src/exchange/server.rs @@ -1,33 +1,65 @@ -use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse}; -use libp2p::{ - request_response::{InboundFailure, RequestId, ResponseChannel}, - PeerId, -}; +use crate::exchange::utils::{HeaderRequestExt, ToHeaderResponse}; +use crate::exchange::HeaderCodec; +use crate::store::Store; +use celestia_proto::p2p::pb::{header_request, HeaderRequest, HeaderResponse, StatusCode}; +use celestia_types::Hash; +use libp2p::request_response::{Behaviour, InboundFailure, RequestId, ResponseChannel}; +use libp2p::PeerId; +use std::fmt::{Debug, Display}; +use std::sync::Arc; +use tendermint::hash::Algorithm; use tracing::instrument; +use tracing::{error, info, trace, warn}; -pub(super) struct ExchangeServerHandler { - // TODO +const INVALID_REQUEST_MSG: &str = "invalid request"; + +pub(crate) struct ExchangeServerHandler { + store: Arc, +} + +// TODO: how much do we want to augument response with information about failure +fn gen_invalid_request_response() -> Vec { + vec![HeaderResponse { + status_code: StatusCode::Invalid.into(), + body: INVALID_REQUEST_MSG.as_bytes().to_vec(), + }] } impl ExchangeServerHandler { - pub(super) fn new() -> Self { - ExchangeServerHandler {} + pub(super) fn new(store: Arc) -> Self { + ExchangeServerHandler { store } } - #[instrument(level = "trace", skip(self, _respond_to))] - pub(super) fn on_request_received( + #[instrument(level = "trace", skip(self, responder))] + pub(super) fn on_request_received( &mut self, peer: PeerId, - request_id: RequestId, + request_id: ID, request: HeaderRequest, - _respond_to: ResponseChannel>, + responder: R, ) { - // TODO + trace!("request_received; request_id: {request_id}, request: {request:?}"); + let Some((amount, data)) = self.parse_request(request) else { + responder.send_response(gen_invalid_request_response()); + return; + }; + + let response = match data { + header_request::Data::Origin(0) => self.handle_request_current_head(), + header_request::Data::Origin(height) => self.handle_request_by_height(height, amount), + header_request::Data::Hash(hash) => self.handle_request_by_hash(hash), + }; + + responder.send_response(response); + } + + fn handle_request_current_head(&self) -> Vec { + vec![self.store.get_head().to_header_response()] } #[instrument(level = "trace", skip(self))] pub(super) fn on_response_sent(&mut self, peer: PeerId, request_id: RequestId) { - // TODO + info!("response_sent; request_id: {request_id}, peer: {peer}"); } #[instrument(level = "trace", skip(self))] @@ -37,6 +69,245 @@ impl ExchangeServerHandler { request_id: RequestId, error: InboundFailure, ) { - // TODO + info!("on_failure; request_id: {request_id}, peer: {peer}, error: {error:?}"); + } + + fn parse_request(&self, request: HeaderRequest) -> Option<(u64, header_request::Data)> { + if !request.is_valid() { + return None; + } + + let HeaderRequest { + amount, + data: Some(data), + } = request + else { + return None; + }; + + Some((amount, data)) + } + + fn handle_request_by_hash(&self, hash: Vec) -> Vec { + let hash = match Hash::from_bytes(Algorithm::Sha256, &hash) { + Ok(h) => h, + Err(e) => { + error!("error decoding hash: {e}"); + return gen_invalid_request_response(); + } + }; + + let header = self.store.get_by_hash(&hash); + + vec![header.to_header_response()] + } + + fn handle_request_by_height(&self, origin: u64, amount: u64) -> Vec { + info!("get by height {origin} +{amount}"); + let mut r = vec![]; + for i in origin..origin + amount { + r.push(self.store.get_by_height(i).to_header_response()); + } + + r + } +} + +pub(super) trait Responder { + fn send_response(self, response: Vec); +} + +pub(super) struct RequestResponseResponder<'a> { + behaviour: &'a mut Behaviour, + response_channel: ResponseChannel>, +} + +impl<'a> RequestResponseResponder<'a> { + pub fn new( + behaviour: &'a mut Behaviour, + response_channel: ResponseChannel>, + ) -> Self { + Self { + behaviour, + response_channel, + } + } +} + +impl Responder for RequestResponseResponder<'_> { + fn send_response(self, response: Vec) { + let Self { + behaviour, + response_channel, + } = self; + // response was prepared specifically for this request, we can drop it + // we'll get notified about failure via `Event::InboundFailure` + behaviour.send_response(response_channel, response).ok(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::exchange::utils::HeaderRequestExt; + use crate::store::Store; + use celestia_proto::p2p::pb::header_request::Data; + use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse, StatusCode}; + use celestia_types::ExtendedHeader; + use libp2p::PeerId; + use std::sync::Arc; + use tendermint_proto::Protobuf; + use tokio::sync::oneshot; + + struct TestResponder(oneshot::Sender>); + + impl TestResponder { + fn new() -> (Self, oneshot::Receiver>) { + let (tx, rx) = oneshot::channel(); + (Self(tx), rx) + } + } + + impl Responder for TestResponder { + fn send_response(self, response: Vec) { + self.0.send(response).ok(); + } + } + + #[tokio::test] + async fn request_header_test() { + let store = Arc::new(Store::test_filled_store(3)); + let expected_origin = store.get_by_height(1).unwrap(); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::with_origin(1, 1); + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_origin); + } + + #[tokio::test] + async fn request_head_test() { + let store = Arc::new(Store::test_filled_store(4)); + let expected_head = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::head_request(); + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_head); + } + + #[tokio::test] + async fn invalid_amount_request_test() { + let store = Arc::new(Store::test_filled_store(1)); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::with_origin(0, 0); + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn none_data_request_test() { + let store = Arc::new(Store::test_filled_store(1)); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest { + data: None, + amount: 1, + }; + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn request_hash_test() { + let store = Arc::new(Store::test_filled_store(1)); + let stored_header = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::with_hash(stored_header.hash()); + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, stored_header); + } + + #[tokio::test] + async fn request_range_test() { + let store = Arc::new(Store::test_filled_store(10)); + let expected_headers = [ + store.get_by_height(5).unwrap(), + store.get_by_height(6).unwrap(), + store.get_by_height(7).unwrap(), + ]; + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest { + data: Some(Data::Origin(5)), + amount: u64::try_from(expected_headers.len()).unwrap(), + }; + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), expected_headers.len()); + + for (rec, exp) in received.iter().zip(expected_headers.iter()) { + assert_eq!(rec.status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp); + } + } + + #[tokio::test] + async fn request_range_behond_head_test() { + let store = Arc::new(Store::test_filled_store(5)); + let expected_hashes = [store.get_by_height(5).ok(), None, None]; + let expected_status_codes = [StatusCode::Ok, StatusCode::NotFound, StatusCode::NotFound]; + assert_eq!(expected_hashes.len(), expected_status_codes.len()); + + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::with_origin(5, u64::try_from(expected_hashes.len()).unwrap()); + let (responder, response) = TestResponder::new(); + handler.on_request_received(PeerId::random(), "test", request, responder); + + let received = response.await.unwrap(); + assert_eq!(received.len(), expected_hashes.len()); + + for (rec, (exp_status, exp_header)) in received + .iter() + .zip(expected_status_codes.iter().zip(expected_hashes.iter())) + { + assert_eq!(rec.status_code, i32::from(*exp_status)); + if let Some(exp_header) = exp_header { + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp_header); + } + } } } diff --git a/node/src/exchange/utils.rs b/node/src/exchange/utils.rs index 3661d1379..752eb9d3d 100644 --- a/node/src/exchange/utils.rs +++ b/node/src/exchange/utils.rs @@ -1,14 +1,14 @@ +use celestia_proto::header::pb::ExtendedHeader as RawExtendedHeader; use celestia_proto::p2p::pb::header_request::Data; use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse, StatusCode}; use celestia_types::consts::HASH_SIZE; +use celestia_types::{DataAvailabilityHeader, ValidatorSet}; use celestia_types::{ExtendedHeader, Hash}; -use tendermint_proto::Protobuf; -use tendermint::Time; -use tendermint::{block::header::Version, AppHash}; use tendermint::block::header::Header; use tendermint::block::Commit; -use celestia_types::{DataAvailabilityHeader, ValidatorSet,}; -use celestia_proto::header::pb::ExtendedHeader as RawExtendedHeader; +use tendermint::Time; +use tendermint::{block::header::Version, AppHash}; +use tendermint_proto::Protobuf; use crate::exchange::ExchangeError; use crate::store::ReadError; @@ -16,6 +16,7 @@ use crate::store::ReadError; pub(super) trait HeaderRequestExt { fn with_origin(origin: u64, amount: u64) -> HeaderRequest; fn with_hash(hash: Hash) -> HeaderRequest; + fn head_request() -> HeaderRequest; fn is_valid(&self) -> bool; fn is_head_request(&self) -> bool; } @@ -35,6 +36,10 @@ impl HeaderRequestExt for HeaderRequest { } } + fn head_request() -> HeaderRequest { + HeaderRequest::with_origin(0, 1) + } + fn is_valid(&self) -> bool { match (&self.data, self.amount) { (None, _) | (_, 0) => false, @@ -71,7 +76,6 @@ pub trait ExtendedHeaderExt { impl ExtendedHeaderExt for ExtendedHeader { fn with_height(height: u64) -> ExtendedHeader { - RawExtendedHeader { header: Some( Header { @@ -131,6 +135,22 @@ impl ToHeaderResponse for ExtendedHeader { } } +impl ToHeaderResponse for Result { + fn to_header_response(&self) -> HeaderResponse { + match self { + Ok(h) => h.to_header_response(), + Err(e) => HeaderResponse { + // TODO: how forthcoming should we be with errors and description? + body: vec![], + status_code: match e { + ReadError::NotFound => StatusCode::NotFound.into(), + _ => StatusCode::Invalid.into(), + }, + }, + } + } +} + #[cfg(test)] mod tests { use super::*;