diff --git a/Cargo.toml b/Cargo.toml index 23f2e45..904d3a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ async-shared-timeout = "0.2" base64 = "0.21" bytes = "1.5" chrono = "0.4" -clap = { version = "4.4", features = ["derive"] } +clap = { version = "4.5", features = ["derive"] } ctrlc2 = { version = "3.5", features = ["tokio", "termination"] } dotenvy = "0.15" env_logger = "0.11" @@ -35,9 +35,10 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" socks5-impl = "0.5" thiserror = "1.0" -tokio = { version = "1.35", features = ["full"] } +tokio = { version = "1.36", features = ["full"] } tokio-rustls = "0.25" tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +tokio-util = "0.7" trust-dns-proto = "0.23" tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } url = "2.5" diff --git a/src/android.rs b/src/android.rs index 02b23fe..323fac3 100644 --- a/src/android.rs +++ b/src/android.rs @@ -142,10 +142,7 @@ pub mod native { } } - lazy_static::lazy_static! { - pub static ref EXITING_FLAG: Arc = Arc::new(AtomicBool::new(false)); - pub static ref LISTEN_ADDR: Arc> = Arc::new(Mutex::new(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080))); - } + static EXITING_FLAG: std::sync::Mutex> = std::sync::Mutex::new(None); /// # Safety /// @@ -159,6 +156,16 @@ pub mod native { stat_path: JString, verbosity: jint, ) -> jint { + let shutdown_token = crate::CancellationToken::new(); + { + let mut lock = EXITING_FLAG.lock().unwrap(); + if lock.is_some() { + log::error!("tun2proxy already started"); + return -1; + } + *lock = Some(shutdown_token.clone()); + } + let mut env = env; let log_level = ArgVerbosity::try_from(verbosity).unwrap().to_string(); @@ -194,8 +201,7 @@ pub mod native { let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; rt.block_on(async { - EXITING_FLAG.store(false, Ordering::SeqCst); - crate::client::run_client(&config, Some(EXITING_FLAG.clone()), Some(callback)).await?; + crate::client::run_client(&config, shutdown_token, Some(callback)).await?; Ok::<(), Error>(()) }) }; @@ -218,16 +224,11 @@ pub mod native { pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_OverTlsWrapper_stopClient(_: JNIEnv, _: JClass) -> jint { stop_protect_socket(); - EXITING_FLAG.store(true, Ordering::SeqCst); - - let l_addr = *LISTEN_ADDR.lock().unwrap(); - let addr = if l_addr.is_ipv6() { - SocketAddr::from((Ipv6Addr::LOCALHOST, l_addr.port())) - } else { - SocketAddr::from((Ipv4Addr::LOCALHOST, l_addr.port())) - }; - let _ = std::net::TcpStream::connect(addr); - log::trace!("stopClient on listen address {l_addr}"); + if let Ok(mut token) = EXITING_FLAG.lock() { + if let Some(token) = token.take() { + token.cancel(); + } + } SocketProtector::release(); Jni::release(); diff --git a/src/api.rs b/src/api.rs index 7dccac3..1ea0d8d 100644 --- a/src/api.rs +++ b/src/api.rs @@ -5,12 +5,8 @@ use crate::{ ArgVerbosity, }; use std::{ - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + net::SocketAddr, os::raw::{c_char, c_int, c_void}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, - }, }; #[derive(Clone)] @@ -27,10 +23,7 @@ impl CCallback { unsafe impl Send for CCallback {} unsafe impl Sync for CCallback {} -lazy_static::lazy_static! { - static ref EXITING_FLAG: Arc = Arc::new(AtomicBool::new(false)); - static ref LISTEN_ADDR: Arc> = Arc::new(Mutex::new(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))); -} +static EXITING_FLAG: std::sync::Mutex> = std::sync::Mutex::new(None); /// # Safety /// @@ -53,6 +46,16 @@ unsafe fn _over_tls_client_run( callback: Option, ctx: *mut c_void, ) -> c_int { + let shutdown_token = crate::CancellationToken::new(); + { + let mut lock = EXITING_FLAG.lock().unwrap(); + if lock.is_some() { + log::error!("tun2proxy already started"); + return -1; + } + *lock = Some(shutdown_token.clone()); + } + let ccb = CCallback(callback, ctx); let block = || -> Result<()> { @@ -61,8 +64,6 @@ unsafe fn _over_tls_client_run( let cb = |addr: SocketAddr| { log::trace!("Listening on {}", addr); let port = addr.port(); - let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, port)); - *LISTEN_ADDR.lock().unwrap() = addr; unsafe { ccb.call(port as c_int); } @@ -72,8 +73,7 @@ unsafe fn _over_tls_client_run( config.check_correctness(false)?; let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; rt.block_on(async { - EXITING_FLAG.store(false, Ordering::SeqCst); - crate::client::run_client(&config, Some(EXITING_FLAG.clone()), Some(cb)).await?; + crate::client::run_client(&config, shutdown_token, Some(cb)).await?; Ok::<(), Error>(()) }) }; @@ -89,15 +89,10 @@ unsafe fn _over_tls_client_run( /// Shutdown the client. #[no_mangle] pub unsafe extern "C" fn over_tls_client_stop() -> c_int { - EXITING_FLAG.store(true, Ordering::SeqCst); - - let l_addr = *LISTEN_ADDR.lock().unwrap(); - let addr = if l_addr.is_ipv6() { - SocketAddr::from((Ipv6Addr::LOCALHOST, l_addr.port())) - } else { - SocketAddr::from((Ipv4Addr::LOCALHOST, l_addr.port())) - }; - let _ = std::net::TcpStream::connect(addr); - log::trace!("Client stop on listen address {}", l_addr); + if let Ok(mut token) = EXITING_FLAG.lock() { + if let Some(token) = token.take() { + token.cancel(); + } + } 0 } diff --git a/src/bin/overtls.rs b/src/bin/overtls.rs index b74a6d1..172d77d 100644 --- a/src/bin/overtls.rs +++ b/src/bin/overtls.rs @@ -1,8 +1,4 @@ use overtls::{client, config, server, CmdOpt, Error, Result}; -use std::{ - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - sync::{atomic::AtomicBool, Arc}, -}; fn main() -> Result<()> { let opt = CmdOpt::parse_cmd(); @@ -43,13 +39,13 @@ fn main() -> Result<()> { } async fn async_main(config: config::Config) -> Result<()> { - let exiting_flag = Arc::new(AtomicBool::new(false)); - let exiting_flag_clone = exiting_flag.clone(); + let shutdown_token = overtls::CancellationToken::new(); + let shutdown_token_clone = shutdown_token.clone(); let main_body = async { if config.is_server { if config.exist_server() { - server::run_server(&config, Some(exiting_flag_clone)).await?; + server::run_server(&config, shutdown_token_clone).await?; } else { return Err(Error::from("Config is not a server config")); } @@ -57,7 +53,7 @@ async fn async_main(config: config::Config) -> Result<()> { let callback = |addr| { log::trace!("Listening on {}", addr); }; - client::run_client(&config, Some(exiting_flag_clone), Some(callback)).await?; + client::run_client(&config, shutdown_token_clone, Some(callback)).await?; } else { return Err("Config is not a client config".into()); } @@ -65,18 +61,9 @@ async fn async_main(config: config::Config) -> Result<()> { Ok(()) }; - let local_addr = config.listen_addr()?; - ctrlc2::set_async_handler(async move { - exiting_flag.store(true, std::sync::atomic::Ordering::Relaxed); - - let addr = if local_addr.is_ipv6() { - SocketAddr::from((Ipv6Addr::LOCALHOST, local_addr.port())) - } else { - SocketAddr::from((Ipv4Addr::LOCALHOST, local_addr.port())) - }; - let _ = std::net::TcpStream::connect(addr); - log::info!(""); + log::info!("Ctrl-C received, exiting..."); + shutdown_token.cancel(); }) .await; diff --git a/src/client.rs b/src/client.rs index 05aa06f..39ba09b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,13 +16,7 @@ use socks5_impl::{ AuthAdaptor, ClientConnection, Connect, IncomingConnection, Server, }, }; -use std::{ - net::SocketAddr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; +use std::{net::SocketAddr, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, @@ -38,7 +32,7 @@ use tungstenite::{ protocol::{Message, Role}, }; -pub async fn run_client(config: &Config, exiting_flag: Option>, callback: Option) -> Result<()> +pub async fn run_client(config: &Config, quit: crate::CancellationToken, callback: Option) -> Result<()> where F: FnOnce(SocketAddr) + Send + Sync + 'static, { @@ -52,14 +46,14 @@ where if let Some(user) = listen_user { let listen_password = client.listen_password.as_deref().unwrap_or(""); let key = UserKeyAuth::new(user, listen_password); - _run_client(config, Arc::new(key), exiting_flag, callback).await?; + _run_client(config, Arc::new(key), quit, callback).await?; } else { - _run_client(config, Arc::new(NoAuth), exiting_flag, callback).await?; + _run_client(config, Arc::new(NoAuth), quit, callback).await?; } Ok(()) } -async fn _run_client(config: &Config, auth: AuthAdaptor, exiting_flag: Option>, callback: Option) -> Result<()> +async fn _run_client(config: &Config, auth: AuthAdaptor, quit: crate::CancellationToken, callback: Option) -> Result<()> where F: FnOnce(SocketAddr) + Send + Sync + 'static, O: Send + Sync + 'static, @@ -74,23 +68,27 @@ where } let (udp_tx, _, incomings) = udprelay::create_udp_tunnel(); - udprelay::udp_handler_watchdog(config, &incomings, &udp_tx, exiting_flag.clone()).await?; + udprelay::udp_handler_watchdog(config, &incomings, &udp_tx, quit.clone()).await?; - while let Ok((conn, _)) = server.accept().await { - if let Some(exiting_flag) = &exiting_flag { - if exiting_flag.load(Ordering::Relaxed) { + loop { + tokio::select! { + _ = quit.cancelled() => { log::info!("exiting..."); break; } - } - let config = config.clone(); - let udp_tx = udp_tx.clone(); - let incomings = incomings.clone(); - tokio::spawn(async move { - if let Err(e) = handle_incoming(conn, config, Some(udp_tx), incomings).await { - log::debug!("{}", e); + result = server.accept() => { + if let Ok((conn, _)) = result { + let config = config.clone(); + let udp_tx = udp_tx.clone(); + let incomings = incomings.clone(); + tokio::spawn(async move { + if let Err(e) = handle_incoming(conn, config, Some(udp_tx), incomings).await { + log::debug!("{}", e); + } + }); + } } - }); + } } Ok(()) diff --git a/src/cmdopt.rs b/src/cmdopt.rs index ab0d07c..65f6b67 100644 --- a/src/cmdopt.rs +++ b/src/cmdopt.rs @@ -73,7 +73,7 @@ impl std::fmt::Display for ArgVerbosity { } /// Proxy tunnel over tls -#[derive(clap::Parser, Debug, Clone, PartialEq, Eq)] +#[derive(clap::Parser, Debug, Clone, PartialEq, Eq, Default)] #[command(author, version, about = "Proxy tunnel over tls.", long_about = None)] pub struct CmdOpt { /// Role of server or client diff --git a/src/lib.rs b/src/lib.rs index bc1aa28..31dab80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ use bytes::BytesMut; pub use cmdopt::{ArgVerbosity, CmdOpt, Role}; pub use error::{Error, Result}; use socks5_impl::protocol::{Address, StreamOperation}; +pub use tokio_util::sync::CancellationToken; #[cfg(target_os = "windows")] pub(crate) const STREAM_BUFFER_SIZE: usize = 1024 * 32; diff --git a/src/server.rs b/src/server.rs index c7ce803..903c231 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,10 +12,7 @@ use socks5_impl::protocol::{Address, StreamOperation}; use std::{ collections::HashMap, net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, }; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, @@ -33,7 +30,7 @@ use tungstenite::{ const WS_HANDSHAKE_LEN: usize = 1024; const WS_MSG_HEADER_LEN: usize = 14; -pub async fn run_server(config: &Config, exiting_flag: Option>) -> Result<()> { +pub async fn run_server(config: &Config, exiting_flag: crate::CancellationToken) -> Result<()> { log::info!("starting {} server...", env!("CARGO_PKG_NAME")); log::trace!("with following settings:"); log::trace!("{}", serde_json::to_string_pretty(config)?); @@ -86,33 +83,36 @@ pub async fn run_server(config: &Config, exiting_flag: Option>) let listener = TcpListener::bind(&addr).await?; loop { - let (stream, peer_addr) = listener.accept().await?; - if let Some(exiting_flag) = &exiting_flag { - if exiting_flag.load(Ordering::Relaxed) { + tokio::select! { + _ = exiting_flag.cancelled() => { log::info!("exiting..."); break; } - } - let acceptor = acceptor.clone(); - let config = config.clone(); - let traffic_audit = traffic_audit.clone(); - - let incoming_task = async move { - if let Some(acceptor) = acceptor { - let stream = acceptor.accept(stream).await?; - handle_incoming(stream, peer_addr, config, traffic_audit).await?; - } else { - handle_incoming(stream, peer_addr, config, traffic_audit).await?; - } - Ok::<_, Error>(()) - }; + ret = listener.accept() => { + let (stream, peer_addr) = ret?; + let acceptor = acceptor.clone(); + let config = config.clone(); + let traffic_audit = traffic_audit.clone(); + + let incoming_task = async move { + if let Some(acceptor) = acceptor { + let stream = acceptor.accept(stream).await?; + handle_incoming(stream, peer_addr, config, traffic_audit).await?; + } else { + handle_incoming(stream, peer_addr, config, traffic_audit).await?; + } + Ok::<_, Error>(()) + }; - tokio::spawn(async move { - if let Err(e) = incoming_task.await { - log::debug!("{peer_addr}: {e}"); + tokio::spawn(async move { + if let Err(e) = incoming_task.await { + log::debug!("{peer_addr}: {e}"); + } + }); } - }); + } } + Ok(()) } diff --git a/src/udprelay.rs b/src/udprelay.rs index 21bcdc9..819bc78 100644 --- a/src/udprelay.rs +++ b/src/udprelay.rs @@ -17,10 +17,7 @@ use socks5_impl::{ use std::{ collections::HashSet, net::{SocketAddr, ToSocketAddrs}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, time::Duration, }; use tokio::{ @@ -284,7 +281,7 @@ pub(crate) async fn udp_handler_watchdog( config: &Config, incomings: &SocketAddrHashSet, udp_tx: &UdpRequestSender, - exiting_flag: Option>, + quit: crate::CancellationToken, ) -> Result<()> { let config = config.clone(); let incomings = incomings.clone(); @@ -292,26 +289,33 @@ pub(crate) async fn udp_handler_watchdog( tokio::spawn(async move { loop { - if let Some(ref flag) = exiting_flag { - if flag.load(Ordering::Relaxed) { - break; - } - } - let (tx, mut rx) = mpsc::channel::<()>(10); - let udp_tx = udp_tx.clone(); let incomings = incomings.clone(); let config = config.clone(); - log::trace!("[UDP] udp client guard thread started"); - let _ = tokio::spawn(async move { - if let Err(e) = run_udp_loop(udp_tx, incomings, config).await { - log::trace!("[UDP] {}", e); + + let block = async move { + let (tx, mut rx) = mpsc::channel::<()>(10); + + log::trace!("[UDP] udp client guard thread started"); + let _ = tokio::spawn(async move { + if let Err(e) = run_udp_loop(udp_tx, incomings, config).await { + log::trace!("[UDP] {}", e); + } + let _ = tx.send(()).await; + }) + .await; + let _ = rx.recv().await; + time::sleep(Duration::from_secs(1)).await; + }; + + tokio::select! { + _ = quit.cancelled() => { + break; + }, + _ = block => { + log::trace!("[UDP] udp client guard thread exited"); } - let _ = tx.send(()).await; - }) - .await; - let _ = rx.recv().await; - time::sleep(Duration::from_secs(1)).await; + }; } }); Ok(())