diff --git a/Cargo.lock b/Cargo.lock index 15d6b07..b42570a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -554,7 +554,6 @@ name = "dephy-pproxy" version = "0.1.0" dependencies = [ "base64 0.22.1", - "bytes", "clap", "futures", "futures-util", @@ -571,6 +570,7 @@ dependencies = [ "prost 0.12.4", "thiserror", "tokio", + "tokio-util", "tonic", "tonic-build", "tonic-web", diff --git a/Cargo.toml b/Cargo.toml index 59175c0..10ee931 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ edition = "2021" [dependencies] base64 = "0.22.1" -bytes = "1.6.0" clap = "4.5.4" futures = "0.3.30" futures-util = "0.3.30" @@ -23,6 +22,7 @@ percent-encoding = "2.3.1" prost = "0.12.4" thiserror = "1.0.60" tokio = { version = "1.37.0", features = ["rt-multi-thread"] } +tokio-util = "0.7.11" tonic = "0.11.0" tonic-web = "0.11.0" tracing = "0.1.40" diff --git a/buf.yaml b/buf.yaml index 5a5af8e..5997ce7 100644 --- a/buf.yaml +++ b/buf.yaml @@ -2,3 +2,4 @@ version: v1 lint: except: - PACKAGE_DIRECTORY_MATCH + - DIRECTORY_SAME_PACKAGE diff --git a/build.rs b/build.rs index 1d1dd34..20cb68b 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,7 @@ fn main() -> Result<(), Box> { + let protos = ["proto/command_v1.proto", "proto/tunnel_v1.proto"]; tonic_build::configure() .protoc_arg("--experimental_allow_proto3_optional") - .compile(&["proto/command_v1.proto"], &["proto"])?; + .compile(&protos, &["proto"])?; Ok(()) } diff --git a/proto/command_v1.proto b/proto/command_v1.proto index d5fe531..94d9af3 100644 --- a/proto/command_v1.proto +++ b/proto/command_v1.proto @@ -10,16 +10,17 @@ message AddPeerResponse { string peer_id = 1; } -message RequestHttpServerRequest { +message CreateTunnelServerRequest { string peer_id = 1; - bytes data = 2; + optional string address = 2; } -message RequestHttpServerResponse { - bytes data = 1; +message CreateTunnelServerResponse { + string peer_id = 1; + string address = 2; } service CommandService { rpc AddPeer(AddPeerRequest) returns (AddPeerResponse); - rpc RequestHttpServer(RequestHttpServerRequest) returns (RequestHttpServerResponse); + rpc CreateTunnelServer(CreateTunnelServerRequest) returns (CreateTunnelServerResponse); } diff --git a/proto/tunnel_v1.proto b/proto/tunnel_v1.proto new file mode 100644 index 0000000..411e018 --- /dev/null +++ b/proto/tunnel_v1.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; +package tunnel.v1; + +enum TunnelCommand { + TUNNEL_COMMAND_UNSPECIFIED = 0; + TUNNEL_COMMAND_CONNECT = 1; + TUNNEL_COMMAND_CONNECT_RESP = 2; + TUNNEL_COMMAND_INBOUND_PACKAGE = 3; + TUNNEL_COMMAND_OUTBOUND_PACKAGE = 4; +} + +message Tunnel { + string tunnel_id = 1; + TunnelCommand command = 2; + optional bytes data = 3; +} diff --git a/src/command.rs b/src/command.rs index f300117..6ec4a73 100644 --- a/src/command.rs +++ b/src/command.rs @@ -38,14 +38,15 @@ impl proto::command_service_server::CommandService for PProxyCommander { .map_err(|e| tonic::Status::internal(format!("{:?}", e))) } - async fn request_http_server( + async fn create_tunnel_server( &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status> { + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { trace!("handle request: {:?}", request); self.handle - .request_http_server(request.into_inner()) + .create_tunnel_server(request.into_inner()) .await .map(Response::new) .map_err(|e| tonic::Status::internal(format!("{:?}", e))) diff --git a/src/error.rs b/src/error.rs index 7bfce44..69ef0dd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,18 @@ +use std::io::ErrorKind as IOErrorKind; + #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum Error { - #[error("Multiaddr {0} parse error")] + #[error("Multiaddr parse error: {0}")] MultiaddrParseError(String), + #[error("SocketAddr parse error: {0}")] + SocketAddrParseError(String), #[error("Failed to extract peer id from multiaddr: {0}")] FailedToExtractPeerIdFromMultiaddr(String), - #[error("PeerId {0} parse error")] + #[error("PeerId parse error: {0}")] PeerIdParseError(String), + #[error("TunnelId parse error: {0}")] + TunnelIdParseError(String), #[error("Essential task closed")] EssentialTaskClosed, #[error("Litep2p error: {0}")] @@ -16,12 +23,44 @@ pub enum Error { Httparse(#[from] httparse::Error), #[error("Incomplete http request")] IncompleteHttpRequest, - #[error("Protocol not support")] - ProtocolNotSupport, + #[error("Protocol not support: {0}")] + ProtocolNotSupport(String), #[error("Io error: {0}")] Io(#[from] std::io::Error), #[error("Unexpected response type")] UnexpectedResponseType, + #[error("Tunnel error: {0:?}")] + Tunnel(TunnelError), + #[error("Protobuf decode error: {0}")] + ProtobufDecode(#[from] prost::DecodeError), +} + +/// A list specifying general categories of Tunnel error like [std::io::ErrorKind]. +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +#[non_exhaustive] +pub enum TunnelError { + /// Failed to send data to peer. + DataSendFailed = 1, + /// The connection timed out when dialing. + ConnectionTimeout = 2, + /// Got [std::io::ErrorKind::ConnectionRefused] error from local stream. + ConnectionRefused = 3, + /// Got [std::io::ErrorKind::ConnectionAborted] error from local stream. + ConnectionAborted = 4, + /// Got [std::io::ErrorKind::ConnectionReset] error from local stream. + ConnectionReset = 5, + /// Got [std::io::ErrorKind::NotConnected] error from local stream. + NotConnected = 6, + /// The connection is closed by peer. + ConnectionClosed = 7, + /// A socket address could not be bound because the address is already in + /// use elsewhere. + AddrInUse = 8, + /// Tunnel already listened. + TunnelInUse = 9, + /// Unknown [std::io::ErrorKind] error. + Unknown = u8::MAX, } impl From> for Error { @@ -41,3 +80,28 @@ impl From for Error { Error::Litep2pRequestResponseError(err) } } + +impl From for Error { + fn from(error: TunnelError) -> Self { + Error::Tunnel(error) + } +} + +impl From for TunnelError { + fn from(kind: IOErrorKind) -> TunnelError { + match kind { + IOErrorKind::ConnectionRefused => TunnelError::ConnectionRefused, + IOErrorKind::ConnectionAborted => TunnelError::ConnectionAborted, + IOErrorKind::ConnectionReset => TunnelError::ConnectionReset, + IOErrorKind::NotConnected => TunnelError::NotConnected, + IOErrorKind::AddrInUse => TunnelError::AddrInUse, + _ => TunnelError::Unknown, + } + } +} + +impl From for TunnelError { + fn from(error: std::io::Error) -> TunnelError { + error.kind().into() + } +} diff --git a/src/gateway.rs b/src/gateway.rs deleted file mode 100644 index f79b97b..0000000 --- a/src/gateway.rs +++ /dev/null @@ -1,279 +0,0 @@ -use std::convert::Infallible; -use std::net::SocketAddr; -use std::str::FromStr; - -use base64::prelude::*; -use bytes::Bytes; -use http_body::Frame; -use http_body_util::combinators::BoxBody; -use http_body_util::BodyExt; -use http_body_util::Full; -use http_body_util::StreamBody; -use hyper::server::conn::http1; -use hyper::service::service_fn; -use hyper::Request; -use hyper::Response; -use hyper_util::rt::TokioIo; -use multiaddr::Multiaddr; -use tokio::net::TcpListener; - -use crate::command::proto::command_service_client::CommandServiceClient; -use crate::command::proto::AddPeerRequest; -use crate::command::proto::RequestHttpServerRequest; - -pub async fn proxy_gateway( - addr: SocketAddr, - commander_addr: SocketAddr, - peer_multiaddr: Multiaddr, -) -> Result<(), Box> { - let listener = TcpListener::bind(addr).await?; - - let mut client = CommandServiceClient::connect(format!("http://{}", commander_addr)).await?; - let pp_response = client - .add_peer(AddPeerRequest { - address: peer_multiaddr.to_string(), - peer_id: None, - }) - .await? - .into_inner(); - - loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - let peer_id = pp_response.peer_id.clone(); - - tokio::task::spawn(async move { - let peer_id = peer_id.clone(); - if let Err(err) = http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .serve_connection( - io, - service_fn(move |req| gateway(req, commander_addr, peer_id.clone())), - ) - .await - { - println!("Failed to serve connection: {:?}", err); - } - }); - } -} - -async fn gateway( - req: Request, - commander_addr: SocketAddr, - peer_id: String, -) -> Result>, hyper::Error> { - let Ok(mut data) = request_header_to_vec(&req) else { - return Ok(error_response( - "Failed to dump headers", - http::StatusCode::INTERNAL_SERVER_ERROR, - )); - }; - - let body = req.into_body().collect().await?; - data.extend(body.to_bytes()); - - let Ok(mut client) = CommandServiceClient::connect(format!("http://{}", commander_addr)).await - else { - return Ok(error_response( - "Failed to connect to pproxy", - http::StatusCode::BAD_GATEWAY, - )); - }; - - let pp_request = RequestHttpServerRequest { peer_id, data }; - let Ok(pp_response) = client - .request_http_server(pp_request) - .await - .map(|r| r.into_inner()) - else { - return Ok(error_response( - "Failed to request pproxy", - http::StatusCode::BAD_GATEWAY, - )); - }; - - let Ok(Some((resp, trailer))) = parse_response_header_easy(&pp_response.data) else { - return Ok(error_response( - "Failed to parse response headers from pproxy", - http::StatusCode::INTERNAL_SERVER_ERROR, - )); - }; - - let (parts, _) = resp.into_parts(); - - // Handle chunked encoding - if let Some(true) = parts.headers.get(http::header::TRANSFER_ENCODING).map(|v| { - let Ok(v) = v.to_str() else { return false }; - v.to_lowercase().contains("chunked") - }) { - let chunks = split_chunked_body(trailer); - return Ok(Response::from_parts(parts, stream_body(chunks))); - } - - let t = trailer.to_vec(); - Ok(Response::from_parts(parts, full(t))) -} - -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} - -fn stream_body(chunks: Vec, Infallible>>) -> BoxBody { - StreamBody::new(futures_util::stream::iter(chunks)) - .map_err(|never| match never {}) - .boxed() -} - -fn error_response( - msg: &'static str, - status: http::StatusCode, -) -> Response> { - let mut resp = Response::new(full(msg)); - *resp.status_mut() = status; - resp -} - -fn io_other_error(msg: &'static str) -> std::io::Error { - let e: Box = msg.into(); - std::io::Error::new(std::io::ErrorKind::Other, e) -} - -fn write_request_header( - r: &http::Request, - mut io: impl std::io::Write, -) -> std::io::Result { - let mut len = 0; - let verb = r.method().as_str(); - let path = r - .uri() - .path_and_query() - .ok_or_else(|| io_other_error("Invalid URI"))?; - - let need_to_insert_host = - r.uri().host().is_some() && !r.headers().contains_key(http::header::HOST); - - macro_rules! w { - ($x:expr) => { - io.write_all($x)?; - len += $x.len(); - }; - } - w!(verb.as_bytes()); - w!(b" "); - w!(path.as_str().as_bytes()); - w!(b" HTTP/1.1\r\n"); - - if need_to_insert_host { - w!(b"Host: "); - let host = r.uri().host().unwrap(); - w!(host.as_bytes()); - if let Some(p) = r.uri().port() { - w!(b":"); - w!(p.as_str().as_bytes()); - } - w!(b"\r\n"); - } - - let already_present = r.headers().get(http::header::AUTHORIZATION).is_some(); - let at_sign = r - .uri() - .authority() - .map_or(false, |x| x.as_str().contains('@')); - if !already_present && at_sign { - w!(b"Authorization: Basic "); - let a = r.uri().authority().unwrap().as_str(); - let a = &a[0..(a.find('@').unwrap())]; - let a = a - .as_bytes() - .split(|v| *v == b':') - .map(|v| percent_encoding::percent_decode(v).collect::>()) - .collect::>>() - .join(&b':'); - let a = BASE64_STANDARD.encode(a); - w!(a.as_bytes()); - w!(b"\r\n"); - } - - for (hn, hv) in r.headers() { - w!(hn.as_str().as_bytes()); - w!(b": "); - w!(hv.as_bytes()); - w!(b"\r\n"); - } - - w!(b"\r\n"); - - Ok(len) -} - -fn request_header_to_vec(r: &http::Request) -> std::io::Result> { - let v = Vec::with_capacity(120); - let mut c = std::io::Cursor::new(v); - write_request_header(r, &mut c)?; - Ok(c.into_inner()) -} - -#[allow(clippy::type_complexity)] -fn parse_response_header<'a>( - buf: &'a [u8], - headers_buffer: &mut [httparse::Header<'a>], -) -> Result, &'a [u8])>, Box> { - let mut x = httparse::Response::new(headers_buffer); - let n = match x.parse(buf)? { - httparse::Status::Partial => return Ok(None), - httparse::Status::Complete(size) => size, - }; - let trailer = &buf[n..]; - let mut r = Response::new(()); - *r.status_mut() = http::StatusCode::from_u16(x.code.unwrap())?; - // x.reason goes to nowhere - *r.version_mut() = http::Version::HTTP_11; // FIXME? - for h in x.headers { - let n = http::HeaderName::from_str(h.name)?; - let v = http::HeaderValue::from_bytes(h.value)?; - r.headers_mut().append(n, v); - } - Ok(Some((r, trailer))) -} - -#[allow(clippy::type_complexity)] -fn parse_response_header_easy( - buf: &[u8], -) -> Result, &[u8])>, Box> { - let mut h = [httparse::EMPTY_HEADER; 50]; - parse_response_header(buf, h.as_mut()) -} - -fn split_chunked_body(input: &[u8]) -> Vec, Infallible>> { - use std::io::BufRead; - use std::io::BufReader; - use std::io::Cursor; - use std::io::Read; - - let mut reader = BufReader::new(Cursor::new(input)); - let mut chunks = Vec::new(); - - loop { - let mut chunk_size = String::new(); - let read = reader.read_line(&mut chunk_size).unwrap(); - if read == 0 { - break; - } - - let chunk_size = usize::from_str_radix(chunk_size.trim_end(), 16).unwrap(); - - let mut chunk = vec![0; chunk_size]; - reader.read_exact(&mut chunk).unwrap(); - - let mut crlf = [0; 2]; - reader.read_exact(&mut crlf).unwrap(); - - chunks.push(Ok(Frame::data(Bytes::from(chunk)))); - } - - chunks -} diff --git a/src/lib.rs b/src/lib.rs index a242bf1..fbe14cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,32 +1,36 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::time::Duration; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use std::sync::Mutex; use futures::channel::oneshot; use litep2p::crypto::ed25519::SecretKey; use litep2p::protocol::request_response::DialOptions; use litep2p::protocol::request_response::RequestResponseEvent; -use litep2p::types::RequestId; use litep2p::PeerId; use multiaddr::Multiaddr; use multiaddr::Protocol; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpStream; +use prost::Message; use tokio::sync::mpsc; -use tokio::time::timeout; use tracing::warn; use crate::command::proto::AddPeerRequest; use crate::command::proto::AddPeerResponse; -use crate::command::proto::RequestHttpServerRequest; -use crate::command::proto::RequestHttpServerResponse; +use crate::command::proto::CreateTunnelServerRequest; +use crate::command::proto::CreateTunnelServerResponse; use crate::server::*; +use crate::tunnel::proto; +use crate::tunnel::tcp_connect_with_timeout; +use crate::tunnel::Tunnel; +use crate::tunnel::TunnelServer; +use crate::types::*; pub mod command; pub mod error; -pub mod gateway; mod server; +mod tunnel; +pub mod types; /// pproxy version pub const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -34,6 +38,9 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Default channel size. const DEFAULT_CHANNEL_SIZE: usize = 4096; +/// Timeout for proxied TCP connections +pub const TCP_SERVER_TIMEOUT: u64 = 30; + /// Public result type error type used by the crate. pub use crate::error::Error; pub type Result = std::result::Result; @@ -43,24 +50,42 @@ type CommandNotifier = oneshot::Sender; #[derive(Debug)] pub enum PProxyCommand { - AddPeer(AddPeerRequest), - RequestHttpServer(RequestHttpServerRequest), + AddPeer { + address: Multiaddr, + peer_id: PeerId, + }, + SendConnectCommand { + peer_id: PeerId, + tunnel_id: TunnelId, + tunnel_tx: mpsc::Sender>, + }, + SendOutboundPackageCommand { + peer_id: PeerId, + tunnel_id: TunnelId, + data: Vec, + }, } pub enum PProxyCommandResponse { - AddPeer(AddPeerResponse), - RequestHttpServer(RequestHttpServerResponse), + AddPeer { peer_id: PeerId }, + SendConnectCommand {}, + SendOutboundPackageCommand {}, } pub struct PProxy { + command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, command_rx: mpsc::Receiver<(PProxyCommand, CommandNotifier)>, - tunnel_notifier: HashMap, p2p_server: P2pServer, proxy_addr: Option, + inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, + inbound_tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>>, + outbound_tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>>, } pub struct PProxyHandle { command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + next_tunnel_id: Arc, + tunnel_servers: Mutex>, } pub enum FullLength { @@ -90,12 +115,19 @@ impl PProxy { ( Self { + command_tx: command_tx.clone(), command_rx, - tunnel_notifier: HashMap::new(), p2p_server: P2pServer::new(secret_key, server_addr), proxy_addr, + inbound_tunnels: HashMap::new(), + inbound_tunnel_txs: HashMap::new(), + outbound_tunnel_txs: HashMap::new(), + }, + PProxyHandle { + command_tx, + next_tunnel_id: Default::default(), + tunnel_servers: Default::default(), }, - PProxyHandle { command_tx }, ) } @@ -107,149 +139,108 @@ impl PProxy { event = self.p2p_server.next_event() => match event { None => return, - Some(ref event) => if let Err(error) = self.handle_p2p_server_event(event).await { - warn!("failed to handle event {:?}: {:?}", event, error); - + Some(event) => if let Err(error) = self.handle_p2p_server_event(event).await { + warn!("failed to handle event: {:?}", error); } }, command = self.command_rx.recv() => match command { None => return, - Some((ref command, tx)) => if let Err(error) = self.handle_command(command, tx).await { - warn!("failed to handle command {:?}: {:?}", command, error); + Some((command, tx)) => if let Err(error) = self.handle_command(command, tx).await { + warn!("failed to handle command: {:?}", error); } } } } } - async fn handle_p2p_server_event(&mut self, event: &P2pServerEvent) -> Result<()> { + async fn handle_p2p_server_event(&mut self, event: P2pServerEvent) -> Result<()> { + #[allow(clippy::single_match)] match event { P2pServerEvent::TunnelEvent(RequestResponseEvent::RequestReceived { + peer, request_id, request, .. }) => { - let Some(proxy_addr) = self.proxy_addr else { - return Err(Error::ProtocolNotSupport); - }; - - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut req = httparse::Request::new(&mut headers); - if req.parse(request)?.is_partial() { - return Err(Error::IncompleteHttpRequest); - } - - let mut stream = tcp_connect_with_timeout(&proxy_addr, 60).await?; - stream.write_all(request).await?; - - let mut response = Vec::new(); - let mut full_length = FullLength::NotParsed; - - loop { - let mut buf = [0u8; 30000]; - - match timeout(std::time::Duration::from_secs(60), stream.read(&mut buf)).await { - Ok(Ok(0)) => { - warn!("empty response from http server"); - break; - } - Ok(Ok(n)) => { - response.extend_from_slice(&buf[..n]); - } - x => { - warn!("http stream read failed {x:?}"); - break; - } + let msg = proto::Tunnel::decode(request.as_slice())?; + match msg.command() { + proto::TunnelCommand::Connect => { + tracing::info!("received connect command from peer: {:?}", peer); + let Some(proxy_addr) = self.proxy_addr else { + return Err(Error::ProtocolNotSupport("No proxy_addr".to_string())); + }; + + let tunnel_id = msg + .tunnel_id + .parse() + .map_err(|_| Error::TunnelIdParseError(msg.tunnel_id))?; + + let stream = tcp_connect_with_timeout(proxy_addr, 60).await?; + let mut tunnel = Tunnel::new(peer, tunnel_id, self.command_tx.clone()); + let (tunnel_tx, tunnel_rx) = mpsc::channel(1024); + tunnel.listen(stream, tunnel_rx).await?; + + self.inbound_tunnels.insert((peer, tunnel_id), tunnel); + self.inbound_tunnel_txs.insert((peer, tunnel_id), tunnel_tx); + + let response = proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::ConnectResp.into(), + data: None, + }; + + self.p2p_server + .tunnel_handle + .send_response(request_id, response.encode_to_vec()); } - if full_length.not_parsed() { - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut resp_checker = httparse::Response::new(&mut headers); - let res = resp_checker.parse(&response)?; - if res.is_complete() { - let content_length = resp_checker.headers.iter().find_map(|h| { - if h.name.to_lowercase() != "content-length" { - return None; - } - let Ok(value) = std::str::from_utf8(h.value) else { - return None; - }; - value.parse::().ok() - }); - - let transfor_encoding = resp_checker.headers.iter().find_map(|h| { - if h.name.to_lowercase() != "transfer-encoding" { - return None; - } - let Ok(value) = std::str::from_utf8(h.value) else { - return None; - }; - Some(value) - }); - - match (content_length, transfor_encoding) { - (Some(content_length), _) => { - let header_length = res.unwrap(); - full_length = FullLength::Parsed(header_length + content_length) - } - (None, Some(value)) if value.to_lowercase().contains("chunked") => { - full_length = FullLength::Chunked; - } - _ => { - full_length = FullLength::NotSet; - } - } - } - } + proto::TunnelCommand::InboundPackage => { + let tunnel_id = msg + .tunnel_id + .parse() + .map_err(|_| Error::TunnelIdParseError(msg.tunnel_id))?; - if let FullLength::Parsed(full_length) = full_length { - if response.len() >= full_length { - break; - } - } + let Some(tx) = self.outbound_tunnel_txs.get(&(peer, tunnel_id)) else { + return Err(Error::ProtocolNotSupport( + "No tunnel for InboundPackage".to_string(), + )); + }; + + tx.send(msg.data.unwrap_or_default()).await?; - if full_length.chunked() && response.ends_with(b"0\r\n\r\n") { - break; + // Have to do this to close the response waiter in remote. + self.p2p_server + .tunnel_handle + .send_response(request_id, vec![]); } - } - if response.is_empty() { - response = b"HTTP/1.1 500 Internal Server Error\r\n\r\n".to_vec(); - } + proto::TunnelCommand::OutboundPackage => { + let tunnel_id = msg + .tunnel_id + .parse() + .map_err(|_| Error::TunnelIdParseError(msg.tunnel_id))?; - self.p2p_server - .tunnel_handle - .send_response(*request_id, response); - } + let Some(tx) = self.inbound_tunnel_txs.get(&(peer, tunnel_id)) else { + return Err(Error::ProtocolNotSupport( + "No tunnel for OutboundPackage".to_string(), + )); + }; - P2pServerEvent::TunnelEvent(RequestResponseEvent::RequestFailed { - request_id, - error, - .. - }) => { - warn!("request failed: {:?}", error); - let Some(tx) = self.tunnel_notifier.remove(request_id) else { - todo!(); - }; - tx.send(Err(Error::Litep2pRequestResponseError(error.clone()))) - .map_err(|_| Error::EssentialTaskClosed)?; - } + tx.send(msg.data.unwrap_or_default()).await?; - P2pServerEvent::TunnelEvent(RequestResponseEvent::ResponseReceived { - request_id, - response, - .. - }) => { - let Some(tx) = self.tunnel_notifier.remove(request_id) else { - todo!(); - }; - tx.send(Ok(PProxyCommandResponse::RequestHttpServer( - RequestHttpServerResponse { - data: response.clone(), - }, - ))) - .map_err(|_| Error::EssentialTaskClosed)?; + // Have to do this to close the response waiter in remote. + self.p2p_server + .tunnel_handle + .send_response(request_id, vec![]); + } + + _ => { + return Err(Error::ProtocolNotSupport( + "Wrong tunnel command".to_string(), + )); + } + } } _ => {} } @@ -257,59 +248,93 @@ impl PProxy { Ok(()) } - async fn handle_command(&mut self, command: &PProxyCommand, tx: CommandNotifier) -> Result<()> { + async fn handle_command(&mut self, command: PProxyCommand, tx: CommandNotifier) -> Result<()> { match command { - PProxyCommand::AddPeer(request) => self.on_add_peer(request.clone(), tx).await, - PProxyCommand::RequestHttpServer(request) => { - self.on_request_http_server(request.clone(), tx).await + PProxyCommand::AddPeer { address, peer_id } => { + self.on_add_peer(address, peer_id, tx).await + } + PProxyCommand::SendConnectCommand { + peer_id, + tunnel_id, + tunnel_tx, + } => { + self.on_send_connect_command(peer_id, tunnel_id, tunnel_tx, tx) + .await + } + PProxyCommand::SendOutboundPackageCommand { + peer_id, + tunnel_id, + data, + } => { + self.on_send_outbound_package_command(peer_id, tunnel_id, data, tx) + .await } } } - async fn on_add_peer(&mut self, request: AddPeerRequest, tx: CommandNotifier) -> Result<()> { - let addr: Multiaddr = request - .address - .parse() - .map_err(|_| Error::MultiaddrParseError(request.address.clone()))?; - - let peer_id = request.peer_id.as_ref().map_or_else( - || extract_peer_id_from_multiaddr(&addr), - |peer_id| { - peer_id - .parse() - .map_err(|_| Error::PeerIdParseError(peer_id.clone())) - }, - )?; - + async fn on_add_peer( + &mut self, + addr: Multiaddr, + peer_id: PeerId, + tx: CommandNotifier, + ) -> Result<()> { self.p2p_server .litep2p .add_known_address(peer_id, vec![addr].into_iter()); - tx.send(Ok(PProxyCommandResponse::AddPeer(AddPeerResponse { - peer_id: peer_id.to_string(), - }))) - .map_err(|_| Error::EssentialTaskClosed) + tx.send(Ok(PProxyCommandResponse::AddPeer { peer_id })) + .map_err(|_| Error::EssentialTaskClosed) } - async fn on_request_http_server( + async fn on_send_connect_command( &mut self, - request: RequestHttpServerRequest, - rx: CommandNotifier, + peer_id: PeerId, + tunnel_id: TunnelId, + tunnel_tx: mpsc::Sender>, + tx: CommandNotifier, ) -> Result<()> { - let peer_id = request - .peer_id - .parse() - .map_err(|_| Error::PeerIdParseError(request.peer_id.clone()))?; + self.outbound_tunnel_txs + .insert((peer_id, tunnel_id), tunnel_tx); - let request_id = self - .p2p_server + let request = proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::Connect.into(), + data: None, + } + .encode_to_vec(); + + self.p2p_server .tunnel_handle - .send_request(peer_id, request.data.clone(), DialOptions::Dial) + .send_request(peer_id, request, DialOptions::Dial) .await?; - self.tunnel_notifier.insert(request_id, rx); + tracing::info!("send connect command to peer_id: {:?}", peer_id); - Ok(()) + tx.send(Ok(PProxyCommandResponse::SendConnectCommand {})) + .map_err(|_| Error::EssentialTaskClosed) + } + + async fn on_send_outbound_package_command( + &mut self, + peer_id: PeerId, + tunnel_id: TunnelId, + data: Vec, + tx: CommandNotifier, + ) -> Result<()> { + let request = proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::OutboundPackage.into(), + data: Some(data), + } + .encode_to_vec(); + + self.p2p_server + .tunnel_handle + .send_request(peer_id, request, DialOptions::Dial) + .await?; + + tx.send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {})) + .map_err(|_| Error::EssentialTaskClosed) } } @@ -317,40 +342,64 @@ impl PProxyHandle { pub async fn add_peer(&self, request: AddPeerRequest) -> Result { let (tx, rx) = oneshot::channel(); + let address: Multiaddr = request + .address + .parse() + .map_err(|_| Error::MultiaddrParseError(request.address.clone()))?; + + let peer_id = request.peer_id.map_or_else( + || extract_peer_id_from_multiaddr(&address), + |peer_id| { + peer_id + .parse() + .map_err(|_| Error::PeerIdParseError(peer_id)) + }, + )?; + self.command_tx - .send((PProxyCommand::AddPeer(request), tx)) + .send((PProxyCommand::AddPeer { address, peer_id }, tx)) .await?; let response = rx.await??; match response { - PProxyCommandResponse::AddPeer(response) => Ok(response), + PProxyCommandResponse::AddPeer { peer_id } => Ok(AddPeerResponse { + peer_id: peer_id.to_string(), + }), _ => Err(Error::UnexpectedResponseType), } } - pub async fn request_http_server( + pub async fn create_tunnel_server( &self, - request: RequestHttpServerRequest, - ) -> Result { - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut req = httparse::Request::new(&mut headers); - if req.parse(&request.data)?.is_partial() { - return Err(Error::IncompleteHttpRequest); - } + request: CreateTunnelServerRequest, + ) -> Result { + let peer_id = request + .peer_id + .parse() + .map_err(|_| Error::PeerIdParseError(request.peer_id))?; - let (tx, rx) = oneshot::channel(); + let address = request.address.unwrap_or("127.0.0.1:0".to_string()); + let address = address + .parse() + .map_err(|_| Error::SocketAddrParseError(address))?; - self.command_tx - .send((PProxyCommand::RequestHttpServer(request), tx)) - .await?; + let mut tunnel_server = TunnelServer::new( + peer_id, + self.next_tunnel_id.clone(), + self.command_tx.clone(), + ); + let address = tunnel_server.listen(address).await?; - let response = rx.await??; + self.tunnel_servers + .lock() + .unwrap() + .insert(peer_id, tunnel_server); - match response { - PProxyCommandResponse::RequestHttpServer(response) => Ok(response), - _ => Err(Error::UnexpectedResponseType), - } + Ok(CreateTunnelServerResponse { + peer_id: peer_id.to_string(), + address: address.to_string(), + }) } } @@ -366,20 +415,3 @@ fn extract_peer_id_from_multiaddr(multiaddr: &Multiaddr) -> Result { PeerId::from_multihash(multihash) .map_err(|_| Error::FailedToExtractPeerIdFromMultiaddr(multiaddr.to_string())) } - -pub async fn tcp_connect_with_timeout( - addr: &SocketAddr, - request_timeout_s: u64, -) -> Result { - match timeout( - Duration::from_secs(request_timeout_s), - TcpStream::connect(addr), - ) - .await - { - Ok(result) => result.map_err(Error::Io), - Err(_) => Err(Error::Io(std::io::Error::from( - std::io::ErrorKind::TimedOut, - ))), - } -} diff --git a/src/main.rs b/src/main.rs index 2cb5704..eece61b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,16 @@ +use std::net::SocketAddr; + use clap::Arg; use clap::ArgAction; use clap::ArgMatches; use clap::Command; +use dephy_pproxy::command::proto::command_service_client::CommandServiceClient; +use dephy_pproxy::command::proto::AddPeerRequest; +use dephy_pproxy::command::proto::CreateTunnelServerRequest; use dephy_pproxy::command::PProxyCommander; -use dephy_pproxy::gateway::proxy_gateway; use dephy_pproxy::PProxy; use litep2p::crypto::ed25519::SecretKey; +use multiaddr::Multiaddr; use tonic::transport::Server; fn parse_args() -> Command { @@ -46,8 +51,8 @@ fn parse_args() -> Command { .help("Will reverse proxy this address if set"), ); - let gateway = Command::new("gateway") - .about("Set up a local server that allows users proxy data to remote peer") + let create_tunnel_server = Command::new("create_tunnel_server") + .about("Set up a tunnel server that allows users proxy data to remote peer") .arg( Arg::new("COMMANDER_SERVER_ADDR") .long("commander-server-addr") @@ -57,23 +62,22 @@ fn parse_args() -> Command { .help("Commander server address"), ) .arg( - Arg::new("PROXY_GATEWAY_ADDR") - .long("proxy-gateway-addr") + Arg::new("TUNNEL_SERVER_ADDR") + .long("tunnel-server-addr") .num_args(1) - .default_value("127.0.0.1:10000") .action(ArgAction::Set) - .help("Set up a local HTTP server that allows users use peerid header to proxy requests to remote peer"), + .help("Tunnel server address, if not set a random port will be used"), ) .arg( Arg::new("PEER_MULTIADDR") - .long("peer-multiaddr") - .num_args(1) - .action(ArgAction::Set) - .required(true) - .help("The multiaddr of the remote peer"), + .long("peer-multiaddr") + .num_args(1) + .action(ArgAction::Set) + .required(true) + .help("The multiaddr of the remote peer"), ); - app = app.subcommand(serve).subcommand(gateway); + app = app.subcommand(serve).subcommand(create_tunnel_server); app } @@ -122,26 +126,48 @@ async fn serve(args: &ArgMatches) { ); } -async fn gateway(args: &ArgMatches) { +async fn create_tunnel_server(args: &ArgMatches) { let commander_server_addr = args .get_one::("COMMANDER_SERVER_ADDR") .unwrap() - .parse() + .parse::() .expect("Invalid command server address"); - let proxy_gateway_addr = args - .get_one::("PROXY_GATEWAY_ADDR") - .unwrap() - .parse() - .expect("Invalid proxy gateway address"); + let tunnel_server_addr = args.get_one::("TUNNEL_SERVER_ADDR").map(|addr| { + addr.parse::() + .expect("Invalid tunnel server address") + .to_string() + }); let peer_multiaddr = args .get_one::("PEER_MULTIADDR") .unwrap() - .parse() + .parse::() .expect("Missing peer multiaddr"); - println!("proxy_gateway_addr: {}", proxy_gateway_addr); - proxy_gateway(proxy_gateway_addr, commander_server_addr, peer_multiaddr) + + let mut client = CommandServiceClient::connect(format!("http://{}", commander_server_addr)) .await - .expect("Gateway server failed") + .expect("Connect to commander server failed"); + + let pp_response = client + .add_peer(AddPeerRequest { + address: peer_multiaddr.to_string(), + peer_id: None, + }) + .await + .expect("Add peer failed") + .into_inner(); + + let peer_id = pp_response.peer_id; + + let pp_response = client + .create_tunnel_server(CreateTunnelServerRequest { + peer_id, + address: tunnel_server_addr, + }) + .await + .expect("Create tunnel failed") + .into_inner(); + + println!("tunnel_server_addr: {}", pp_response.address); } #[tokio::main] @@ -156,7 +182,7 @@ async fn main() { serve(args).await; } - if let Some(args) = args.subcommand_matches("gateway") { - gateway(args).await; + if let Some(args) = args.subcommand_matches("create_tunnel_server") { + create_tunnel_server(args).await; } } diff --git a/src/tunnel.rs b/src/tunnel.rs new file mode 100644 index 0000000..85b1309 --- /dev/null +++ b/src/tunnel.rs @@ -0,0 +1,314 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::time::Duration; + +use futures::channel::oneshot; +use litep2p::PeerId; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tokio_util::sync::CancellationToken; + +use crate::error::TunnelError; +use crate::types::TunnelId; +use crate::CommandNotifier; +use crate::PProxyCommand; + +pub mod proto { + tonic::include_proto!("tunnel.v1"); +} + +pub struct TunnelServer { + peer_id: PeerId, + next_tunnel_id: Arc, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, +} + +pub struct TunnelServerListener { + peer_id: PeerId, + next_tunnel_id: Arc, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + tunnels: HashMap, +} + +pub struct Tunnel { + peer_id: PeerId, + tunnel_id: TunnelId, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + listener_cancel_token: Option, + listener: Option>, +} + +pub struct TunnelListener { + peer_id: PeerId, + tunnel_id: TunnelId, + local_stream: TcpStream, + remote_stream_rx: mpsc::Receiver>, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + cancel_token: CancellationToken, +} + +impl Drop for Tunnel { + fn drop(&mut self) { + if let Some(cancel_token) = self.listener_cancel_token.take() { + cancel_token.cancel(); + } + + if let Some(listener) = self.listener.take() { + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + listener.abort(); + }); + } + + tracing::info!("Tunnel {}-{} dropped", self.peer_id, self.tunnel_id); + } +} + +impl TunnelServer { + pub fn new( + peer_id: PeerId, + next_tunnel_id: Arc, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + ) -> Self { + Self { + peer_id, + next_tunnel_id, + pproxy_command_tx, + } + } + + pub async fn listen(&mut self, address: SocketAddr) -> Result { + let tcp_listener = TcpListener::bind(address).await?; + let local_addr = tcp_listener.local_addr()?; + + let mut listener = TunnelServerListener::new( + self.peer_id, + self.next_tunnel_id.clone(), + self.pproxy_command_tx.clone(), + ); + tokio::spawn(Box::pin(async move { listener.listen(tcp_listener).await })); + + Ok(local_addr) + } +} + +impl TunnelServerListener { + fn new( + peer_id: PeerId, + next_tunnel_id: Arc, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + ) -> Self { + Self { + peer_id, + next_tunnel_id, + pproxy_command_tx, + tunnels: HashMap::new(), + } + } + + fn next_tunnel_id(&mut self) -> TunnelId { + TunnelId::from(self.next_tunnel_id.fetch_add(1usize, Ordering::Relaxed)) + } + + async fn listen(&mut self, listener: TcpListener) { + loop { + let Ok((stream, _)) = listener.accept().await else { + continue; + }; + + let tunnel_id = self.next_tunnel_id(); + let mut tunnel = Tunnel::new(self.peer_id, tunnel_id, self.pproxy_command_tx.clone()); + + let (tx, rx) = oneshot::channel(); + let (tunnel_tx, tunnel_rx) = mpsc::channel(1024); + if let Err(e) = self + .pproxy_command_tx + .send(( + PProxyCommand::SendConnectCommand { + peer_id: self.peer_id, + tunnel_id, + tunnel_tx, + }, + tx, + )) + .await + { + tracing::error!("Send connect command channel tx failed: {e:?}"); + continue; + } + + match rx.await { + Err(e) => { + tracing::error!("Send connect command channel rx failed: {e:?}"); + continue; + } + Ok(Err(e)) => { + tracing::error!("Send connect command channel failed: {e:?}"); + continue; + } + Ok(Ok(_resp)) => {} + } + + if let Err(e) = tunnel.listen(stream, tunnel_rx).await { + tracing::error!("Tunnel listen failed: {e:?}"); + continue; + }; + + self.tunnels.insert(tunnel_id, tunnel); + } + } +} + +impl Tunnel { + pub fn new( + peer_id: PeerId, + tunnel_id: TunnelId, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + ) -> Self { + Self { + peer_id, + tunnel_id, + pproxy_command_tx, + listener: None, + listener_cancel_token: None, + } + } + + pub async fn listen( + &mut self, + local_stream: TcpStream, + remote_stream_rx: mpsc::Receiver>, + ) -> Result<(), TunnelError> { + if self.listener.is_some() { + return Err(TunnelError::TunnelInUse); + } + + let mut listener = TunnelListener::new( + self.peer_id, + self.tunnel_id, + local_stream, + remote_stream_rx, + self.pproxy_command_tx.clone(), + ) + .await; + let listener_cancel_token = listener.cancel_token(); + let listener_handler = tokio::spawn(Box::pin(async move { listener.listen().await })); + + self.listener = Some(listener_handler); + self.listener_cancel_token = Some(listener_cancel_token); + + Ok(()) + } +} + +impl TunnelListener { + async fn new( + peer_id: PeerId, + tunnel_id: TunnelId, + local_stream: TcpStream, + remote_stream_rx: mpsc::Receiver>, + pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, + ) -> Self { + Self { + peer_id, + tunnel_id, + local_stream, + remote_stream_rx, + pproxy_command_tx, + cancel_token: CancellationToken::new(), + } + } + + fn cancel_token(&self) -> CancellationToken { + self.cancel_token.clone() + } + + async fn listen(&mut self) { + let (mut local_read, mut local_write) = self.local_stream.split(); + + let listen_local = async { + loop { + if self.cancel_token.is_cancelled() { + break TunnelError::ConnectionClosed; + } + + let mut buf = [0u8; 30000]; + match local_read.read(&mut buf).await { + Err(e) => { + break e.kind().into(); + } + Ok(0) => { + break TunnelError::ConnectionClosed; + } + Ok(n) => { + let (tx, rx) = oneshot::channel(); + let data = buf[..n].to_vec(); + let command = PProxyCommand::SendOutboundPackageCommand { + peer_id: self.peer_id, + tunnel_id: self.tunnel_id, + data, + }; + if let Err(e) = self.pproxy_command_tx.send((command, tx)).await { + tracing::error!("Send tcp package channel tx failed: {e:?}"); + break TunnelError::DataSendFailed; + }; + + match rx.await { + Err(e) => { + tracing::error!("Send tcp package channel rx failed: {e:?}"); + break TunnelError::DataSendFailed; + } + Ok(Err(e)) => { + tracing::error!("Send tcp package channel failed: {e:?}"); + break TunnelError::DataSendFailed; + } + Ok(Ok(_resp)) => {} + } + } + } + } + }; + + let listen_remote = async { + loop { + if self.cancel_token.is_cancelled() { + break TunnelError::ConnectionClosed; + } + + if let Some(body) = self.remote_stream_rx.recv().await { + if let Err(e) = local_write.write_all(&body).await { + tracing::error!("Write to local stream failed: {e:?}"); + break e.kind().into(); + } + } + } + }; + + tokio::select! { + defeat = listen_local => { + tracing::info!("Local stream closed: {defeat:?}"); + }, + defeat = listen_remote => { + tracing::info!("Remote stream closed: {defeat:?}"); + } + } + } +} + +pub async fn tcp_connect_with_timeout( + addr: SocketAddr, + request_timeout_s: u64, +) -> Result { + let fut = TcpStream::connect(addr); + match timeout(Duration::from_secs(request_timeout_s), fut).await { + Ok(result) => result.map_err(From::from), + Err(_) => Err(TunnelError::ConnectionTimeout), + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..498302d --- /dev/null +++ b/src/types.rs @@ -0,0 +1,24 @@ +use std::fmt::Display; +use std::str::FromStr; + +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct TunnelId(usize); + +impl TunnelId { + pub fn from>(value: T) -> Self { + TunnelId(value.into()) + } +} + +impl Display for TunnelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for TunnelId { + type Err = std::num::ParseIntError; + fn from_str(s: &str) -> Result { + s.parse().map(TunnelId) + } +}