From f783a10b122f16556a33072eb1ffb67a17a378eb Mon Sep 17 00:00:00 2001 From: magine Date: Fri, 28 Jun 2024 20:59:24 +0800 Subject: [PATCH] Replace http server with tcp tunnel --- Cargo.lock | 1 + Cargo.toml | 1 + buf.yaml | 1 + build.rs | 3 +- proto/command_v1.proto | 11 +- proto/tunnel_v1.proto | 16 +++ src/command.rs | 8 +- src/error.rs | 48 +++++++ src/gateway.rs | 279 ----------------------------------------- src/lib.rs | 233 ++++++++++++++-------------------- src/main.rs | 76 +++++++---- src/tunnel.rs | 188 +++++++++++++++++++++++++++ src/types.rs | 16 +++ 13 files changed, 429 insertions(+), 452 deletions(-) create mode 100644 proto/tunnel_v1.proto delete mode 100644 src/gateway.rs create mode 100644 src/tunnel.rs create mode 100644 src/types.rs diff --git a/Cargo.lock b/Cargo.lock index 15d6b07..9c70c89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -571,6 +571,7 @@ dependencies = [ "prost 0.12.4", "thiserror", "tokio", + "tokio-util", "tonic", "tonic-build", "tonic-web", diff --git a/Cargo.toml b/Cargo.toml index 59175c0..48e4e70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ percent-encoding = "2.3.1" prost = "0.12.4" thiserror = "1.0.60" tokio = { version = "1.37.0", features = ["rt-multi-thread"] } +tokio-util = "0.7.11" tonic = "0.11.0" tonic-web = "0.11.0" tracing = "0.1.40" diff --git a/buf.yaml b/buf.yaml index 5a5af8e..5997ce7 100644 --- a/buf.yaml +++ b/buf.yaml @@ -2,3 +2,4 @@ version: v1 lint: except: - PACKAGE_DIRECTORY_MATCH + - DIRECTORY_SAME_PACKAGE diff --git a/build.rs b/build.rs index 1d1dd34..20cb68b 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,7 @@ fn main() -> Result<(), Box> { + let protos = ["proto/command_v1.proto", "proto/tunnel_v1.proto"]; tonic_build::configure() .protoc_arg("--experimental_allow_proto3_optional") - .compile(&["proto/command_v1.proto"], &["proto"])?; + .compile(&protos, &["proto"])?; Ok(()) } diff --git a/proto/command_v1.proto b/proto/command_v1.proto index d5fe531..cb74c3a 100644 --- a/proto/command_v1.proto +++ b/proto/command_v1.proto @@ -10,16 +10,17 @@ message AddPeerResponse { string peer_id = 1; } -message RequestHttpServerRequest { +message CreateTunnelRequest { string peer_id = 1; - bytes data = 2; + optional string address = 2; } -message RequestHttpServerResponse { - bytes data = 1; +message CreateTunnelResponse { + string peer_id = 1; + string address = 2; } service CommandService { rpc AddPeer(AddPeerRequest) returns (AddPeerResponse); - rpc RequestHttpServer(RequestHttpServerRequest) returns (RequestHttpServerResponse); + rpc CreateTunnel(CreateTunnelRequest) returns (CreateTunnelResponse); } diff --git a/proto/tunnel_v1.proto b/proto/tunnel_v1.proto new file mode 100644 index 0000000..411e018 --- /dev/null +++ b/proto/tunnel_v1.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; +package tunnel.v1; + +enum TunnelCommand { + TUNNEL_COMMAND_UNSPECIFIED = 0; + TUNNEL_COMMAND_CONNECT = 1; + TUNNEL_COMMAND_CONNECT_RESP = 2; + TUNNEL_COMMAND_INBOUND_PACKAGE = 3; + TUNNEL_COMMAND_OUTBOUND_PACKAGE = 4; +} + +message Tunnel { + string tunnel_id = 1; + TunnelCommand command = 2; + optional bytes data = 3; +} diff --git a/src/command.rs b/src/command.rs index f300117..a4ea192 100644 --- a/src/command.rs +++ b/src/command.rs @@ -38,14 +38,14 @@ impl proto::command_service_server::CommandService for PProxyCommander { .map_err(|e| tonic::Status::internal(format!("{:?}", e))) } - async fn request_http_server( + async fn create_tunnel( &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status> { + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { trace!("handle request: {:?}", request); self.handle - .request_http_server(request.into_inner()) + .create_tunnel(request.into_inner()) .await .map(Response::new) .map_err(|e| tonic::Status::internal(format!("{:?}", e))) diff --git a/src/error.rs b/src/error.rs index 7bfce44..2bb5deb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,7 @@ +use std::io::ErrorKind as IOErrorKind; + #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum Error { #[error("Multiaddr {0} parse error")] MultiaddrParseError(String), @@ -22,6 +25,33 @@ pub enum Error { Io(#[from] std::io::Error), #[error("Unexpected response type")] UnexpectedResponseType, + #[error("Tunnel error: {0:?}")] + Tunnel(TunnelError), + #[error("Protobuf decode error: {0}")] + ProtobufDecode(#[from] prost::DecodeError), +} + +/// A list specifying general categories of Tunnel error like [std::io::ErrorKind]. +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +#[non_exhaustive] +pub enum TunnelError { + /// Failed to send data to peer. + DataSendFailed = 1, + /// The connection timed out when dialing. + ConnectionTimeout = 2, + /// Got [std::io::ErrorKind::ConnectionRefused] error from local stream. + ConnectionRefused = 3, + /// Got [std::io::ErrorKind::ConnectionAborted] error from local stream. + ConnectionAborted = 4, + /// Got [std::io::ErrorKind::ConnectionReset] error from local stream. + ConnectionReset = 5, + /// Got [std::io::ErrorKind::NotConnected] error from local stream. + NotConnected = 6, + /// The connection is closed by peer. + ConnectionClosed = 7, + /// Unknown [std::io::ErrorKind] error. + Unknown = u8::MAX, } impl From> for Error { @@ -41,3 +71,21 @@ impl From for Error { Error::Litep2pRequestResponseError(err) } } + +impl From for Error { + fn from(err: TunnelError) -> Self { + Error::Tunnel(err) + } +} + +impl From for TunnelError { + fn from(kind: IOErrorKind) -> TunnelError { + match kind { + IOErrorKind::ConnectionRefused => TunnelError::ConnectionRefused, + IOErrorKind::ConnectionAborted => TunnelError::ConnectionAborted, + IOErrorKind::ConnectionReset => TunnelError::ConnectionReset, + IOErrorKind::NotConnected => TunnelError::NotConnected, + _ => TunnelError::Unknown, + } + } +} diff --git a/src/gateway.rs b/src/gateway.rs deleted file mode 100644 index f79b97b..0000000 --- a/src/gateway.rs +++ /dev/null @@ -1,279 +0,0 @@ -use std::convert::Infallible; -use std::net::SocketAddr; -use std::str::FromStr; - -use base64::prelude::*; -use bytes::Bytes; -use http_body::Frame; -use http_body_util::combinators::BoxBody; -use http_body_util::BodyExt; -use http_body_util::Full; -use http_body_util::StreamBody; -use hyper::server::conn::http1; -use hyper::service::service_fn; -use hyper::Request; -use hyper::Response; -use hyper_util::rt::TokioIo; -use multiaddr::Multiaddr; -use tokio::net::TcpListener; - -use crate::command::proto::command_service_client::CommandServiceClient; -use crate::command::proto::AddPeerRequest; -use crate::command::proto::RequestHttpServerRequest; - -pub async fn proxy_gateway( - addr: SocketAddr, - commander_addr: SocketAddr, - peer_multiaddr: Multiaddr, -) -> Result<(), Box> { - let listener = TcpListener::bind(addr).await?; - - let mut client = CommandServiceClient::connect(format!("http://{}", commander_addr)).await?; - let pp_response = client - .add_peer(AddPeerRequest { - address: peer_multiaddr.to_string(), - peer_id: None, - }) - .await? - .into_inner(); - - loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - let peer_id = pp_response.peer_id.clone(); - - tokio::task::spawn(async move { - let peer_id = peer_id.clone(); - if let Err(err) = http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .serve_connection( - io, - service_fn(move |req| gateway(req, commander_addr, peer_id.clone())), - ) - .await - { - println!("Failed to serve connection: {:?}", err); - } - }); - } -} - -async fn gateway( - req: Request, - commander_addr: SocketAddr, - peer_id: String, -) -> Result>, hyper::Error> { - let Ok(mut data) = request_header_to_vec(&req) else { - return Ok(error_response( - "Failed to dump headers", - http::StatusCode::INTERNAL_SERVER_ERROR, - )); - }; - - let body = req.into_body().collect().await?; - data.extend(body.to_bytes()); - - let Ok(mut client) = CommandServiceClient::connect(format!("http://{}", commander_addr)).await - else { - return Ok(error_response( - "Failed to connect to pproxy", - http::StatusCode::BAD_GATEWAY, - )); - }; - - let pp_request = RequestHttpServerRequest { peer_id, data }; - let Ok(pp_response) = client - .request_http_server(pp_request) - .await - .map(|r| r.into_inner()) - else { - return Ok(error_response( - "Failed to request pproxy", - http::StatusCode::BAD_GATEWAY, - )); - }; - - let Ok(Some((resp, trailer))) = parse_response_header_easy(&pp_response.data) else { - return Ok(error_response( - "Failed to parse response headers from pproxy", - http::StatusCode::INTERNAL_SERVER_ERROR, - )); - }; - - let (parts, _) = resp.into_parts(); - - // Handle chunked encoding - if let Some(true) = parts.headers.get(http::header::TRANSFER_ENCODING).map(|v| { - let Ok(v) = v.to_str() else { return false }; - v.to_lowercase().contains("chunked") - }) { - let chunks = split_chunked_body(trailer); - return Ok(Response::from_parts(parts, stream_body(chunks))); - } - - let t = trailer.to_vec(); - Ok(Response::from_parts(parts, full(t))) -} - -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} - -fn stream_body(chunks: Vec, Infallible>>) -> BoxBody { - StreamBody::new(futures_util::stream::iter(chunks)) - .map_err(|never| match never {}) - .boxed() -} - -fn error_response( - msg: &'static str, - status: http::StatusCode, -) -> Response> { - let mut resp = Response::new(full(msg)); - *resp.status_mut() = status; - resp -} - -fn io_other_error(msg: &'static str) -> std::io::Error { - let e: Box = msg.into(); - std::io::Error::new(std::io::ErrorKind::Other, e) -} - -fn write_request_header( - r: &http::Request, - mut io: impl std::io::Write, -) -> std::io::Result { - let mut len = 0; - let verb = r.method().as_str(); - let path = r - .uri() - .path_and_query() - .ok_or_else(|| io_other_error("Invalid URI"))?; - - let need_to_insert_host = - r.uri().host().is_some() && !r.headers().contains_key(http::header::HOST); - - macro_rules! w { - ($x:expr) => { - io.write_all($x)?; - len += $x.len(); - }; - } - w!(verb.as_bytes()); - w!(b" "); - w!(path.as_str().as_bytes()); - w!(b" HTTP/1.1\r\n"); - - if need_to_insert_host { - w!(b"Host: "); - let host = r.uri().host().unwrap(); - w!(host.as_bytes()); - if let Some(p) = r.uri().port() { - w!(b":"); - w!(p.as_str().as_bytes()); - } - w!(b"\r\n"); - } - - let already_present = r.headers().get(http::header::AUTHORIZATION).is_some(); - let at_sign = r - .uri() - .authority() - .map_or(false, |x| x.as_str().contains('@')); - if !already_present && at_sign { - w!(b"Authorization: Basic "); - let a = r.uri().authority().unwrap().as_str(); - let a = &a[0..(a.find('@').unwrap())]; - let a = a - .as_bytes() - .split(|v| *v == b':') - .map(|v| percent_encoding::percent_decode(v).collect::>()) - .collect::>>() - .join(&b':'); - let a = BASE64_STANDARD.encode(a); - w!(a.as_bytes()); - w!(b"\r\n"); - } - - for (hn, hv) in r.headers() { - w!(hn.as_str().as_bytes()); - w!(b": "); - w!(hv.as_bytes()); - w!(b"\r\n"); - } - - w!(b"\r\n"); - - Ok(len) -} - -fn request_header_to_vec(r: &http::Request) -> std::io::Result> { - let v = Vec::with_capacity(120); - let mut c = std::io::Cursor::new(v); - write_request_header(r, &mut c)?; - Ok(c.into_inner()) -} - -#[allow(clippy::type_complexity)] -fn parse_response_header<'a>( - buf: &'a [u8], - headers_buffer: &mut [httparse::Header<'a>], -) -> Result, &'a [u8])>, Box> { - let mut x = httparse::Response::new(headers_buffer); - let n = match x.parse(buf)? { - httparse::Status::Partial => return Ok(None), - httparse::Status::Complete(size) => size, - }; - let trailer = &buf[n..]; - let mut r = Response::new(()); - *r.status_mut() = http::StatusCode::from_u16(x.code.unwrap())?; - // x.reason goes to nowhere - *r.version_mut() = http::Version::HTTP_11; // FIXME? - for h in x.headers { - let n = http::HeaderName::from_str(h.name)?; - let v = http::HeaderValue::from_bytes(h.value)?; - r.headers_mut().append(n, v); - } - Ok(Some((r, trailer))) -} - -#[allow(clippy::type_complexity)] -fn parse_response_header_easy( - buf: &[u8], -) -> Result, &[u8])>, Box> { - let mut h = [httparse::EMPTY_HEADER; 50]; - parse_response_header(buf, h.as_mut()) -} - -fn split_chunked_body(input: &[u8]) -> Vec, Infallible>> { - use std::io::BufRead; - use std::io::BufReader; - use std::io::Cursor; - use std::io::Read; - - let mut reader = BufReader::new(Cursor::new(input)); - let mut chunks = Vec::new(); - - loop { - let mut chunk_size = String::new(); - let read = reader.read_line(&mut chunk_size).unwrap(); - if read == 0 { - break; - } - - let chunk_size = usize::from_str_radix(chunk_size.trim_end(), 16).unwrap(); - - let mut chunk = vec![0; chunk_size]; - reader.read_exact(&mut chunk).unwrap(); - - let mut crlf = [0; 2]; - reader.read_exact(&mut crlf).unwrap(); - - chunks.push(Ok(Frame::data(Bytes::from(chunk)))); - } - - chunks -} diff --git a/src/lib.rs b/src/lib.rs index a242bf1..4d32298 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::time::Duration; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use futures::channel::oneshot; use litep2p::crypto::ed25519::SecretKey; @@ -10,23 +11,25 @@ use litep2p::types::RequestId; use litep2p::PeerId; use multiaddr::Multiaddr; use multiaddr::Protocol; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::net::TcpStream; +use prost::Message; use tokio::sync::mpsc; -use tokio::time::timeout; use tracing::warn; use crate::command::proto::AddPeerRequest; use crate::command::proto::AddPeerResponse; -use crate::command::proto::RequestHttpServerRequest; -use crate::command::proto::RequestHttpServerResponse; +use crate::command::proto::CreateTunnelRequest; +use crate::command::proto::CreateTunnelResponse; use crate::server::*; +use crate::tunnel::proto; +use crate::tunnel::tcp_connect_with_timeout; +use crate::tunnel::Tunnel; +use crate::types::*; pub mod command; pub mod error; -pub mod gateway; mod server; +mod tunnel; +pub mod types; /// pproxy version pub const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -34,6 +37,9 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Default channel size. const DEFAULT_CHANNEL_SIZE: usize = 4096; +/// Timeout for proxied TCP connections +pub const TCP_SERVER_TIMEOUT: u64 = 30; + /// Public result type error type used by the crate. pub use crate::error::Error; pub type Result = std::result::Result; @@ -44,12 +50,12 @@ type CommandNotifier = oneshot::Sender; #[derive(Debug)] pub enum PProxyCommand { AddPeer(AddPeerRequest), - RequestHttpServer(RequestHttpServerRequest), + CreateTunnel(CreateTunnelRequest), } pub enum PProxyCommandResponse { AddPeer(AddPeerResponse), - RequestHttpServer(RequestHttpServerResponse), + CreateTunnel(CreateTunnelResponse), } pub struct PProxy { @@ -57,6 +63,9 @@ pub struct PProxy { tunnel_notifier: HashMap, p2p_server: P2pServer, proxy_addr: Option, + next_tunnel_id: AtomicUsize, + inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, + outbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, } pub struct PProxyHandle { @@ -94,11 +103,18 @@ impl PProxy { tunnel_notifier: HashMap::new(), p2p_server: P2pServer::new(secret_key, server_addr), proxy_addr, + next_tunnel_id: Default::default(), + inbound_tunnels: HashMap::new(), + outbound_tunnels: HashMap::new(), }, PProxyHandle { command_tx }, ) } + fn next_tunnel_id(&mut self) -> TunnelId { + TunnelId::from(self.next_tunnel_id.fetch_add(1usize, Ordering::Relaxed)) + } + pub async fn run(mut self) { loop { tokio::select! { @@ -109,7 +125,6 @@ impl PProxy { None => return, Some(ref event) => if let Err(error) = self.handle_p2p_server_event(event).await { warn!("failed to handle event {:?}: {:?}", event, error); - } }, @@ -126,101 +141,58 @@ impl PProxy { async fn handle_p2p_server_event(&mut self, event: &P2pServerEvent) -> Result<()> { match event { P2pServerEvent::TunnelEvent(RequestResponseEvent::RequestReceived { + peer, request_id, request, .. }) => { - let Some(proxy_addr) = self.proxy_addr else { - return Err(Error::ProtocolNotSupport); - }; - - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut req = httparse::Request::new(&mut headers); - if req.parse(request)?.is_partial() { - return Err(Error::IncompleteHttpRequest); - } - - let mut stream = tcp_connect_with_timeout(&proxy_addr, 60).await?; - stream.write_all(request).await?; - - let mut response = Vec::new(); - let mut full_length = FullLength::NotParsed; - - loop { - let mut buf = [0u8; 30000]; - - match timeout(std::time::Duration::from_secs(60), stream.read(&mut buf)).await { - Ok(Ok(0)) => { - warn!("empty response from http server"); - break; - } - Ok(Ok(n)) => { - response.extend_from_slice(&buf[..n]); - } - x => { - warn!("http stream read failed {x:?}"); - break; - } + let msg = proto::Tunnel::decode(request.as_slice())?; + match msg.command() { + proto::TunnelCommand::Connect => { + let Some(proxy_addr) = self.proxy_addr else { + return Err(Error::ProtocolNotSupport); + }; + + let Ok(tunnel_id) = msg.tunnel_id.parse::() else { + return Err(Error::ProtocolNotSupport); + }; + let tunnel_id = TunnelId::from(tunnel_id); + + let stream = tcp_connect_with_timeout(proxy_addr, 60).await?; + let mut tunnel = Tunnel::new(*peer, tunnel_id); + tunnel.listen(stream).await; + + self.inbound_tunnels.insert((*peer, tunnel_id), tunnel); + + let response = proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::ConnectResp.into(), + data: None, + }; + + self.p2p_server + .tunnel_handle + .send_response(*request_id, response.encode_to_vec()); } - if full_length.not_parsed() { - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut resp_checker = httparse::Response::new(&mut headers); - let res = resp_checker.parse(&response)?; - if res.is_complete() { - let content_length = resp_checker.headers.iter().find_map(|h| { - if h.name.to_lowercase() != "content-length" { - return None; - } - let Ok(value) = std::str::from_utf8(h.value) else { - return None; - }; - value.parse::().ok() - }); - - let transfor_encoding = resp_checker.headers.iter().find_map(|h| { - if h.name.to_lowercase() != "transfer-encoding" { - return None; - } - let Ok(value) = std::str::from_utf8(h.value) else { - return None; - }; - Some(value) - }); - - match (content_length, transfor_encoding) { - (Some(content_length), _) => { - let header_length = res.unwrap(); - full_length = FullLength::Parsed(header_length + content_length) - } - (None, Some(value)) if value.to_lowercase().contains("chunked") => { - full_length = FullLength::Chunked; - } - _ => { - full_length = FullLength::NotSet; - } - } - } + proto::TunnelCommand::InboundPackage => { + // Have to do this to close the response waiter in remote. + self.p2p_server + .tunnel_handle + .send_response(*request_id, vec![]); } - if let FullLength::Parsed(full_length) = full_length { - if response.len() >= full_length { - break; - } + proto::TunnelCommand::OutboundPackage => { + // Have to do this to close the response waiter in remote. + self.p2p_server + .tunnel_handle + .send_response(*request_id, vec![]); } - if full_length.chunked() && response.ends_with(b"0\r\n\r\n") { - break; + _ => { + return Err(Error::ProtocolNotSupport); } } - - if response.is_empty() { - response = b"HTTP/1.1 500 Internal Server Error\r\n\r\n".to_vec(); - } - - self.p2p_server - .tunnel_handle - .send_response(*request_id, response); } P2pServerEvent::TunnelEvent(RequestResponseEvent::RequestFailed { @@ -237,19 +209,19 @@ impl PProxy { } P2pServerEvent::TunnelEvent(RequestResponseEvent::ResponseReceived { + peer, request_id, - response, .. }) => { - let Some(tx) = self.tunnel_notifier.remove(request_id) else { - todo!(); - }; - tx.send(Ok(PProxyCommandResponse::RequestHttpServer( - RequestHttpServerResponse { - data: response.clone(), - }, - ))) - .map_err(|_| Error::EssentialTaskClosed)?; + if let Some(tx) = self.tunnel_notifier.remove(request_id) { + tx.send(Ok(PProxyCommandResponse::CreateTunnel( + CreateTunnelResponse { + peer_id: peer.to_string(), + address: String::new(), + }, + ))) + .map_err(|_| Error::EssentialTaskClosed)?; + } } _ => {} } @@ -260,8 +232,8 @@ impl PProxy { async fn handle_command(&mut self, command: &PProxyCommand, tx: CommandNotifier) -> Result<()> { match command { PProxyCommand::AddPeer(request) => self.on_add_peer(request.clone(), tx).await, - PProxyCommand::RequestHttpServer(request) => { - self.on_request_http_server(request.clone(), tx).await + PProxyCommand::CreateTunnel(request) => { + self.on_create_tunnel(request.clone(), tx).await } } } @@ -291,9 +263,9 @@ impl PProxy { .map_err(|_| Error::EssentialTaskClosed) } - async fn on_request_http_server( + async fn on_create_tunnel( &mut self, - request: RequestHttpServerRequest, + request: CreateTunnelRequest, rx: CommandNotifier, ) -> Result<()> { let peer_id = request @@ -301,10 +273,18 @@ impl PProxy { .parse() .map_err(|_| Error::PeerIdParseError(request.peer_id.clone()))?; + let tunnel_id = self.next_tunnel_id(); + + let msg = proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::Connect.into(), + data: None, + }; + let request_id = self .p2p_server - .tunnel_handle - .send_request(peer_id, request.data.clone(), DialOptions::Dial) + .control_handle + .send_request(peer_id, msg.encode_to_vec(), DialOptions::Dial) .await?; self.tunnel_notifier.insert(request_id, rx); @@ -329,26 +309,20 @@ impl PProxyHandle { } } - pub async fn request_http_server( + pub async fn create_tunnel( &self, - request: RequestHttpServerRequest, - ) -> Result { - let mut headers = [httparse::EMPTY_HEADER; 1024]; - let mut req = httparse::Request::new(&mut headers); - if req.parse(&request.data)?.is_partial() { - return Err(Error::IncompleteHttpRequest); - } - + request: CreateTunnelRequest, + ) -> Result { let (tx, rx) = oneshot::channel(); self.command_tx - .send((PProxyCommand::RequestHttpServer(request), tx)) + .send((PProxyCommand::CreateTunnel(request), tx)) .await?; let response = rx.await??; match response { - PProxyCommandResponse::RequestHttpServer(response) => Ok(response), + PProxyCommandResponse::CreateTunnel(response) => Ok(response), _ => Err(Error::UnexpectedResponseType), } } @@ -366,20 +340,3 @@ fn extract_peer_id_from_multiaddr(multiaddr: &Multiaddr) -> Result { PeerId::from_multihash(multihash) .map_err(|_| Error::FailedToExtractPeerIdFromMultiaddr(multiaddr.to_string())) } - -pub async fn tcp_connect_with_timeout( - addr: &SocketAddr, - request_timeout_s: u64, -) -> Result { - match timeout( - Duration::from_secs(request_timeout_s), - TcpStream::connect(addr), - ) - .await - { - Ok(result) => result.map_err(Error::Io), - Err(_) => Err(Error::Io(std::io::Error::from( - std::io::ErrorKind::TimedOut, - ))), - } -} diff --git a/src/main.rs b/src/main.rs index 2cb5704..e018ff7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,16 @@ +use std::net::SocketAddr; + use clap::Arg; use clap::ArgAction; use clap::ArgMatches; use clap::Command; +use dephy_pproxy::command::proto::command_service_client::CommandServiceClient; +use dephy_pproxy::command::proto::AddPeerRequest; +use dephy_pproxy::command::proto::CreateTunnelRequest; use dephy_pproxy::command::PProxyCommander; -use dephy_pproxy::gateway::proxy_gateway; use dephy_pproxy::PProxy; use litep2p::crypto::ed25519::SecretKey; +use multiaddr::Multiaddr; use tonic::transport::Server; fn parse_args() -> Command { @@ -46,7 +51,7 @@ fn parse_args() -> Command { .help("Will reverse proxy this address if set"), ); - let gateway = Command::new("gateway") + let create_tunnel = Command::new("create_tunnel") .about("Set up a local server that allows users proxy data to remote peer") .arg( Arg::new("COMMANDER_SERVER_ADDR") @@ -57,23 +62,22 @@ fn parse_args() -> Command { .help("Commander server address"), ) .arg( - Arg::new("PROXY_GATEWAY_ADDR") - .long("proxy-gateway-addr") + Arg::new("TUNNEL_ADDR") + .long("tunnel-addr") .num_args(1) - .default_value("127.0.0.1:10000") .action(ArgAction::Set) - .help("Set up a local HTTP server that allows users use peerid header to proxy requests to remote peer"), + .help("Local server address, if not set a random port will be used"), ) .arg( Arg::new("PEER_MULTIADDR") - .long("peer-multiaddr") - .num_args(1) - .action(ArgAction::Set) - .required(true) - .help("The multiaddr of the remote peer"), + .long("peer-multiaddr") + .num_args(1) + .action(ArgAction::Set) + .required(true) + .help("The multiaddr of the remote peer"), ); - app = app.subcommand(serve).subcommand(gateway); + app = app.subcommand(serve).subcommand(create_tunnel); app } @@ -122,26 +126,48 @@ async fn serve(args: &ArgMatches) { ); } -async fn gateway(args: &ArgMatches) { +async fn create_tunnel(args: &ArgMatches) { let commander_server_addr = args .get_one::("COMMANDER_SERVER_ADDR") .unwrap() - .parse() + .parse::() .expect("Invalid command server address"); - let proxy_gateway_addr = args - .get_one::("PROXY_GATEWAY_ADDR") - .unwrap() - .parse() - .expect("Invalid proxy gateway address"); + let tunnel_addr = args.get_one::("TUNNEL_ADDR").map(|addr| { + addr.parse::() + .expect("Invalid tunnel address") + .to_string() + }); let peer_multiaddr = args .get_one::("PEER_MULTIADDR") .unwrap() - .parse() + .parse::() .expect("Missing peer multiaddr"); - println!("proxy_gateway_addr: {}", proxy_gateway_addr); - proxy_gateway(proxy_gateway_addr, commander_server_addr, peer_multiaddr) + + let mut client = CommandServiceClient::connect(format!("http://{}", commander_server_addr)) .await - .expect("Gateway server failed") + .expect("Connect to commander server failed"); + + let pp_response = client + .add_peer(AddPeerRequest { + address: peer_multiaddr.to_string(), + peer_id: None, + }) + .await + .expect("Add peer failed") + .into_inner(); + + let peer_id = pp_response.peer_id; + + let pp_response = client + .create_tunnel(CreateTunnelRequest { + peer_id, + address: tunnel_addr, + }) + .await + .expect("Create tunnel failed") + .into_inner(); + + println!("tunnel_addr: {}", pp_response.address); } #[tokio::main] @@ -156,7 +182,7 @@ async fn main() { serve(args).await; } - if let Some(args) = args.subcommand_matches("gateway") { - gateway(args).await; + if let Some(args) = args.subcommand_matches("create_tunnel") { + create_tunnel(args).await; } } diff --git a/src/tunnel.rs b/src/tunnel.rs new file mode 100644 index 0000000..ceac1f4 --- /dev/null +++ b/src/tunnel.rs @@ -0,0 +1,188 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use bytes::Bytes; +use litep2p::PeerId; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tokio_util::sync::CancellationToken; + +use crate::error::TunnelError; +use crate::types::TunnelId; + +pub mod proto { + tonic::include_proto!("tunnel.v1"); +} + +/// Abstract Tcp Tunnel +pub struct Tunnel { + peer_id: PeerId, + tunnel_id: TunnelId, + remote_stream_tx: Option>, + listener_cancel_token: Option, + listener: Option>, +} + +/// Listener of Tcp Tunnel, contains a mpsc channel +pub struct TunnelListener { + peer_id: PeerId, + tunnel_id: TunnelId, + local_stream: TcpStream, + remote_stream_tx: mpsc::Sender, + remote_stream_rx: mpsc::Receiver, + cancel_token: CancellationToken, +} + +impl Drop for Tunnel { + fn drop(&mut self) { + if let Some(cancel_token) = self.listener_cancel_token.take() { + cancel_token.cancel(); + } + + if let Some(listener) = self.listener.take() { + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + listener.abort(); + }); + } + + tracing::info!("Tunnel {}-{} dropped", self.peer_id, self.tunnel_id); + } +} + +impl Tunnel { + /// Create a new tunnel with a given tunnel Id + pub fn new(peer_id: PeerId, tunnel_id: TunnelId) -> Self { + Self { + peer_id, + tunnel_id, + remote_stream_tx: None, + listener: None, + listener_cancel_token: None, + } + } + + /// Send bytes to tunnel via channel + pub async fn send(&self, bytes: Bytes) { + if let Some(ref tx) = self.remote_stream_tx { + let _ = tx.send(bytes).await; + } else { + tracing::error!( + "Tunnel {}-{} remote stream tx is none", + self.peer_id, + self.tunnel_id + ); + } + } + + /// Start listen a local stream, this function will spawn a thread which + /// listening the inbound messages + pub async fn listen(&mut self, local_stream: TcpStream) { + if self.listener.is_some() { + return; + } + let mut listener = TunnelListener::new(self.peer_id, self.tunnel_id, local_stream).await; + let listener_cancel_token = listener.cancel_token(); + let remote_stream_tx = listener.remote_stream_tx.clone(); + let listener_handler = tokio::spawn(Box::pin(async move { listener.listen().await })); + + self.remote_stream_tx = Some(remote_stream_tx); + self.listener = Some(listener_handler); + self.listener_cancel_token = Some(listener_cancel_token); + } +} + +impl TunnelListener { + /// Create a new listener instance with TcpStream, tunnel id, and did of a target peer + async fn new(peer_id: PeerId, tunnel_id: TunnelId, local_stream: TcpStream) -> Self { + let (remote_stream_tx, remote_stream_rx) = mpsc::channel(1024); + Self { + peer_id, + tunnel_id, + local_stream, + remote_stream_tx, + remote_stream_rx, + cancel_token: CancellationToken::new(), + } + } + + fn cancel_token(&self) -> CancellationToken { + self.cancel_token.clone() + } + + async fn listen(&mut self) { + let (mut local_read, mut local_write) = self.local_stream.split(); + + let listen_local = async { + loop { + if self.cancel_token.is_cancelled() { + break TunnelError::ConnectionClosed; + } + + let mut buf = [0u8; 30000]; + match local_read.read(&mut buf).await { + Err(e) => { + break e.kind().into(); + } + Ok(0) => { + break TunnelError::ConnectionClosed; + } + Ok(n) => { + let data = buf[..n].to_vec(); + let msg = proto::Tunnel { + tunnel_id: self.tunnel_id.to_string(), + command: proto::TunnelCommand::OutboundPackage.into(), + data: Some(data), + }; + } + } + } + }; + + let listen_remote = async { + loop { + if self.cancel_token.is_cancelled() { + break TunnelError::ConnectionClosed; + } + + if let Some(body) = self.remote_stream_rx.recv().await { + if let Err(e) = local_write.write_all(&body).await { + tracing::error!("Write to local stream failed: {e:?}"); + break e.kind().into(); + } + } + } + }; + + tokio::select! { + defeat = listen_local => { + tracing::info!("Local stream closed: {defeat:?}"); + }, + defeat = listen_remote => { + tracing::info!("Remote stream closed: {defeat:?}"); + } + } + } +} + +/// This function handle a tcp request with timeout +pub async fn tcp_connect_with_timeout( + addr: SocketAddr, + request_timeout_s: u64, +) -> Result { + let fut = tcp_connect(addr); + match timeout(Duration::from_secs(request_timeout_s), fut).await { + Ok(result) => result, + Err(_) => Err(TunnelError::ConnectionTimeout), + } +} + +async fn tcp_connect(addr: SocketAddr) -> Result { + match TcpStream::connect(addr).await { + Ok(o) => Ok(o), + Err(e) => Err(e.kind().into()), + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..9c8908c --- /dev/null +++ b/src/types.rs @@ -0,0 +1,16 @@ +use std::fmt::Display; + +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct TunnelId(usize); + +impl TunnelId { + pub fn from>(value: T) -> Self { + TunnelId(value.into()) + } +} + +impl Display for TunnelId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +}