From f1f81eb613087b5fd9cefbffe175291820c4f827 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 22 Jun 2024 15:36:04 +0900 Subject: [PATCH 1/4] rename tests --- tests/test.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test.rs b/tests/test.rs index fecc1d5..3620901 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -313,7 +313,7 @@ async fn test_modify_header() { assert_eq!(String::from_utf8(body).unwrap(), "MODIFIED"); } #[tokio::test] -async fn test_modify_url() { +async fn test_modify_url_http_to_http() { let app = Router::new().route("/", get(|| async { "Hello, World!" })); let mut setup = setup(app, false).await; @@ -340,7 +340,7 @@ async fn test_modify_url() { } #[tokio::test] -async fn test_modify_url_https() { +async fn test_modify_url_http_to_https() { let app = Router::new().route("/", get(|| async { "Hello, World!" })); let mut setup = setup(app, true).await; @@ -529,7 +529,7 @@ async fn test_tls_simple() { } #[tokio::test] -async fn test_tls_modify_url() { +async fn test_tls_modify_url_https_to_https() { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .init(); @@ -566,7 +566,7 @@ async fn test_tls_modify_url() { } #[tokio::test] -async fn test_tls_modify_url_http() { +async fn test_tls_modify_url_https_to_http() { let app = Router::new().route("/", get(|| async { "Hello, World!" })); let mut setup = setup_tls(app, false, true).await; From 09300fdc6e6638b9316c5567fac1656200da282c Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 22 Jun 2024 15:53:31 +0900 Subject: [PATCH 2/4] add async fn send_and_receive_request --- src/lib.rs | 65 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d8f68e1..dde258c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,20 +163,9 @@ impl + Send + Sync + 'static> MitmProxy { B::Error: Into>, { let original_uri = req.uri().clone(); - let (req_back_tx, req_back_rx) = futures::channel::oneshot::channel(); - let (res_tx, res_rx) = futures::channel::oneshot::channel(); - let (upgrade_tx, upgrade_rx) = futures::channel::oneshot::channel(); - // Used tokio::spawn above to middle_man can consume rx in request() - let _ = tx.unbounded_send(Communication { - client_addr, - request: req, - request_back: req_back_tx, - response: res_rx, - upgrade: upgrade_rx, - }); - - let Ok(req) = req_back_rx.await else { - tracing::info!("Request canceled"); + + let (Some(req), res_tx, upgrade_tx) = send_and_receive_request(&tx, client_addr, req).await + else { return Ok(no_body(StatusCode::INTERNAL_SERVER_ERROR)); }; @@ -228,20 +217,11 @@ impl + Send + Sync + 'static> MitmProxy { let proxy = proxy.clone(); async move { - let (req_back_tx, req_back_rx) = futures::channel::oneshot::channel(); - let (res_tx, res_rx) = futures::channel::oneshot::channel(); - let (upgrade_tx, upgrade_rx) = futures::channel::oneshot::channel(); - inject_authority(&mut req, connect_authority.clone()); - let _ = tx.unbounded_send(Communication { - client_addr, - request: req, - request_back: req_back_tx, - response: res_rx, - upgrade: upgrade_rx, - }); - let Ok(req) = req_back_rx.await else { - tracing::info!("Request canceled"); + + let (Some(req), res_tx, upgrade_tx) = + send_and_receive_request(&tx, client_addr, req).await + else { return Ok::<_, hyper::Error>(no_body( StatusCode::INTERNAL_SERVER_ERROR, )); @@ -595,3 +575,34 @@ where (StreamBody::new(body), rx) } + +async fn send_and_receive_request( + tx: &UnboundedSender>, + client_addr: std::net::SocketAddr, + req: Request, +) -> ( + Option>, + futures::channel::oneshot::Sender< + Result, Arc>>>, hyper::Error>, + >, + futures::channel::oneshot::Sender, +) { + let (req_back_tx, req_back_rx) = futures::channel::oneshot::channel(); + let (res_tx, res_rx) = futures::channel::oneshot::channel(); + let (upgrade_tx, upgrade_rx) = futures::channel::oneshot::channel(); + // Used tokio::spawn above to middle_man can consume rx in request() + let _ = tx.unbounded_send(Communication { + client_addr, + request: req, + request_back: req_back_tx, + response: res_rx, + upgrade: upgrade_rx, + }); + + if let Ok(req) = req_back_rx.await { + tracing::info!("Request canceled"); + (Some(req), res_tx, upgrade_tx) + } else { + (None, res_tx, upgrade_tx) + } +} From 0892239d3207eed22603afd0f727f00f9f32d911 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 22 Jun 2024 16:06:16 +0900 Subject: [PATCH 3/4] good error message --- src/lib.rs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dde258c..8caa575 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -206,9 +206,16 @@ impl + Send + Sync + 'static> MitmProxy { // TODO: Cache server_config let server_config = Arc::new(server_config); let tls_acceptor = tokio_rustls::TlsAcceptor::from(server_config); - let Ok(client) = tls_acceptor.accept(TokioIo::new(client)).await else { - tracing::error!("Failed to accept TLS connection for {}", original_host); - return; + let client = match tls_acceptor.accept(TokioIo::new(client)).await { + Ok(client) => client, + Err(err) => { + tracing::error!( + "Failed to accept TLS connection for {}, {}", + original_host, + err + ); + return; + } }; let f = move |mut req: Request<_>| { @@ -403,7 +410,13 @@ impl MitmProxyImpl { 80 }); - let tcp = TcpStream::connect((host, port)).await.unwrap(); + let tcp = match TcpStream::connect((host, port)).await { + Ok(tcp) => tcp, + Err(err) => { + tracing::error!("Failed to connect to {}:{} {}", host, port, err); + panic!(); + } + }; // This is actually needed to some servers let _ = tcp.set_nodelay(true); From e9785578834fba28b2defff3014cc6110e86c81f Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 22 Jun 2024 16:38:14 +0900 Subject: [PATCH 4/4] fix comment --- src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 8caa575..a3a3cad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,7 +170,6 @@ impl + Send + Sync + 'static> MitmProxy { }; if req.method() == Method::CONNECT { - // Modified CONNECT request is ignored // HTTPS connection let Some(connect_authority) = req.uri().authority().cloned() else { tracing::error!(