From 600c52dd21afee7ee4918bcf66884f0375b2a99b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Florkiewicz?= Date: Thu, 21 Sep 2023 15:43:41 +0200 Subject: [PATCH] Implement ExchangeServerHandler --- node/Cargo.toml | 1 + node/src/exchange.rs | 15 +- node/src/exchange/server.rs | 352 ++++++++++++++++++++++++++++++++++-- node/src/exchange/utils.rs | 25 +++ 4 files changed, 372 insertions(+), 21 deletions(-) diff --git a/node/Cargo.toml b/node/Cargo.toml index cdf5d2eb..fcde887b 100644 --- a/node/Cargo.toml +++ b/node/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" [dependencies] celestia-proto = { workspace = true } celestia-types = { workspace = true } +tendermint = { workspace = true } tendermint-proto = { workspace = true } async-trait = "0.1.73" diff --git a/node/src/exchange.rs b/node/src/exchange.rs index bc273220..67aaa3b8 100644 --- a/node/src/exchange.rs +++ b/node/src/exchange.rs @@ -2,6 +2,7 @@ use std::io; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::exchange::request_response::ResponseChannel; use async_trait::async_trait; use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse}; use celestia_types::ExtendedHeader; @@ -36,9 +37,11 @@ const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024; /// Maximum length of the protobuf length delimiter in bytes const PROTOBUF_MAX_LENGTH_DELIMITER_LEN: usize = 10; +type RequestType = HeaderRequest; +type ResponseType = Vec; type ReqRespBehaviour = request_response::Behaviour; -type ReqRespEvent = request_response::Event>; -type ReqRespMessage = request_response::Message>; +type ReqRespEvent = request_response::Event; +type ReqRespMessage = request_response::Message; pub(crate) struct ExchangeBehaviour where @@ -46,7 +49,7 @@ where { req_resp: ReqRespBehaviour, client_handler: ExchangeClientHandler, - server_handler: ExchangeServerHandler, + server_handler: ExchangeServerHandler>, } pub(crate) struct ExchangeConfig<'a, S> { @@ -232,6 +235,12 @@ where } } + while let Poll::Ready((channel, response)) = self.server_handler.poll(cx) { + // response was prepared specifically for the request, we can drop it + // in case of error we'll get Event::InboundFailure + self.req_resp.send_response(channel, response).ok(); + } + Poll::Pending } } diff --git a/node/src/exchange/server.rs b/node/src/exchange/server.rs index 8ee24424..f6472a03 100644 --- a/node/src/exchange/server.rs +++ b/node/src/exchange/server.rs @@ -1,52 +1,368 @@ +use std::fmt::{Debug, Display}; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse}; +use celestia_proto::p2p::pb::{header_request, HeaderRequest, HeaderResponse}; +use celestia_types::Hash; +use futures::stream::FuturesUnordered; +use futures::{FutureExt, Stream}; use libp2p::{ - request_response::{InboundFailure, RequestId, ResponseChannel}, + request_response::{InboundFailure, RequestId}, PeerId, }; -use tracing::instrument; +use tendermint::hash::Algorithm; +use tracing::{instrument, trace}; +use crate::exchange::utils::{ExtendedHeaderExt, HeaderRequestExt, HeaderResponseExt}; +use crate::exchange::ResponseType; use crate::store::Store; -pub(super) struct ExchangeServerHandler +type StoreJobType = dyn Future + Send; + +pub(crate) struct ExchangeServerHandler where - S: Store + 'static, + S: Store, { - _store: Arc, + store: Arc, + store_jobs: FuturesUnordered>>>, } -impl ExchangeServerHandler +impl ExchangeServerHandler where S: Store + 'static, + C: Send + 'static, { pub(super) fn new(store: Arc) -> Self { - ExchangeServerHandler { _store: store } + ExchangeServerHandler { + store, + store_jobs: FuturesUnordered::new(), + } } - #[instrument(level = "trace", skip(self, _respond_to))] - pub(super) fn on_request_received( + #[instrument(level = "trace", skip(self, response_channel))] + pub(super) fn on_request_received( &mut self, peer: PeerId, - request_id: RequestId, + request_id: ID, request: HeaderRequest, - _respond_to: ResponseChannel>, - ) { - // TODO + response_channel: C, + ) where + ID: Display + Debug, + { + let Some((amount, data)) = self.parse_request(request) else { + self.store_jobs + .push(handle_invalid_request(response_channel).boxed()); + return; + }; + + let store_job = match data { + header_request::Data::Origin(0) => { + handle_request_current_head(self.store.clone(), response_channel).boxed() + } + header_request::Data::Origin(height) => { + handle_request_by_height(self.store.clone(), response_channel, height, amount) + .boxed() + } + header_request::Data::Hash(hash) => { + handle_request_by_hash(self.store.clone(), response_channel, hash).boxed() + } + }; + + self.store_jobs.push(store_job); } - #[instrument(level = "trace", skip(self))] pub(super) fn on_response_sent(&mut self, peer: PeerId, request_id: RequestId) { - // TODO + trace!("response_sent; request_id: {request_id}, peer: {peer}"); } - #[instrument(level = "trace", skip(self))] pub(super) fn on_failure( &mut self, peer: PeerId, request_id: RequestId, error: InboundFailure, ) { - // TODO + // TODO: cancel job if libp2p already failed it? + trace!("on_failure; request_id: {request_id}, peer: {peer}, error: {error:?}"); + } + + fn parse_request(&self, request: HeaderRequest) -> Option<(u64, header_request::Data)> { + if !request.is_valid() { + return None; + } + + let HeaderRequest { + amount, + data: Some(data), + } = request + else { + return None; + }; + + Some((amount, data)) + } + + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<(C, ResponseType)> { + while let Poll::Ready(ev) = Pin::new(&mut self.store_jobs).poll_next(cx) { + if let Some(response) = ev { + return Poll::Ready(response); + } + } + + Poll::Pending + } +} + +async fn handle_request_current_head(store: Arc, channel: C) -> (C, ResponseType) +where + S: Store, +{ + let response = match store.get_head().await { + Ok(head) => head.to_header_response(), + Err(_) => HeaderResponse::not_found(), + }; + + (channel, vec![response]) +} + +async fn handle_request_by_hash(store: Arc, channel: C, hash: Vec) -> (C, ResponseType) +where + S: Store, +{ + let Ok(hash) = Hash::from_bytes(Algorithm::Sha256, &hash) else { + return (channel, vec![HeaderResponse::invalid()]); + }; + + let response = match store.get_by_hash(&hash).await { + Ok(head) => head.to_header_response(), + Err(_) => HeaderResponse::not_found(), + }; + + (channel, vec![response]) +} + +async fn handle_request_by_height( + store: Arc, + channel: C, + origin: u64, + amount: u64, +) -> (C, ResponseType) +where + S: Store, +{ + let mut responses = vec![]; + for i in origin..origin + amount { + let response = match store.get_by_height(i).await { + Ok(head) => head.to_header_response(), + Err(_) => HeaderResponse::not_found(), + }; + responses.push(response); + } + + (channel, responses) +} + +async fn handle_invalid_request(channel: C) -> (C, ResponseType) +where + C: Send, +{ + (channel, vec![HeaderResponse::invalid()]) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::exchange::utils::HeaderRequestExt; + use crate::test_utils::gen_filled_store; + use celestia_proto::p2p::pb::header_request::Data; + use celestia_proto::p2p::pb::{HeaderRequest, StatusCode}; + use celestia_types::ExtendedHeader; + use futures::task::noop_waker_ref; + use libp2p::PeerId; + use std::sync::Arc; + use std::task::Context; + use tendermint_proto::Protobuf; + + #[tokio::test] + async fn request_header_test() { + let store = Arc::new(gen_filled_store(3)); + let expected_genesis = store.get_by_height(1).unwrap(); + let mut handler = ExchangeServerHandler::new(store); + let channel = TestResponseChannel(1); + + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_origin(1, 1), + channel, + ); + + let (received_channel, received_response) = wait_for_poll(&mut handler); + + assert_eq!(channel, received_channel); + assert_eq!(received_response.len(), 1); + assert_eq!(received_response[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received_response[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_genesis); + } + + #[tokio::test] + async fn request_head_test() { + let store = Arc::new(gen_filled_store(4)); + let expected_head = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(store); + let channel = TestResponseChannel(1); + + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::head_request(), + channel, + ); + + let (received_channel, received_response) = wait_for_poll(&mut handler); + + assert_eq!(channel, received_channel); + assert_eq!(received_response.len(), 1); + assert_eq!(received_response[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received_response[0].body[..]).unwrap(); + assert_eq!(decoded_header, expected_head); + } + + #[tokio::test] + async fn invalid_amount_request_test() { + let store = Arc::new(gen_filled_store(1)); + let mut handler = ExchangeServerHandler::new(store); + + let channel = TestResponseChannel(1); + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_origin(0, 0), + channel, + ); + + let (_, received) = wait_for_poll(&mut handler); + + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn none_data_request_test() { + let store = Arc::new(gen_filled_store(1)); + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest { + data: None, + amount: 1, + }; + let channel = TestResponseChannel(1); + handler.on_request_received(PeerId::random(), "test", request, channel); + + let (_, received) = wait_for_poll(&mut handler); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Invalid)); + } + + #[tokio::test] + async fn request_hash_test() { + let store = Arc::new(gen_filled_store(1)); + let stored_header = store.get_head().unwrap(); + let mut handler = ExchangeServerHandler::new(store); + + let channel = TestResponseChannel(1); + handler.on_request_received( + PeerId::random(), + "test", + HeaderRequest::with_hash(stored_header.hash()), + channel, + ); + + let (_, received) = wait_for_poll(&mut handler); + assert_eq!(received.len(), 1); + assert_eq!(received[0].status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&received[0].body[..]).unwrap(); + assert_eq!(decoded_header, stored_header); + } + + #[tokio::test] + async fn request_range_test() { + let store = Arc::new(gen_filled_store(10)); + let expected_headers = [ + store.get_by_height(5).unwrap(), + store.get_by_height(6).unwrap(), + store.get_by_height(7).unwrap(), + ]; + let mut handler = ExchangeServerHandler::new(store); + let channel = TestResponseChannel(1); + + let request = HeaderRequest { + data: Some(Data::Origin(5)), + amount: u64::try_from(expected_headers.len()).unwrap(), + }; + handler.on_request_received(PeerId::random(), "test", request, channel); + + let (_, received) = wait_for_poll(&mut handler); + + assert_eq!(received.len(), expected_headers.len()); + + for (rec, exp) in received.iter().zip(expected_headers.iter()) { + assert_eq!(rec.status_code, i32::from(StatusCode::Ok)); + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp); + } + } + + #[tokio::test] + async fn request_range_behond_head_test() { + let store = Arc::new(gen_filled_store(5)); + let expected_hashes = [store.get_by_height(5).ok(), None, None]; + let expected_status_codes = [StatusCode::Ok, StatusCode::NotFound, StatusCode::NotFound]; + assert_eq!(expected_hashes.len(), expected_status_codes.len()); + + let mut handler = ExchangeServerHandler::new(store); + + let request = HeaderRequest::with_origin(5, u64::try_from(expected_hashes.len()).unwrap()); + let channel = TestResponseChannel(1); + handler.on_request_received(PeerId::random(), "test", request, channel); + + let (_, received) = wait_for_poll(&mut handler); + + assert_eq!(received.len(), expected_hashes.len()); + + for (rec, (exp_status, exp_header)) in received + .iter() + .zip(expected_status_codes.iter().zip(expected_hashes.iter())) + { + assert_eq!(rec.status_code, i32::from(*exp_status)); + if let Some(exp_header) = exp_header { + let decoded_header = ExtendedHeader::decode(&rec.body[..]).unwrap(); + assert_eq!(&decoded_header, exp_header); + } + } + } + + #[derive(Debug, PartialEq, Clone, Copy)] + struct TestResponseChannel(pub u64); + + fn create_dummy_async_context() -> Context<'static> { + let waker = noop_waker_ref(); + Context::from_waker(waker) + } + + fn wait_for_poll( + handler: &mut ExchangeServerHandler, + ) -> (TestResponseChannel, ResponseType) + where + S: Store + 'static, + { + let mut cx = create_dummy_async_context(); + loop { + if let Poll::Ready(r) = handler.poll(&mut cx) { + return r; + }; + } } } diff --git a/node/src/exchange/utils.rs b/node/src/exchange/utils.rs index a4f9dfd9..d7e5d273 100644 --- a/node/src/exchange/utils.rs +++ b/node/src/exchange/utils.rs @@ -6,9 +6,13 @@ use tendermint_proto::Protobuf; use crate::exchange::ExchangeError; +const INVALID_REQUEST_MSG: &str = "invalid request"; +const NOT_FOUND_MSG: &str = "not found"; + pub(super) trait HeaderRequestExt { fn with_origin(origin: u64, amount: u64) -> HeaderRequest; fn with_hash(hash: Hash) -> HeaderRequest; + fn head_request() -> HeaderRequest; fn is_valid(&self) -> bool; fn is_head_request(&self) -> bool; } @@ -28,6 +32,10 @@ impl HeaderRequestExt for HeaderRequest { } } + fn head_request() -> HeaderRequest { + HeaderRequest::with_origin(0, 1) + } + fn is_valid(&self) -> bool { match (&self.data, self.amount) { (None, _) | (_, 0) => false, @@ -44,6 +52,9 @@ impl HeaderRequestExt for HeaderRequest { pub(super) trait HeaderResponseExt { fn to_extended_header(&self) -> Result; + + fn not_found() -> HeaderResponse; + fn invalid() -> HeaderResponse; } impl HeaderResponseExt for HeaderResponse { @@ -56,6 +67,20 @@ impl HeaderResponseExt for HeaderResponse { } } } + + // TODO: how forthcoming should we be with errors and description? + fn not_found() -> HeaderResponse { + HeaderResponse { + body: NOT_FOUND_MSG.as_bytes().to_vec(), + status_code: StatusCode::NotFound.into(), + } + } + fn invalid() -> HeaderResponse { + HeaderResponse { + status_code: StatusCode::Invalid.into(), + body: INVALID_REQUEST_MSG.as_bytes().to_vec(), + } + } } pub(super) trait ExtendedHeaderExt {