diff --git a/proto/tunnel_v1.proto b/proto/tunnel_v1.proto index cd73d48..1680642 100644 --- a/proto/tunnel_v1.proto +++ b/proto/tunnel_v1.proto @@ -6,7 +6,8 @@ enum TunnelCommand { TUNNEL_COMMAND_CONNECT = 1; TUNNEL_COMMAND_CONNECT_RESP = 2; TUNNEL_COMMAND_PACKAGE = 3; - TUNNEL_COMMAND_BREAK = 4; + TUNNEL_COMMAND_PACKAGE_RESP = 4; + TUNNEL_COMMAND_BREAK = 5; } message Tunnel { diff --git a/src/error.rs b/src/error.rs index 22414a6..a6c6421 100644 --- a/src/error.rs +++ b/src/error.rs @@ -29,6 +29,8 @@ pub enum Error { TunnelNotWaiting(String), #[error("Tunnel dial failed: {0}")] TunnelDialFailed(String), + #[error("")] + TunnelContextNotFound(String), #[error("Tunnel error: {0:?}")] Tunnel(TunnelError), #[error("Protobuf decode error: {0}")] diff --git a/src/lib.rs b/src/lib.rs index e9772da..4947436 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,11 @@ pub enum PProxyCommandResponse { ExpirePeerAccess {}, } +pub struct TunnelContext { + tx: mpsc::Sender, + outbound_sent_notifier: Option, +} + pub struct PProxy { command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>, command_rx: mpsc::Receiver<(PProxyCommand, CommandNotifier)>, @@ -99,7 +104,7 @@ pub struct PProxy { proxy_addr: Option, outbound_ready_notifiers: HashMap, inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, - tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>, + tunnel_ctx: HashMap<(PeerId, TunnelId), TunnelContext>, access_client: Option, } @@ -129,7 +134,7 @@ impl PProxy { proxy_addr, outbound_ready_notifiers: HashMap::new(), inbound_tunnels: HashMap::new(), - tunnel_txs: HashMap::new(), + tunnel_ctx: HashMap::new(), access_client, }, PProxyHandle { @@ -175,7 +180,10 @@ impl PProxy { tunnel.listen(stream, tunnel_rx).await?; self.inbound_tunnels.insert((peer_id, tunnel_id), tunnel); - self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx); + self.tunnel_ctx.insert((peer_id, tunnel_id), TunnelContext { + tx: tunnel_tx, + outbound_sent_notifier: None, + }); Ok(()) } @@ -191,7 +199,7 @@ impl PProxy { &mut self, peer_id: PeerId, request: proto::Tunnel, - ) -> Result> { + ) -> Result { let tunnel_id = request .tunnel_id .parse() @@ -216,11 +224,11 @@ impl PProxy { } }; - Ok(Some(proto::Tunnel { + Ok(proto::Tunnel { tunnel_id: tunnel_id.to_string(), command: proto::TunnelCommand::ConnectResp.into(), data, - })) + }) } proto::TunnelCommand::Package => { @@ -231,16 +239,20 @@ impl PProxy { return Err(Error::AccessDenied(peer_id.to_string())); } - let Some(tx) = self.tunnel_txs.get(&(peer_id, tunnel_id)) else { + let Some(ctx) = self.tunnel_ctx.get(&(peer_id, tunnel_id)) else { return Err(Error::ProtocolNotSupport( "No tunnel for Package".to_string(), )); }; - tx.send(Ok(request.data.unwrap_or_default())).await?; + ctx.tx.send(Ok(request.data.unwrap_or_default())).await?; // Have to do this to close the response waiter in remote. - Ok(None) + Ok(proto::Tunnel { + tunnel_id: tunnel_id.to_string(), + command: proto::TunnelCommand::PackageResp.into(), + data: None, + }) } _ => Err(Error::ProtocolNotSupport( @@ -263,7 +275,7 @@ impl PProxy { SwarmEvent::ConnectionClosed { peer_id, .. } => { self.inbound_tunnels.retain(|(p, _), _| p != &peer_id); - self.tunnel_txs.retain(|(p, _), _| p != &peer_id); + self.tunnel_ctx.retain(|(p, _), _| p != &peer_id); } SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse( @@ -278,13 +290,13 @@ impl PProxy { Err(e) => { if let Ok(tunnel_id) = tunnel_id.parse() { self.inbound_tunnels.remove(&(peer, tunnel_id)); - self.tunnel_txs.remove(&(peer, tunnel_id)); + self.tunnel_ctx.remove(&(peer, tunnel_id)); } - Some(proto::Tunnel { + proto::Tunnel { tunnel_id, command: proto::TunnelCommand::Break.into(), data: Some(e.to_string().as_bytes().to_vec()), - }) + } } }; @@ -299,11 +311,6 @@ impl PProxy { request_id, response, } => { - // This is response of TunnelCommand::Package - let Some(response) = response else { - return Ok(()); - }; - match response.command() { proto::TunnelCommand::ConnectResp => { let tx = self @@ -326,6 +333,28 @@ impl PProxy { .map_err(|_| Error::EssentialTaskClosed)?; } + proto::TunnelCommand::PackageResp => { + let tunnel_id = response.tunnel_id.parse().map_err(|_| { + Error::TunnelIdParseError(response.tunnel_id.clone()) + })?; + + let Some(ctx) = self.tunnel_ctx.get_mut(&(peer, tunnel_id)) else { + return Err(Error::ProtocolNotSupport( + "No ctx for Package".to_string(), + )); + }; + + let Some(notifier) = ctx.outbound_sent_notifier.take() else { + return Err(Error::ProtocolNotSupport( + "No notifier for Package".to_string(), + )); + }; + + notifier + .send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {})) + .map_err(|_| Error::EssentialTaskClosed)?; + } + proto::TunnelCommand::Break => { let tunnel_id = response.tunnel_id.parse().map_err(|_| { Error::TunnelIdParseError(response.tunnel_id.clone()) @@ -338,8 +367,8 @@ impl PProxy { } // Terminat connected tunnel - if let Some(tx) = self.tunnel_txs.remove(&(peer, tunnel_id)) { - tx.send(Err(TunnelError::ConnectionAborted)).await? + if let Some(ctx) = self.tunnel_ctx.remove(&(peer, tunnel_id)) { + ctx.tx.send(Err(TunnelError::ConnectionAborted)).await? }; } @@ -432,7 +461,10 @@ impl PProxy { tunnel_tx: mpsc::Sender, tx: CommandNotifier, ) -> Result<()> { - self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx); + self.tunnel_ctx.insert((peer_id, tunnel_id), TunnelContext { + tx: tunnel_tx, + outbound_sent_notifier: None, + }); let request = proto::Tunnel { tunnel_id: tunnel_id.to_string(), @@ -465,13 +497,23 @@ impl PProxy { data: Some(data), }; + let Some(ctx) = self.tunnel_ctx.get_mut(&(peer_id, tunnel_id)) else { + let err_msg = "No ctx for outbound package"; + + tx.send(Err(Error::TunnelContextNotFound(err_msg.to_string()))) + .map_err(|_| Error::EssentialTaskClosed)?; + + return Err(Error::TunnelContextNotFound(err_msg.to_string())); + }; + + ctx.outbound_sent_notifier = Some(tx); + self.swarm .behaviour_mut() .request_response .send_request(&peer_id, request); - tx.send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {})) - .map_err(|_| Error::EssentialTaskClosed) + Ok(()) } async fn on_expire_peer_access(&mut self, peer_id: PeerId, tx: CommandNotifier) -> Result<()> { diff --git a/src/p2p/codec.rs b/src/p2p/codec.rs index 5fcc2ca..b52e658 100644 --- a/src/p2p/codec.rs +++ b/src/p2p/codec.rs @@ -19,7 +19,7 @@ pub struct Codec; impl libp2p::request_response::Codec for Codec { type Protocol = StreamProtocol; type Request = proto::Tunnel; - type Response = Option; + type Response = proto::Tunnel; async fn read_request( &mut self, @@ -48,13 +48,7 @@ impl libp2p::request_response::Codec for Codec { io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?; - if vec.is_empty() { - return Ok(None); - } - - proto::Tunnel::decode(vec.as_slice()) - .map(Some) - .map_err(decode_into_io_error) + proto::Tunnel::decode(vec.as_slice()).map_err(decode_into_io_error) } async fn write_request( @@ -80,14 +74,8 @@ impl libp2p::request_response::Codec for Codec { where T: AsyncWrite + Unpin + Send, { - let mut data = vec![]; - - if let Some(resp) = resp { - data.extend_from_slice(resp.encode_to_vec().as_slice()); - }; - + let data = resp.encode_to_vec(); io.write_all(data.as_ref()).await?; - Ok(()) } } @@ -137,7 +125,7 @@ mod tests { let (mut a, mut b) = Endpoint::pair(124, 124); Codec - .write_response(&protocol, &mut a, Some(expected_response.clone())) + .write_response(&protocol, &mut a, expected_response.clone()) .await .expect("Should write response"); a.close().await.unwrap(); @@ -148,6 +136,6 @@ mod tests { .expect("Should read response"); b.close().await.unwrap(); - assert_eq!(actual_response, Some(expected_response)); + assert_eq!(actual_response, expected_response); } }