From c54be33f6924e32f78718b6ea765defa230762b5 Mon Sep 17 00:00:00 2001 From: magine Date: Mon, 19 Aug 2024 11:57:52 +0800 Subject: [PATCH] feat: convert auth client to access client --- Cargo.lock | 22 ++++++------ proto/command_v1.proto | 7 ++++ src/access.rs | 80 ++++++++++++++++++++++++++++++++++++++++++ src/auth.rs | 41 ---------------------- src/command.rs | 13 +++++++ src/lib.rs | 58 +++++++++++++++++++++++++----- src/main.rs | 14 ++++---- 7 files changed, 167 insertions(+), 68 deletions(-) create mode 100644 src/access.rs delete mode 100644 src/auth.rs diff --git a/Cargo.lock b/Cargo.lock index a6fb052..7757ce3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1482,9 +1482,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.154" +version = "0.2.157" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" +checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86" [[package]] name = "libp2p" @@ -1936,13 +1936,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.11" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3167,26 +3168,25 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "pin-project-lite", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", diff --git a/proto/command_v1.proto b/proto/command_v1.proto index 72c5634..421df2a 100644 --- a/proto/command_v1.proto +++ b/proto/command_v1.proto @@ -28,8 +28,15 @@ message ConnectRelayResponse { string relaied_multiaddr = 1; } +message ExpirePeerAccessRequest { + string peer_id = 1; +} + +message ExpirePeerAccessResponse {} + service CommandService { rpc AddPeer(AddPeerRequest) returns (AddPeerResponse); rpc CreateTunnelServer(CreateTunnelServerRequest) returns (CreateTunnelServerResponse); rpc ConnectRelay(ConnectRelayRequest) returns (ConnectRelayResponse); + rpc ExpirePeerAccess(ExpirePeerAccessRequest) returns (ExpirePeerAccessResponse); } diff --git a/src/access.rs b/src/access.rs new file mode 100644 index 0000000..0cc4f4e --- /dev/null +++ b/src/access.rs @@ -0,0 +1,80 @@ +use std::collections::HashMap; +use std::time::Duration; +use std::time::Instant; + +use libp2p::PeerId; +use serde::Deserialize; + +const ACCESS_TTL: Duration = Duration::from_secs(10 * 60); + +pub struct AccessClient { + local_id: libp2p::PeerId, + endpoint: reqwest::Url, + client: reqwest::Client, + cache: HashMap, +} + +#[derive(Deserialize)] +struct AccessClientResponse { + data: bool, +} + +impl AccessClient { + pub fn new(local_id: libp2p::PeerId, endpoint: reqwest::Url) -> AccessClient { + AccessClient { + local_id, + endpoint, + client: reqwest::Client::new(), + cache: HashMap::default(), + } + } + + async fn request_endpoint(&mut self, peer: &PeerId) -> Result { + let url = self.endpoint.join("access-control").unwrap(); + let params = [ + ("device", self.local_id.to_string()), + ("token", peer.to_string()), + ]; + + let response = self + .client + .get(url) + .query(¶ms) + .send() + .await? + .json::() + .await?; + + Ok(response.data) + } + + pub async fn is_valid(&mut self, peer: &PeerId) -> bool { + self.cache + .retain(|_, (_, created)| created.elapsed() < ACCESS_TTL); + + if let Some((valid, _)) = self.cache.get(peer) { + return *valid; + } + + match self.request_endpoint(peer).await { + Err(e) => { + tracing::error!( + "error while requesting access endpoint (will return false) for {}: {}", + peer, + e + ); + false + } + + Ok(valid) => { + self.cache.insert(*peer, (valid, Instant::now())); + valid + } + } + } + + pub fn expire(&mut self, token: &PeerId) { + tracing::debug!("expire token: {} from cache", token); + self.cache.remove(token); + } +} diff --git a/src/auth.rs b/src/auth.rs deleted file mode 100644 index 8d0157b..0000000 --- a/src/auth.rs +++ /dev/null @@ -1,41 +0,0 @@ -use serde::Deserialize; - -pub struct AuthClient { - local_id: libp2p::PeerId, - endpoint: reqwest::Url, - client: reqwest::Client, -} - -#[derive(Deserialize)] -struct AuthClientResponse { - data: bool, -} - -impl AuthClient { - pub fn new(local_id: libp2p::PeerId, endpoint: reqwest::Url) -> AuthClient { - AuthClient { - local_id, - endpoint, - client: reqwest::Client::new(), - } - } - - pub async fn is_valid(&self, token: &str) -> Result { - let url = self.endpoint.join("access-control").unwrap(); - let params = [ - ("device", self.local_id.to_string()), - ("token", token.to_string()), - ]; - - let response = self - .client - .get(url) - .query(¶ms) - .send() - .await? - .json::() - .await?; - - Ok(response.data) - } -} diff --git a/src/command.rs b/src/command.rs index e44b76a..0207e0e 100644 --- a/src/command.rs +++ b/src/command.rs @@ -63,4 +63,17 @@ impl proto::command_service_server::CommandService for PProxyCommander { .map(Response::new) .map_err(|e| tonic::Status::internal(format!("{:?}", e))) } + + async fn expire_peer_access( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + tracing::debug!("handle request: {:?}", request); + + self.handle + .expire_peer_access(request.into_inner()) + .await + .map(Response::new) + .map_err(|e| tonic::Status::internal(format!("{:?}", e))) + } } diff --git a/src/lib.rs b/src/lib.rs index 4b3c5b5..8aab8a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,13 +15,15 @@ use libp2p::PeerId; use libp2p::Swarm; use tokio::sync::mpsc; -use crate::auth::AuthClient; +use crate::access::AccessClient; use crate::command::proto::AddPeerRequest; use crate::command::proto::AddPeerResponse; use crate::command::proto::ConnectRelayRequest; use crate::command::proto::ConnectRelayResponse; use crate::command::proto::CreateTunnelServerRequest; use crate::command::proto::CreateTunnelServerResponse; +use crate::command::proto::ExpirePeerAccessRequest; +use crate::command::proto::ExpirePeerAccessResponse; use crate::p2p::PProxyNetworkBehaviour; use crate::p2p::PProxyNetworkBehaviourEvent; use crate::tunnel::proto; @@ -30,7 +32,7 @@ use crate::tunnel::Tunnel; use crate::tunnel::TunnelServer; use crate::types::*; -mod auth; +mod access; pub mod command; pub mod error; mod p2p; @@ -75,6 +77,9 @@ pub enum PProxyCommand { tunnel_id: TunnelId, data: Vec, }, + ExpirePeerAccess { + peer_id: PeerId, + }, } pub enum PProxyCommandResponse { @@ -82,6 +87,7 @@ pub enum PProxyCommandResponse { ConnectRelay { relaied_multiaddr: Multiaddr }, SendConnectCommand {}, SendOutboundPackageCommand {}, + ExpirePeerAccess {}, } pub struct PProxy { @@ -92,7 +98,7 @@ pub struct PProxy { outbound_ready_notifiers: HashMap, inbound_tunnels: HashMap<(PeerId, TunnelId), Tunnel>, tunnel_txs: HashMap<(PeerId, TunnelId), mpsc::Sender>>, - auth_client: Option, + access_client: Option, } pub struct PProxyHandle { @@ -106,13 +112,13 @@ impl PProxy { keypair: Keypair, listen_addr: SocketAddr, proxy_addr: Option, - auth_server_endpoint: Option, + access_server_endpoint: Option, ) -> Result<(Self, PProxyHandle)> { let (command_tx, command_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); let swarm = crate::p2p::new_swarm(keypair, listen_addr) .map_err(|e| Error::Libp2pSwarmCreateError(e.to_string()))?; - let auth_client = - auth_server_endpoint.map(|endpoint| AuthClient::new(*swarm.local_peer_id(), endpoint)); + let access_client = access_server_endpoint + .map(|endpoint| AccessClient::new(*swarm.local_peer_id(), endpoint)); Ok(( Self { @@ -123,7 +129,7 @@ impl PProxy { outbound_ready_notifiers: HashMap::new(), inbound_tunnels: HashMap::new(), tunnel_txs: HashMap::new(), - auth_client, + access_client, }, PProxyHandle { command_tx, @@ -187,8 +193,8 @@ impl PProxy { request_response::Message::Request { request, channel, .. } => { - if let Some(auth_client) = &mut self.auth_client { - if !auth_client.is_valid(&peer.to_string()).await? { + if let Some(ac) = &mut self.access_client { + if !ac.is_valid(&peer).await { // TODO: Manage tunnel lifecycle return Err(Error::Tunnel(error::TunnelError::ConnectionClosed)); } @@ -329,6 +335,9 @@ impl PProxy { self.on_send_outbound_package_command(peer_id, tunnel_id, data, tx) .await } + PProxyCommand::ExpirePeerAccess { peer_id } => { + self.on_expire_peer_access(peer_id, tx).await + } } } @@ -405,6 +414,17 @@ impl PProxy { tx.send(Ok(PProxyCommandResponse::SendOutboundPackageCommand {})) .map_err(|_| Error::EssentialTaskClosed) } + + async fn on_expire_peer_access(&mut self, peer_id: PeerId, tx: CommandNotifier) -> Result<()> { + if let Some(ref mut ac) = self.access_client { + ac.expire(&peer_id); + } + + tx.send(Ok(PProxyCommandResponse::ExpirePeerAccess {})) + .map_err(|_| Error::EssentialTaskClosed)?; + + Ok(()) + } } impl PProxyHandle { @@ -495,6 +515,26 @@ impl PProxyHandle { _ => Err(Error::UnexpectedResponseType), } } + + pub async fn expire_peer_access( + &self, + request: ExpirePeerAccessRequest, + ) -> Result { + let (tx, rx) = oneshot::channel(); + + let peer_id = request + .peer_id + .parse() + .map_err(|_| Error::PeerIdParseError(request.peer_id))?; + + self.command_tx + .send((PProxyCommand::ExpirePeerAccess { peer_id }, tx)) + .await?; + + rx.await??; + + Ok(ExpirePeerAccessResponse {}) + } } fn extract_peer_id_from_multiaddr(multiaddr: &Multiaddr) -> Result { diff --git a/src/main.rs b/src/main.rs index 03eb0c3..7b2872c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,11 +53,11 @@ fn parse_args() -> Command { .help("Will reverse proxy this address via tunnel protocol if set"), ) .arg( - Arg::new("AUTH_SERVER_ENDPOINT") - .long("auth-server-endpoint") + Arg::new("ACCESS_SERVER_ENDPOINT") + .long("access-server-endpoint") .num_args(1) .action(ArgAction::Set) - .help("Authentication server endpoint is used to verify if one peer can access another. If not set, all access is allowed."), + .help("Access server endpoint is used to verify if one peer can access another. If not set, all access is allowed."), ); let create_tunnel_server = Command::new("create_tunnel_server") @@ -139,9 +139,9 @@ async fn serve(args: &ArgMatches) { let proxy_addr = args .get_one::("PROXY_ADDR") .map(|addr| addr.parse().expect("Invalid proxy address")); - let auth_server_endpoint = args - .get_one::("AUTH_SERVER_ENDPOINT") - .map(|endpoint| Url::parse(endpoint).expect("Invalid authentication server endpoint")); + let access_server_endpoint = args + .get_one::("ACCESS_SERVER_ENDPOINT") + .map(|endpoint| Url::parse(endpoint).expect("Invalid access server endpoint")); println!("server_addr: {}", server_addr); println!("commander_server_addr: {}", commander_server_addr); @@ -150,7 +150,7 @@ async fn serve(args: &ArgMatches) { identity::ed25519::Keypair::from(key).into(), server_addr, proxy_addr, - auth_server_endpoint, + access_server_endpoint, ) .expect("Create pproxy failed");