Skip to content

Commit

Permalink
cache cert
Browse files Browse the repository at this point in the history
  • Loading branch information
hatoo committed Sep 28, 2024
1 parent 5415403 commit 4743cd6
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 36 deletions.
105 changes: 105 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ tracing = "0.1.40"
hyper-util = { version = "0.1.7", features = ["tokio"] }
native-tls = { version = "0.2.12", features = ["alpn"] }
thiserror = "1.0.62"
moka = { version = "0.12.8", features = ["sync"] }

[dev-dependencies]
axum = { version = "0.7.2", features = ["http2"] }
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::path::PathBuf;
use clap::{Args, Parser};
use http_mitm_proxy::{DefaultClient, MitmProxy};
use moka::sync::Cache;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
Expand Down Expand Up @@ -82,6 +83,7 @@ async fn main() {
let proxy = MitmProxy::new(
// This is the root cert that will be used to sign the fake certificates
Some(root_cert),
Some(Cache::new(128)),
);
let client = DefaultClient::new(
Expand Down
2 changes: 2 additions & 0 deletions examples/dev_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use clap::{Args, Parser};
use http_body_util::{BodyExt, Full};
use http_mitm_proxy::{DefaultClient, MitmProxy};
use hyper::Response;
use moka::sync::Cache;

#[derive(Parser)]
struct Opt {
Expand Down Expand Up @@ -76,6 +77,7 @@ async fn main() {
let proxy = MitmProxy::new(
// This is the root cert that will be used to sign the fake certificates
Some(root_cert),
Some(Cache::new(128)),
);

let client = DefaultClient::new(
Expand Down
2 changes: 2 additions & 0 deletions examples/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::path::PathBuf;

use clap::{Args, Parser};
use http_mitm_proxy::{DefaultClient, MitmProxy};
use moka::sync::Cache;
use tracing_subscriber::EnvFilter;

#[derive(Parser)]
Expand Down Expand Up @@ -69,6 +70,7 @@ async fn main() {
let proxy = MitmProxy::new(
// This is the root cert that will be used to sign the fake certificates
Some(root_cert),
Some(Cache::new(128)),
);

let client = DefaultClient::new(
Expand Down
79 changes: 68 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ use hyper::{
Method, Request, Response, StatusCode,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use moka::sync::Cache;
use std::{borrow::Borrow, future::Future, net::SocketAddr, sync::Arc};
use tls::server_config;
use tls::{generate_cert, CertifiedKeyDer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};

pub use futures;
pub use hyper;
pub use moka;
pub use tokio_native_tls;

pub mod default_client;
Expand All @@ -29,12 +31,20 @@ pub struct MitmProxy<C> {
///
/// If None, proxy will just tunnel HTTPS traffic and will not observe HTTPS traffic.
pub root_cert: Option<C>,
/// Cache to store generated certificates. If None, cache will not be used.
/// If root_cert is None, cache will not be used.
///
/// The key of cache is hostname.
pub cert_cache: Option<Cache<String, CertifiedKeyDer>>,
}

impl<C> MitmProxy<C> {
/// Create a new MitmProxy
pub fn new(root_cert: Option<C>) -> Self {
Self { root_cert }
pub fn new(root_cert: Option<C>, cache: Option<Cache<String, CertifiedKeyDer>>) -> Self {
Self {
root_cert,
cert_cache: cache,
}
}
}

Expand Down Expand Up @@ -119,15 +129,20 @@ impl<C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static> MitmProxy<C> {
);
return;
};
if let Some(root_cert) = proxy.root_cert.as_ref() {
let Ok(server_config) =
// Even if URL is modified by middleman, we should sign with original host name to communicate client.
server_config(connect_authority.host().to_string(), root_cert.borrow(), true)
else {
tracing::error!("Failed to create server config for {}", connect_authority.host());
return;
if let Some(server_config) =
proxy.server_config(connect_authority.host().to_string(), true)
{
let server_config = match server_config {
Ok(server_config) => server_config,
Err(err) => {
tracing::error!(
"Failed to create server config for {}, {}",
connect_authority.host(),
err
);
return;
}
};
// TODO: Cache server_config
let server_config = Arc::new(server_config);
let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config);
let client = match tls_acceptor.accept(TokioIo::new(client)).await {
Expand Down Expand Up @@ -189,6 +204,48 @@ impl<C: Borrow<rcgen::CertifiedKey> + Send + Sync + 'static> MitmProxy<C> {
.map(|res| res.map(|b| b.boxed()))
}
}

fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
if let Some(root_cert) = self.root_cert.as_ref() {
Some(if let Some(cache) = self.cert_cache.as_ref() {
cache.get_with(host.clone(), move || {
generate_cert(host, root_cert.borrow())
})
} else {
generate_cert(host, root_cert.borrow())
})
} else {
None
}
}

fn server_config(
&self,
host: String,
h2: bool,
) -> Option<Result<rustls::ServerConfig, rustls::Error>> {
if let Some(cert) = self.get_certified_key(host) {
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![rustls::pki_types::CertificateDer::from(cert.cert_der)],
rustls::pki_types::PrivateKeyDer::Pkcs8(
rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_der),
),
);

Some(if h2 {
config.map(|mut server_config| {
server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
server_config
})
} else {
config
})
} else {
None
}
}
}

fn no_body<E>(status: StatusCode) -> Response<BoxBody<Bytes, E>> {
Expand Down
36 changes: 13 additions & 23 deletions src/tls.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
pub fn server_config(
host: String,
root_cert: &rcgen::CertifiedKey,
h2: bool,
) -> Result<rustls::ServerConfig, rustls::Error> {
#[derive(Debug, Clone)]
pub struct CertifiedKeyDer {
pub cert_der: Vec<u8>,
/// Pkcs8
pub key_der: Vec<u8>,
}

pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> CertifiedKeyDer {
let mut cert_params = rcgen::CertificateParams::new(vec![host]).unwrap();
cert_params
.key_usages
Expand All @@ -14,27 +17,14 @@ pub fn server_config(
.extended_key_usages
.push(rcgen::ExtendedKeyUsagePurpose::ClientAuth);

let private_key = rcgen::KeyPair::generate().unwrap();
let key_pair = rcgen::KeyPair::generate().unwrap();

let cert = cert_params
.signed_by(&private_key, &root_cert.cert, &root_cert.key_pair)
.signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair)
.unwrap();

let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![rustls::pki_types::CertificateDer::from(cert)],
rustls::pki_types::PrivateKeyDer::Pkcs8(rustls::pki_types::PrivatePkcs8KeyDer::from(
private_key.serialize_der(),
)),
);

if h2 {
config.map(|mut server_config| {
server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
server_config
})
} else {
config
CertifiedKeyDer {
cert_der: cert.der().to_vec(),
key_der: key_pair.serialize_der(),
}
}
Loading

0 comments on commit 4743cd6

Please sign in to comment.