From 4743cd66bc93eb4963110594863040a8d4e79fc9 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 28 Sep 2024 17:31:08 +0900 Subject: [PATCH] cache cert --- Cargo.lock | 105 ++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + README.md | 2 + examples/dev_proxy.rs | 2 + examples/proxy.rs | 2 + src/lib.rs | 79 ++++++++++++++++++++++++++----- src/tls.rs | 36 ++++++--------- tests/test.rs | 5 +- 8 files changed, 196 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 161a04c..34163d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -389,6 +389,30 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "data-encoding" version = "2.6.0" @@ -711,6 +735,7 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", + "moka", "native-tls", "rcgen", "reqwest", @@ -983,6 +1008,26 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" +[[package]] +name = "moka" +version = "0.12.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "once_cell", + "parking_lot", + "quanta", + "rustc_version", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -1216,6 +1261,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quanta" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.37" @@ -1225,6 +1285,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "raw-cpuid" +version = "11.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb9ee317cfe3fbd54b36a511efc1edd42e216903c9cd575e686dd68a2ba90d8d" +dependencies = [ + "bitflags", +] + [[package]] name = "rcgen" version = "0.13.1" @@ -1362,6 +1431,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rusticata-macros" version = "4.1.0" @@ -1477,6 +1555,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.210" @@ -1647,6 +1731,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tempfile" version = "3.12.0" @@ -1888,6 +1978,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triomphe" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" + [[package]] name = "try-lock" version = "0.2.5" @@ -1938,6 +2034,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f2935cc..4c9cca6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ tracing = "0.1.40" hyper-util = { version = "0.1.7", features = ["tokio"] } native-tls = { version = "0.2.12", features = ["alpn"] } thiserror = "1.0.62" +moka = { version = "0.12.8", features = ["sync"] } [dev-dependencies] axum = { version = "0.7.2", features = ["http2"] } diff --git a/README.md b/README.md index 01911fa..c9292ff 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ use std::path::PathBuf; use clap::{Args, Parser}; use http_mitm_proxy::{DefaultClient, MitmProxy}; +use moka::sync::Cache; use tracing_subscriber::EnvFilter; #[derive(Parser)] @@ -82,6 +83,7 @@ async fn main() { let proxy = MitmProxy::new( // This is the root cert that will be used to sign the fake certificates Some(root_cert), + Some(Cache::new(128)), ); let client = DefaultClient::new( diff --git a/examples/dev_proxy.rs b/examples/dev_proxy.rs index a5ecd62..15cb149 100644 --- a/examples/dev_proxy.rs +++ b/examples/dev_proxy.rs @@ -6,6 +6,7 @@ use clap::{Args, Parser}; use http_body_util::{BodyExt, Full}; use http_mitm_proxy::{DefaultClient, MitmProxy}; use hyper::Response; +use moka::sync::Cache; #[derive(Parser)] struct Opt { @@ -76,6 +77,7 @@ async fn main() { let proxy = MitmProxy::new( // This is the root cert that will be used to sign the fake certificates Some(root_cert), + Some(Cache::new(128)), ); let client = DefaultClient::new( diff --git a/examples/proxy.rs b/examples/proxy.rs index e46ecff..2d69c3a 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use clap::{Args, Parser}; use http_mitm_proxy::{DefaultClient, MitmProxy}; +use moka::sync::Cache; use tracing_subscriber::EnvFilter; #[derive(Parser)] @@ -69,6 +70,7 @@ async fn main() { let proxy = MitmProxy::new( // This is the root cert that will be used to sign the fake certificates Some(root_cert), + Some(Cache::new(128)), ); let client = DefaultClient::new( diff --git a/src/lib.rs b/src/lib.rs index 5599c8b..7d4582c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,12 +9,14 @@ use hyper::{ Method, Request, Response, StatusCode, }; use hyper_util::rt::{TokioExecutor, TokioIo}; +use moka::sync::Cache; use std::{borrow::Borrow, future::Future, net::SocketAddr, sync::Arc}; -use tls::server_config; +use tls::{generate_cert, CertifiedKeyDer}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; pub use futures; pub use hyper; +pub use moka; pub use tokio_native_tls; pub mod default_client; @@ -29,12 +31,20 @@ pub struct MitmProxy { /// /// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic. pub root_cert: Option, + /// Cache to store generated certificates. If None, cache will not be used. + /// If root_cert is None, cache will not be used. + /// + /// The key of cache is hostname. + pub cert_cache: Option>, } impl MitmProxy { /// Create a new MitmProxy - pub fn new(root_cert: Option) -> Self { - Self { root_cert } + pub fn new(root_cert: Option, cache: Option>) -> Self { + Self { + root_cert, + cert_cache: cache, + } } } @@ -119,15 +129,20 @@ impl + Send + Sync + 'static> MitmProxy { ); return; }; - if let Some(root_cert) = proxy.root_cert.as_ref() { - let Ok(server_config) = - // Even if URL is modified by middleman, we should sign with original host name to communicate client. - server_config(connect_authority.host().to_string(), root_cert.borrow(), true) - else { - tracing::error!("Failed to create server config for {}", connect_authority.host()); - return; + if let Some(server_config) = + proxy.server_config(connect_authority.host().to_string(), true) + { + let server_config = match server_config { + Ok(server_config) => server_config, + Err(err) => { + tracing::error!( + "Failed to create server config for {}, {}", + connect_authority.host(), + err + ); + return; + } }; - // TODO: Cache server_config let server_config = Arc::new(server_config); let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config); let client = match tls_acceptor.accept(TokioIo::new(client)).await { @@ -189,6 +204,48 @@ impl + Send + Sync + 'static> MitmProxy { .map(|res| res.map(|b| b.boxed())) } } + + fn get_certified_key(&self, host: String) -> Option { + if let Some(root_cert) = self.root_cert.as_ref() { + Some(if let Some(cache) = self.cert_cache.as_ref() { + cache.get_with(host.clone(), move || { + generate_cert(host, root_cert.borrow()) + }) + } else { + generate_cert(host, root_cert.borrow()) + }) + } else { + None + } + } + + fn server_config( + &self, + host: String, + h2: bool, + ) -> Option> { + if let Some(cert) = self.get_certified_key(host) { + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + vec![rustls::pki_types::CertificateDer::from(cert.cert_der)], + rustls::pki_types::PrivateKeyDer::Pkcs8( + rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der), + ), + ); + + Some(if h2 { + config.map(|mut server_config| { + server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; + server_config + }) + } else { + config + }) + } else { + None + } + } } fn no_body(status: StatusCode) -> Response> { diff --git a/src/tls.rs b/src/tls.rs index 627dcb8..fccb6e1 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,8 +1,11 @@ -pub fn server_config( - host: String, - root_cert: &rcgen::CertifiedKey, - h2: bool, -) -> Result { +#[derive(Debug, Clone)] +pub struct CertifiedKeyDer { + pub cert_der: Vec, + /// Pkcs8 + pub key_der: Vec, +} + +pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> CertifiedKeyDer { let mut cert_params = rcgen::CertificateParams::new(vec![host]).unwrap(); cert_params .key_usages @@ -14,27 +17,14 @@ pub fn server_config( .extended_key_usages .push(rcgen::ExtendedKeyUsagePurpose::ClientAuth); - let private_key = rcgen::KeyPair::generate().unwrap(); + let key_pair = rcgen::KeyPair::generate().unwrap(); let cert = cert_params - .signed_by(&private_key, &root_cert.cert, &root_cert.key_pair) + .signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair) .unwrap(); - let config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - vec![rustls::pki_types::CertificateDer::from(cert)], - rustls::pki_types::PrivateKeyDer::Pkcs8(rustls::pki_types::PrivatePkcs8KeyDer::from( - private_key.serialize_der(), - )), - ); - - if h2 { - config.map(|mut server_config| { - server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()]; - server_config - }) - } else { - config + CertifiedKeyDer { + cert_der: cert.der().to_vec(), + key_der: key_pair.serialize_der(), } } diff --git a/tests/test.rs b/tests/test.rs index 96f363a..41424a8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -17,6 +17,7 @@ use hyper::{ body::{Body, Incoming}, Response, Uri, }; +use moka::sync::Cache; static PORT: AtomicU16 = AtomicU16::new(3666); @@ -88,7 +89,7 @@ where S: Fn(SocketAddr, Request) -> F + Send + Sync + Clone + 'static, F: std::future::Future, E2>> + Send + 'static, { - let proxy = MitmProxy::new(Some(root_cert())); + let proxy = MitmProxy::new(Some(root_cert()), Some(Cache::new(128))); let proxy_port = get_port(); let proxy = proxy .bind(("127.0.0.1", proxy_port), service) @@ -114,7 +115,7 @@ where S: Fn(SocketAddr, Request) -> F + Send + Sync + Clone + 'static, F: std::future::Future, E2>> + Send + 'static, { - let proxy = MitmProxy::new(Some(root_cert)); + let proxy = MitmProxy::new(Some(root_cert), Some(Cache::new(128))); let proxy_port = get_port(); let proxy = proxy .bind(("127.0.0.1", proxy_port), service)