diff --git a/proto/tunnel_v1.proto b/proto/tunnel_v1.proto index a8779dc..cd73d48 100644 --- a/proto/tunnel_v1.proto +++ b/proto/tunnel_v1.proto @@ -6,6 +6,7 @@ enum TunnelCommand { TUNNEL_COMMAND_CONNECT = 1; TUNNEL_COMMAND_CONNECT_RESP = 2; TUNNEL_COMMAND_PACKAGE = 3; + TUNNEL_COMMAND_BREAK = 4; } message Tunnel { diff --git a/src/error.rs b/src/error.rs index cd99e6b..22414a6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,6 +33,8 @@ pub enum Error { Tunnel(TunnelError), #[error("Protobuf decode error: {0}")] ProtobufDecode(#[from] prost::DecodeError), + #[error("Access denied, peer: {0}")] + AccessDenied(String), } /// A list specifying general categories of Tunnel error like [std::io::ErrorKind]. diff --git a/src/lib.rs b/src/lib.rs index 8aab8a5..e06f045 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ use crate::command::proto::CreateTunnelServerRequest; use crate::command::proto::CreateTunnelServerResponse; use crate::command::proto::ExpirePeerAccessRequest; use crate::command::proto::ExpirePeerAccessResponse; +use crate::error::TunnelError; use crate::p2p::PProxyNetworkBehaviour; use crate::p2p::PProxyNetworkBehaviourEvent; use crate::tunnel::proto; @@ -57,6 +58,7 @@ pub type Result = std::result::Result; type CommandNotification = Result; type CommandNotifier = oneshot::Sender; +type ChannelPackage = std::result::Result, TunnelError>; #[derive(Debug)] pub enum PProxyCommand { @@ -70,7 +72,7 @@ pub enum PProxyCommand { SendConnectCommand { peer_id: PeerId, tunnel_id: TunnelId, - tunnel_tx: mpsc::Sender>, + tunnel_tx: mpsc::Sender, }, SendOutboundPackageCommand { peer_id: PeerId, @@ -97,7 +99,7 @@ pub struct PProxy { proxy_addr: Option, outbound_ready_notifiers: HashMap, inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, - tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>>, + tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>, access_client: Option, } @@ -179,89 +181,110 @@ impl PProxy { Ok(()) } + async fn is_tunnel_valid(&mut self, peer_id: &PeerId) -> bool { + let Some(ref mut ac) = self.access_client else { + return true; + }; + ac.is_valid(peer_id).await + } + + async fn handle_tunnel_request( + &mut self, + peer_id: PeerId, + request: proto::Tunnel, + ) -> Result> { + let tunnel_id = request + .tunnel_id + .parse() + .map_err(|_| Error::TunnelIdParseError(request.tunnel_id.clone()))?; + + match request.command() { + proto::TunnelCommand::Connect => { + tracing::info!("received connect command from peer: {:?}", peer_id); + if !self.is_tunnel_valid(&peer_id).await { + return Err(Error::AccessDenied(peer_id.to_string())); + } + + let Some(proxy_addr) = self.proxy_addr else { + return Err(Error::ProtocolNotSupport("No proxy_addr".to_string())); + }; + + let data = match self.dial_tunnel(proxy_addr, peer_id, tunnel_id).await { + Ok(_) => None, + Err(e) => { + tracing::warn!("failed to dial tunnel: {:?}", e); + Some(e.to_string().into_bytes()) + } + }; + + Ok(Some(proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::ConnectResp.into(), + data, + })) + } + + proto::TunnelCommand::Package => { + // Note that only inbound package need access check. + if self.inbound_tunnels.contains_key(&(peer_id, tunnel_id)) + && !self.is_tunnel_valid(&peer_id).await + { + return Err(Error::AccessDenied(peer_id.to_string())); + } + + let Some(tx) = self.tunnel_txs.get(&(peer_id, tunnel_id)) else { + return Err(Error::ProtocolNotSupport( + "No tunnel for Package".to_string(), + )); + }; + + tx.send(Ok(request.data.unwrap_or_default())).await?; + + // Have to do this to close the response waiter in remote. + Ok(None) + } + + _ => Err(Error::ProtocolNotSupport( + "Wrong tunnel request command".to_string(), + )), + } + } + async fn handle_p2p_server_event( &mut self, event: SwarmEvent, ) -> Result<()> { tracing::debug!("received SwarmEvent: {:?}", event); - #[allow(clippy::single_match)] match event { + SwarmEvent::NewListenAddr { mut address, .. } => { + address.push(Protocol::P2p(*self.swarm.local_peer_id())); + println!("Local node is listening on {address}"); + } + SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse( request_response::Event::Message { peer, message }, )) => match message { request_response::Message::Request { request, channel, .. } => { - if let Some(ac) = &mut self.access_client { - if !ac.is_valid(&peer).await { - // TODO: Manage tunnel lifecycle - return Err(Error::Tunnel(error::TunnelError::ConnectionClosed)); - } - } - - match request.command() { - proto::TunnelCommand::Connect => { - tracing::info!("received connect command from peer: {:?}", peer); - let Some(proxy_addr) = self.proxy_addr else { - return Err(Error::ProtocolNotSupport("No proxy_addr".to_string())); - }; - - let tunnel_id = request - .tunnel_id - .parse() - .map_err(|_| Error::TunnelIdParseError(request.tunnel_id))?; - - let data = match self.dial_tunnel(proxy_addr, peer, tunnel_id).await { - Ok(_) => None, - Err(e) => { - tracing::warn!("failed to dial tunnel: {:?}", e); - Some(e.to_string().into_bytes()) - } - }; - - let response = proto::Tunnel { - tunnel_id: tunnel_id.to_string(), - command: proto::TunnelCommand::ConnectResp.into(), - data, - }; - - self.swarm - .behaviour_mut() - .request_response - .send_response(channel, Some(response)) - .map_err(|_| Error::EssentialTaskClosed)?; - } - - proto::TunnelCommand::Package => { - let tunnel_id = request - .tunnel_id - .parse() - .map_err(|_| Error::TunnelIdParseError(request.tunnel_id))?; - - let Some(tx) = self.tunnel_txs.get(&(peer, tunnel_id)) else { - return Err(Error::ProtocolNotSupport( - "No tunnel for Package".to_string(), - )); - }; - - tx.send(request.data.unwrap_or_default()).await?; - - // Have to do this to close the response waiter in remote. - self.swarm - .behaviour_mut() - .request_response - .send_response(channel, None) - .map_err(|_| Error::EssentialTaskClosed)?; - } + let tunnel_id = request.tunnel_id.clone(); + let resp = match self.handle_tunnel_request(peer, request).await { + Ok(resp) => resp, + Err(e) => Some(proto::Tunnel { + tunnel_id, + command: proto::TunnelCommand::Break.into(), + data: Some(e.to_string().as_bytes().to_vec()), + }), + }; - _ => { - return Err(Error::ProtocolNotSupport( - "Wrong tunnel request command".to_string(), - )); - } - } + self.swarm + .behaviour_mut() + .request_response + .send_response(channel, resp) + .map_err(|_| Error::EssentialTaskClosed)?; } + request_response::Message::Response { request_id, response, @@ -278,7 +301,7 @@ impl PProxy { .remove(&request_id) .ok_or_else(|| { Error::TunnelNotWaiting(format!( - "peer {}, tunnel {}", + "peer {}, tunnel {} ready but not waiting", peer, response.tunnel_id )) })?; @@ -293,6 +316,24 @@ impl PProxy { .map_err(|_| Error::EssentialTaskClosed)?; } + proto::TunnelCommand::Break => { + let tunnel_id = response.tunnel_id.parse().map_err(|_| { + Error::TunnelIdParseError(response.tunnel_id.clone()) + })?; + + // Terminat connecting tunnel + if let Some(tx) = self.outbound_ready_notifiers.remove(&request_id) { + tx.send(Err(Error::Tunnel(TunnelError::ConnectionAborted))) + .map_err(|_| Error::EssentialTaskClosed)?; + return Ok(()); + } + + // Terminat connected tunnel + if let Some(tx) = self.tunnel_txs.remove(&(peer, tunnel_id)) { + tx.send(Err(TunnelError::ConnectionAborted)).await? + }; + } + _ => { return Err(Error::ProtocolNotSupport( "Wrong tunnel response command".to_string(), @@ -302,9 +343,18 @@ impl PProxy { } }, - SwarmEvent::NewListenAddr { mut address, .. } => { - address.push(Protocol::P2p(*self.swarm.local_peer_id())); - println!("Local node is listening on {address}"); + SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse( + request_response::Event::OutboundFailure { + request_id, error, .. + }, + )) => { + // Tell tunnel dial failed + if let Some(tx) = self.outbound_ready_notifiers.remove(&request_id) { + tx.send(Err(Error::TunnelDialFailed(error.to_string()))) + .map_err(|_| Error::EssentialTaskClosed)? + } + + // TODO: Should tell tunnel sent failed } _ => {} @@ -370,7 +420,7 @@ impl PProxy { &mut self, peer_id: PeerId, tunnel_id: TunnelId, - tunnel_tx: mpsc::Sender>, + tunnel_tx: mpsc::Sender, tx: CommandNotifier, ) -> Result<()> { self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx); diff --git a/src/tunnel.rs b/src/tunnel.rs index 33d484b..184ef50 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -52,7 +52,7 @@ pub struct TunnelListener { peer_id: PeerId, tunnel_id: TunnelId, local_stream: TcpStream, - remote_stream_rx: mpsc::Receiver>, + remote_stream_rx: mpsc::Receiver, TunnelError>>, pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, cancel_token: CancellationToken, } @@ -225,7 +225,7 @@ impl Tunnel { pub async fn listen( &mut self, local_stream: TcpStream, - remote_stream_rx: mpsc::Receiver>, + remote_stream_rx: mpsc::Receiver, TunnelError>>, ) -> Result<(), TunnelError> { if self.listener.is_some() { return Err(TunnelError::AlreadyListened); @@ -254,7 +254,7 @@ impl TunnelListener { peer_id: PeerId, tunnel_id: TunnelId, local_stream: TcpStream, - remote_stream_rx: mpsc::Receiver>, + remote_stream_rx: mpsc::Receiver, TunnelError>>, pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, ) -> Self { Self { @@ -324,11 +324,20 @@ impl TunnelListener { break TunnelError::ConnectionClosed; } - if let Some(body) = self.remote_stream_rx.recv().await { - tracing::debug!("Received {} bytes from local stream", body.len()); - if let Err(e) = local_write.write_all(&body).await { - tracing::error!("Write to local stream failed: {e:?}"); - break e.kind().into(); + let Some(data) = self.remote_stream_rx.recv().await else { + break TunnelError::ConnectionClosed; + }; + + match data { + Err(e) => { + break e; + } + Ok(body) => { + tracing::debug!("Received {} bytes from local stream", body.len()); + if let Err(e) = local_write.write_all(&body).await { + tracing::error!("Write to local stream failed: {e:?}"); + break e.kind().into(); + } } } }