Skip to content

Commit

Permalink
feat: wiat for response before sending next package
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma233 committed Aug 29, 2024
1 parent 57a9746 commit ca254ac
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 41 deletions.
3 changes: 2 additions & 1 deletion proto/tunnel_v1.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ enum TunnelCommand {
TUNNEL_COMMAND_CONNECT = 1;
TUNNEL_COMMAND_CONNECT_RESP = 2;
TUNNEL_COMMAND_PACKAGE = 3;
TUNNEL_COMMAND_BREAK = 4;
TUNNEL_COMMAND_PACKAGE_RESP = 4;
TUNNEL_COMMAND_BREAK = 5;
}

message Tunnel {
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum Error {
TunnelNotWaiting(String),
#[error("Tunnel dial failed: {0}")]
TunnelDialFailed(String),
#[error("")]
TunnelContextNotFound(String),
#[error("Tunnel error: {0:?}")]
Tunnel(TunnelError),
#[error("Protobuf decode error: {0}")]
Expand Down
88 changes: 65 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,19 @@ pub enum PProxyCommandResponse {
ExpirePeerAccess {},
}

pub struct TunnelContext {
tx: mpsc::Sender<ChannelPackage>,
outbound_sent_notifier: Option<CommandNotifier>,
}

pub struct PProxy {
command_tx: mpsc::Sender<(PProxyCommand, CommandNotifier)>,
command_rx: mpsc::Receiver<(PProxyCommand, CommandNotifier)>,
swarm: Swarm<PProxyNetworkBehaviour>,
proxy_addr: Option<SocketAddr>,
outbound_ready_notifiers: HashMap<request_response::OutboundRequestId, CommandNotifier>,
inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>,
tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender<ChannelPackage>>,
tunnel_ctx: HashMap<(PeerId, TunnelId), TunnelContext>,
access_client: Option<AccessClient>,
}

Expand Down Expand Up @@ -129,7 +134,7 @@ impl PProxy {
proxy_addr,
outbound_ready_notifiers: HashMap::new(),
inbound_tunnels: HashMap::new(),
tunnel_txs: HashMap::new(),
tunnel_ctx: HashMap::new(),
access_client,
},
PProxyHandle {
Expand Down Expand Up @@ -175,7 +180,10 @@ impl PProxy {
tunnel.listen(stream, tunnel_rx).await?;

self.inbound_tunnels.insert((peer_id, tunnel_id), tunnel);
self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx);
self.tunnel_ctx.insert((peer_id, tunnel_id), TunnelContext {
tx: tunnel_tx,
outbound_sent_notifier: None,
});

Ok(())
}
Expand All @@ -191,7 +199,7 @@ impl PProxy {
&mut self,
peer_id: PeerId,
request: proto::Tunnel,
) -> Result<Option<proto::Tunnel>> {
) -> Result<proto::Tunnel> {
let tunnel_id = request
.tunnel_id
.parse()
Expand All @@ -216,11 +224,11 @@ impl PProxy {
}
};

Ok(Some(proto::Tunnel {
Ok(proto::Tunnel {
tunnel_id: tunnel_id.to_string(),
command: proto::TunnelCommand::ConnectResp.into(),
data,
}))
})
}

proto::TunnelCommand::Package => {
Expand All @@ -231,16 +239,20 @@ impl PProxy {
return Err(Error::AccessDenied(peer_id.to_string()));
}

let Some(tx) = self.tunnel_txs.get(&(peer_id, tunnel_id)) else {
let Some(ctx) = self.tunnel_ctx.get(&(peer_id, tunnel_id)) else {
return Err(Error::ProtocolNotSupport(
"No tunnel for Package".to_string(),
));
};

tx.send(Ok(request.data.unwrap_or_default())).await?;
ctx.tx.send(Ok(request.data.unwrap_or_default())).await?;

// Have to do this to close the response waiter in remote.
Ok(None)
Ok(proto::Tunnel {
tunnel_id: tunnel_id.to_string(),
command: proto::TunnelCommand::PackageResp.into(),
data: None,
})
}

_ => Err(Error::ProtocolNotSupport(
Expand All @@ -263,7 +275,7 @@ impl PProxy {

SwarmEvent::ConnectionClosed { peer_id, .. } => {
self.inbound_tunnels.retain(|(p, _), _| p != &peer_id);
self.tunnel_txs.retain(|(p, _), _| p != &peer_id);
self.tunnel_ctx.retain(|(p, _), _| p != &peer_id);
}

SwarmEvent::Behaviour(PProxyNetworkBehaviourEvent::RequestResponse(
Expand All @@ -278,13 +290,13 @@ impl PProxy {
Err(e) => {
if let Ok(tunnel_id) = tunnel_id.parse() {
self.inbound_tunnels.remove(&(peer, tunnel_id));
self.tunnel_txs.remove(&(peer, tunnel_id));
self.tunnel_ctx.remove(&(peer, tunnel_id));
}
Some(proto::Tunnel {
proto::Tunnel {
tunnel_id,
command: proto::TunnelCommand::Break.into(),
data: Some(e.to_string().as_bytes().to_vec()),
})
}
}
};

Expand All @@ -299,11 +311,6 @@ impl PProxy {
request_id,
response,
} => {
// This is response of TunnelCommand::Package
let Some(response) = response else {
return Ok(());
};

match response.command() {
proto::TunnelCommand::ConnectResp => {
let tx = self
Expand All @@ -326,6 +333,28 @@ impl PProxy {
.map_err(|_| Error::EssentialTaskClosed)?;
}

proto::TunnelCommand::PackageResp => {
let tunnel_id = response.tunnel_id.parse().map_err(|_| {
Error::TunnelIdParseError(response.tunnel_id.clone())
})?;

let Some(ctx) = self.tunnel_ctx.get_mut(&(peer, tunnel_id)) else {
return Err(Error::ProtocolNotSupport(
"No ctx for Package".to_string(),
));
};

let Some(notifier) = ctx.outbound_sent_notifier.take() else {
return Err(Error::ProtocolNotSupport(
"No notifier for Package".to_string(),
));
};

notifier
.send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {}))
.map_err(|_| Error::EssentialTaskClosed)?;
}

proto::TunnelCommand::Break => {
let tunnel_id = response.tunnel_id.parse().map_err(|_| {
Error::TunnelIdParseError(response.tunnel_id.clone())
Expand All @@ -338,8 +367,8 @@ impl PProxy {
}

// Terminat connected tunnel
if let Some(tx) = self.tunnel_txs.remove(&(peer, tunnel_id)) {
tx.send(Err(TunnelError::ConnectionAborted)).await?
if let Some(ctx) = self.tunnel_ctx.remove(&(peer, tunnel_id)) {
ctx.tx.send(Err(TunnelError::ConnectionAborted)).await?
};
}

Expand Down Expand Up @@ -432,7 +461,10 @@ impl PProxy {
tunnel_tx: mpsc::Sender<ChannelPackage>,
tx: CommandNotifier,
) -> Result<()> {
self.tunnel_txs.insert((peer_id, tunnel_id), tunnel_tx);
self.tunnel_ctx.insert((peer_id, tunnel_id), TunnelContext {
tx: tunnel_tx,
outbound_sent_notifier: None,
});

let request = proto::Tunnel {
tunnel_id: tunnel_id.to_string(),
Expand Down Expand Up @@ -465,13 +497,23 @@ impl PProxy {
data: Some(data),
};

let Some(ctx) = self.tunnel_ctx.get_mut(&(peer_id, tunnel_id)) else {
let err_msg = "No ctx for outbound package";

tx.send(Err(Error::TunnelContextNotFound(err_msg.to_string())))
.map_err(|_| Error::EssentialTaskClosed)?;

return Err(Error::TunnelContextNotFound(err_msg.to_string()));
};

ctx.outbound_sent_notifier = Some(tx);

self.swarm
.behaviour_mut()
.request_response
.send_request(&peer_id, request);

tx.send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {}))
.map_err(|_| Error::EssentialTaskClosed)
Ok(())
}

async fn on_expire_peer_access(&mut self, peer_id: PeerId, tx: CommandNotifier) -> Result<()> {
Expand Down
22 changes: 5 additions & 17 deletions src/p2p/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Codec;
impl libp2p::request_response::Codec for Codec {
type Protocol = StreamProtocol;
type Request = proto::Tunnel;
type Response = Option<proto::Tunnel>;
type Response = proto::Tunnel;

async fn read_request<T>(
&mut self,
Expand Down Expand Up @@ -48,13 +48,7 @@ impl libp2p::request_response::Codec for Codec {

io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?;

if vec.is_empty() {
return Ok(None);
}

proto::Tunnel::decode(vec.as_slice())
.map(Some)
.map_err(decode_into_io_error)
proto::Tunnel::decode(vec.as_slice()).map_err(decode_into_io_error)
}

async fn write_request<T>(
Expand All @@ -80,14 +74,8 @@ impl libp2p::request_response::Codec for Codec {
where
T: AsyncWrite + Unpin + Send,
{
let mut data = vec![];

if let Some(resp) = resp {
data.extend_from_slice(resp.encode_to_vec().as_slice());
};

let data = resp.encode_to_vec();
io.write_all(data.as_ref()).await?;

Ok(())
}
}
Expand Down Expand Up @@ -137,7 +125,7 @@ mod tests {

let (mut a, mut b) = Endpoint::pair(124, 124);
Codec
.write_response(&protocol, &mut a, Some(expected_response.clone()))
.write_response(&protocol, &mut a, expected_response.clone())
.await
.expect("Should write response");
a.close().await.unwrap();
Expand All @@ -148,6 +136,6 @@ mod tests {
.expect("Should read response");
b.close().await.unwrap();

assert_eq!(actual_response, Some(expected_response));
assert_eq!(actual_response, expected_response);
}
}

0 comments on commit ca254ac

Please sign in to comment.