diff --git a/examples/dev_proxy.rs b/examples/dev_proxy.rs index ae59dad..da36688 100644 --- a/examples/dev_proxy.rs +++ b/examples/dev_proxy.rs @@ -102,11 +102,15 @@ async fn main() { let original_url = req.uri().clone(); - // Forward connection from http/https dev.example to http://127.0.0.1:3333 - if req.uri().host() == Some("dev.example") { + // Forward connection from http/https www.marscode.com/ide/ to http://127.0.0.1:3333 + if req.uri().host() == Some("www.marscode.com") + && req.method() != hyper::Method::CONNECT + && req.uri().path().starts_with("/ide/") + { req.headers_mut().insert( hyper::header::HOST, - hyper::header::HeaderValue::from_static("127.0.0.1"), + hyper::header::HeaderValue::from_maybe_shared(format!("127.0.0.1:{}", port)) + .unwrap(), ); let mut parts = req.uri().clone().into_parts(); @@ -115,6 +119,15 @@ async fn main() { hyper::http::uri::Authority::from_maybe_shared(format!("127.0.0.1:{}", port)) .unwrap(), ); + parts.path_and_query = Some( + parts + .path_and_query + .unwrap() + .to_string() + .trim_start_matches("/ide") + .parse() + .unwrap(), + ); *req.uri_mut() = hyper::Uri::from_parts(parts).unwrap(); } diff --git a/src/lib.rs b/src/lib.rs index c345f88..46e1dae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,6 @@ use tls::server_config; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, ToSocketAddrs}, - sync::Mutex, }; pub use futures; @@ -237,43 +236,6 @@ impl + Send + Sync + 'static> MitmProxy { return; }; - let Ok(server) = TcpStream::connect(authority.as_str()).await else { - tracing::error!("Failed to connect to {}", authority); - return; - }; - let sender = if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { - let Ok(server) = proxy.tls_connector.connect(&host, server).await else { - tracing::error!("Failed to handshake TLS to {}", host); - return; - }; - let Ok((sender, conn)) = client::conn::http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(TokioIo::new(server)) - .await - else { - tracing::error!("Failed to handshake HTTP to {}", host); - return; - }; - - tokio::spawn(conn.with_upgrades()); - sender - } else { - let Ok((sender, conn)) = client::conn::http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(TokioIo::new(server)) - .await - else { - tracing::error!("Failed to handshake HTTP to {}", host); - return; - }; - - tokio::spawn(conn.with_upgrades()); - sender - }; - - let sender = Arc::new(Mutex::new(sender)); let _ = server::conn::http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) @@ -281,8 +243,10 @@ impl + Send + Sync + 'static> MitmProxy { TokioIo::new(client), service_fn(move |mut req| { let original_authority = original_authority.clone(); - let sender = sender.clone(); let tx = tx.clone(); + let authority = authority.clone(); + let host = host.clone(); + let proxy = proxy.clone(); async move { let (req_back_tx, req_back_rx) = @@ -305,24 +269,70 @@ impl + Send + Sync + 'static> MitmProxy { StatusCode::INTERNAL_SERVER_ERROR, )); }; + + let host = req.uri().host().map(str::to_string).unwrap_or(host); + let authority = + req.uri().authority().cloned().unwrap_or(authority); + let Ok(server) = TcpStream::connect(authority.as_str()).await + else { + tracing::error!("Failed to connect to {}", authority); + return Ok(no_body(StatusCode::BAD_REQUEST)); + }; + let mut sender = if req.uri().scheme() + == Some(&hyper::http::uri::Scheme::HTTPS) + { + let Ok(server) = + proxy.tls_connector.connect(&host, server).await + else { + tracing::error!("Failed to handshake TLS to {}", host); + return Ok(no_body(StatusCode::BAD_REQUEST)); + }; + let Ok((sender, conn)) = + client::conn::http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(TokioIo::new(server)) + .await + else { + tracing::error!("Failed to handshake HTTP to {}", host); + return Ok(no_body(StatusCode::BAD_REQUEST)); + }; + + tokio::spawn(conn.with_upgrades()); + sender + } else { + let Ok((sender, conn)) = + client::conn::http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(TokioIo::new(server)) + .await + else { + tracing::error!("Failed to handshake HTTP to {}", host); + return Ok(no_body(StatusCode::BAD_REQUEST)); + }; + + tokio::spawn(conn.with_upgrades()); + sender + }; + remove_authority(&mut req); let (req, req_parts) = dup_request(req); - let (res, res_upgrade) = - match sender.lock().await.send_request(req).await { - Ok(res) => { - let (res, res_upgrade, res_middleman) = - dup_response(res); - let _ = res_tx.send(Ok(res_middleman)); - (res, res_upgrade) - } - Err(err) => { - let _ = res_tx.send(Err(err)); - return Ok::<_, hyper::Error>(no_body( - StatusCode::INTERNAL_SERVER_ERROR, - )); - } - }; + let (res, res_upgrade) = match sender.send_request(req).await { + Ok(res) => { + let (res, res_upgrade, res_middleman) = + dup_response(res); + let _ = res_tx.send(Ok(res_middleman)); + (res, res_upgrade) + } + Err(err) => { + let _ = res_tx.send(Err(err)); + return Ok::<_, hyper::Error>(no_body( + StatusCode::INTERNAL_SERVER_ERROR, + )); + } + }; if res.status() == StatusCode::SWITCHING_PROTOCOLS { tokio::task::spawn(async move {