Skip to content

Commit

Permalink
feat: close tunnel if remote broken
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma233 committed Aug 20, 2024
1 parent 7c1d52a commit cd77eeb
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 83 deletions.
1 change: 1 addition & 0 deletions proto/tunnel_v1.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ enum TunnelCommand {
TUNNEL_COMMAND_CONNECT = 1;
TUNNEL_COMMAND_CONNECT_RESP = 2;
TUNNEL_COMMAND_PACKAGE = 3;
TUNNEL_COMMAND_BREAK = 4;
}

message Tunnel {
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub enum Error {
Tunnel(TunnelError),
#[error("Protobuf decode error: {0}")]
ProtobufDecode(#[from] prost::DecodeError),
#[error("Access denied, peer: {0}")]
AccessDenied(String),
}

/// A list specifying general categories of Tunnel error like [std::io::ErrorKind].
Expand Down
200 changes: 125 additions & 75 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::command::proto::CreateTunnelServerRequest;
use crate::command::proto::CreateTunnelServerResponse;
use crate::command::proto::ExpirePeerAccessRequest;
use crate::command::proto::ExpirePeerAccessResponse;
use crate::error::TunnelError;
use crate::p2p::PProxyNetworkBehaviour;
use crate::p2p::PProxyNetworkBehaviourEvent;
use crate::tunnel::proto;
Expand Down Expand Up @@ -57,6 +58,7 @@ pub type Result<T> = std::result::Result<T, error::Error>;

type CommandNotification = Result<PProxyCommandResponse>;
type CommandNotifier = oneshot::Sender<CommandNotification>;
type ChannelPackage = std::result::Result<Vec<u8>, TunnelError>;

#[derive(Debug)]
pub enum PProxyCommand {
Expand All @@ -70,7 +72,7 @@ pub enum PProxyCommand {
SendConnectCommand {
peer_id: PeerId,
tunnel_id: TunnelId,
tunnel_tx: mpsc::Sender<Vec<u8>>,
tunnel_tx: mpsc::Sender<ChannelPackage>,
},
SendOutboundPackageCommand {
peer_id: PeerId,
Expand All @@ -97,7 +99,7 @@ pub struct PProxy {
proxy_addr: Option<SocketAddr>,
outbound_ready_notifiers: HashMap<request_response::OutboundRequestId, CommandNotifier>,
inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>,
tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender<Vec<u8>>>,
tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender<ChannelPackage>>,
access_client: Option<AccessClient>,
}

Expand Down Expand Up @@ -179,89 +181,110 @@ impl PProxy {
Ok(())
}

async fn is_tunnel_valid(&mut self, peer_id: &PeerId) -> bool {
let Some(ref mut ac) = self.access_client else {
return true;
};
ac.is_valid(peer_id).await
}

async fn handle_tunnel_request(
&mut self,
peer_id: PeerId,
request: proto::Tunnel,
) -> Result<Option<proto::Tunnel>> {
let tunnel_id = request
.tunnel_id
.parse()
.map_err(|_| Error::TunnelIdParseError(request.tunnel_id.clone()))?;

match request.command() {
proto::TunnelCommand::Connect => {
tracing::info!("received connect command from peer: {:?}", peer_id);
if !self.is_tunnel_valid(&peer_id).await {
return Err(Error::AccessDenied(peer_id.to_string()));
}

let Some(proxy_addr) = self.proxy_addr else {
return Err(Error::ProtocolNotSupport("No proxy_addr".to_string()));
};

let data = match self.dial_tunnel(proxy_addr, peer_id, tunnel_id).await {
Ok(_) => None,
Err(e) => {
tracing::warn!("failed to dial tunnel: {:?}", e);
Some(e.to_string().into_bytes())
}
};

Ok(Some(proto::Tunnel {
tunnel_id: tunnel_id.to_string(),
command: proto::TunnelCommand::ConnectResp.into(),
data,
}))
}

proto::TunnelCommand::Package => {
// Note that only inbound package need access check.
if self.inbound_tunnels.contains_key(&(peer_id, tunnel_id))
&& !self.is_tunnel_valid(&peer_id).await
{
return Err(Error::AccessDenied(peer_id.to_string()));
}

let Some(tx) = self.tunnel_txs.get(&(peer_id, tunnel_id)) else {
return Err(Error::ProtocolNotSupport(
"No tunnel for Package".to_string(),
));
};

tx.send(Ok(request.data.unwrap_or_default())).await?;

// Have to do this to close the response waiter in remote.
Ok(None)
}

_ => Err(Error::ProtocolNotSupport(
"Wrong tunnel request command".to_string(),
)),
}
}

async fn handle_p2p_server_event(
&mut self,
event: SwarmEvent<PProxyNetworkBehaviourEvent>,
) -> Result<()> {
tracing::debug!("received SwarmEvent: {:?}", event);

#[allow(clippy::single_match)]
match event {
SwarmEvent::NewListenAddr { mut address, .. } => {
address.push(Protocol::P2p(*self.swarm.local_peer_id()));
println!("Local node is listening on {address}");
}

SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse(
request_response::Event::Message { peer, message },
)) => match message {
request_response::Message::Request {
request, channel, ..
} => {
if let Some(ac) = &mut self.access_client {
if !ac.is_valid(&peer).await {
// TODO: Manage tunnel lifecycle
return Err(Error::Tunnel(error::TunnelError::ConnectionClosed));
}
}

match request.command() {
proto::TunnelCommand::Connect => {
tracing::info!("received connect command from peer: {:?}", peer);
let Some(proxy_addr) = self.proxy_addr else {
return Err(Error::ProtocolNotSupport("No proxy_addr".to_string()));
};

let tunnel_id = request
.tunnel_id
.parse()
.map_err(|_| Error::TunnelIdParseError(request.tunnel_id))?;

let data = match self.dial_tunnel(proxy_addr, peer, tunnel_id).await {
Ok(_) => None,
Err(e) => {
tracing::warn!("failed to dial tunnel: {:?}", e);
Some(e.to_string().into_bytes())
}
};

let response = proto::Tunnel {
tunnel_id: tunnel_id.to_string(),
command: proto::TunnelCommand::ConnectResp.into(),
data,
};

self.swarm
.behaviour_mut()
.request_response
.send_response(channel, Some(response))
.map_err(|_| Error::EssentialTaskClosed)?;
}

proto::TunnelCommand::Package => {
let tunnel_id = request
.tunnel_id
.parse()
.map_err(|_| Error::TunnelIdParseError(request.tunnel_id))?;

let Some(tx) = self.tunnel_txs.get(&(peer, tunnel_id)) else {
return Err(Error::ProtocolNotSupport(
"No tunnel for Package".to_string(),
));
};

tx.send(request.data.unwrap_or_default()).await?;

// Have to do this to close the response waiter in remote.
self.swarm
.behaviour_mut()
.request_response
.send_response(channel, None)
.map_err(|_| Error::EssentialTaskClosed)?;
}
let tunnel_id = request.tunnel_id.clone();
let resp = match self.handle_tunnel_request(peer, request).await {
Ok(resp) => resp,
Err(e) => Some(proto::Tunnel {
tunnel_id,
command: proto::TunnelCommand::Break.into(),
data: Some(e.to_string().as_bytes().to_vec()),
}),
};

_ => {
return Err(Error::ProtocolNotSupport(
"Wrong tunnel request command".to_string(),
));
}
}
self.swarm
.behaviour_mut()
.request_response
.send_response(channel, resp)
.map_err(|_| Error::EssentialTaskClosed)?;
}

request_response::Message::Response {
request_id,
response,
Expand All @@ -278,7 +301,7 @@ impl PProxy {
.remove(&request_id)
.ok_or_else(|| {
Error::TunnelNotWaiting(format!(
"peer {}, tunnel {}",
"peer {}, tunnel {} ready but not waiting",
peer, response.tunnel_id
))
})?;
Expand All @@ -293,6 +316,24 @@ impl PProxy {
.map_err(|_| Error::EssentialTaskClosed)?;
}

proto::TunnelCommand::Break => {
let tunnel_id = response.tunnel_id.parse().map_err(|_| {
Error::TunnelIdParseError(response.tunnel_id.clone())
})?;

// Terminat connecting tunnel
if let Some(tx) = self.outbound_ready_notifiers.remove(&request_id) {
tx.send(Err(Error::Tunnel(TunnelError::ConnectionAborted)))
.map_err(|_| Error::EssentialTaskClosed)?;
return Ok(());
}

// Terminat connected tunnel
if let Some(tx) = self.tunnel_txs.remove(&(peer, tunnel_id)) {
tx.send(Err(TunnelError::ConnectionAborted)).await?
};
}

_ => {
return Err(Error::ProtocolNotSupport(
"Wrong tunnel response command".to_string(),
Expand All @@ -302,9 +343,18 @@ impl PProxy {
}
},

SwarmEvent::NewListenAddr { mut address, .. } => {
address.push(Protocol::P2p(*self.swarm.local_peer_id()));
println!("Local node is listening on {address}");
SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse(
request_response::Event::OutboundFailure {
request_id, error, ..
},
)) => {
// Tell tunnel dial failed
if let Some(tx) = self.outbound_ready_notifiers.remove(&request_id) {
tx.send(Err(Error::TunnelDialFailed(error.to_string())))
.map_err(|_| Error::EssentialTaskClosed)?
}

// TODO: Should tell tunnel sent failed
}

_ => {}
Expand Down Expand Up @@ -370,7 +420,7 @@ impl PProxy {
&mut self,
peer_id: PeerId,
tunnel_id: TunnelId,
tunnel_tx: mpsc::Sender<Vec<u8>>,
tunnel_tx: mpsc::Sender<ChannelPackage>,
tx: CommandNotifier,
) -> Result<()> {
self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx);
Expand Down
25 changes: 17 additions & 8 deletions src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct TunnelListener {
peer_id: PeerId,
tunnel_id: TunnelId,
local_stream: TcpStream,
remote_stream_rx: mpsc::Receiver<Vec<u8>>,
remote_stream_rx: mpsc::Receiver<Result<Vec<u8>, TunnelError>>,
pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>,
cancel_token: CancellationToken,
}
Expand Down Expand Up @@ -225,7 +225,7 @@ impl Tunnel {
pub async fn listen(
&mut self,
local_stream: TcpStream,
remote_stream_rx: mpsc::Receiver<Vec<u8>>,
remote_stream_rx: mpsc::Receiver<Result<Vec<u8>, TunnelError>>,
) -> Result<(), TunnelError> {
if self.listener.is_some() {
return Err(TunnelError::AlreadyListened);
Expand Down Expand Up @@ -254,7 +254,7 @@ impl TunnelListener {
peer_id: PeerId,
tunnel_id: TunnelId,
local_stream: TcpStream,
remote_stream_rx: mpsc::Receiver<Vec<u8>>,
remote_stream_rx: mpsc::Receiver<Result<Vec<u8>, TunnelError>>,
pproxy_command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>,
) -> Self {
Self {
Expand Down Expand Up @@ -324,11 +324,20 @@ impl TunnelListener {
break TunnelError::ConnectionClosed;
}

if let Some(body) = self.remote_stream_rx.recv().await {
tracing::debug!("Received {} bytes from local stream", body.len());
if let Err(e) = local_write.write_all(&body).await {
tracing::error!("Write to local stream failed: {e:?}");
break e.kind().into();
let Some(data) = self.remote_stream_rx.recv().await else {
break TunnelError::ConnectionClosed;
};

match data {
Err(e) => {
break e;
}
Ok(body) => {
tracing::debug!("Received {} bytes from local stream", body.len());
if let Err(e) = local_write.write_all(&body).await {
tracing::error!("Write to local stream failed: {e:?}");
break e.kind().into();
}
}
}
}
Expand Down

0 comments on commit cd77eeb

Please sign in to comment.