Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support modifing Upgraded connection #60

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ A HTTP proxy server library intended to be a backend of application like Burp pr
use std::path::PathBuf;

use clap::{Args, Parser};
use futures::StreamExt;
use http_mitm_proxy::{default_client::Upgrade, DefaultClient, MitmProxy};
use http_mitm_proxy::{DefaultClient, MitmProxy};
use moka::sync::Cache;
use tracing_subscriber::EnvFilter;

Expand Down
97 changes: 60 additions & 37 deletions examples/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::path::PathBuf;

use clap::{Args, Parser};
use futures::StreamExt;
use http_mitm_proxy::{
default_client::{websocket, Upgrade},
default_client::{websocket, Upgraded},
DefaultClient, MitmProxy,
};
use moka::sync::Cache;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing_subscriber::EnvFilter;
use winnow::Parser as _;

#[derive(Parser)]
struct Opt {
Expand Down Expand Up @@ -49,6 +50,7 @@ async fn main() {

tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_line_number(true)
.init();

let root_cert = if let Some(external_cert) = opt.external_cert {
Expand Down Expand Up @@ -77,7 +79,7 @@ async fn main() {
Some(Cache::new(128)),
);

let client = DefaultClient::new().unwrap();
let client = DefaultClient::new().unwrap().with_upgrades();
let server = proxy
.bind(("127.0.0.1", 3003), move |_client_addr, req| {
let client = client.clone();
Expand All @@ -96,49 +98,70 @@ async fn main() {

// You can try https://echo.websocket.org/.ws to test websocket.
println!("Upgrade connection");
let Upgrade {
mut client_to_server,
mut server_to_client,
} = upgrade;
let url = uri.to_string();

tokio::spawn(async move {
let mut buf = Vec::new();
while let Some(data) = client_to_server.next().await {
buf.extend(data);
let Upgraded { client, server } = upgrade.await.unwrap().unwrap();
let url = uri.to_string();

let (mut client_rx, mut client_tx) = tokio::io::split(client);
let (mut server_rx, mut server_tx) = tokio::io::split(server);

let url0 = url.clone();
let client_to_server = async move {
let mut buf = Vec::new();

loop {
let input = &mut buf.as_slice();
if let Ok(frame) = websocket::frame(input) {
println!(
"Client -> Server: {} {:?}",
url,
String::from_utf8_lossy(&frame.payload_data)
);
buf = input.to_vec();
} else {
if client_rx.read_buf(&mut buf).await.unwrap() == 0 {
break;
}
loop {
let input = &mut buf.as_slice();
if let Ok((frame, read)) =
websocket::frame.with_taken().parse_next(input)
{
println!(
"{} Client: {}",
&url0,
String::from_utf8_lossy(&frame.payload_data)
);
server_tx.write_all(read).await.unwrap();
buf = input.to_vec();
} else {
break;
}
}
}
}
});
let url = uri.to_string();
tokio::spawn(async move {
let mut buf = Vec::new();
while let Some(data) = server_to_client.next().await {
buf.extend(data);
};

let url0 = url.clone();
let server_to_client = async move {
let mut buf = Vec::new();

loop {
let input = &mut buf.as_slice();
if let Ok(frame) = websocket::frame(input) {
println!(
"Server -> Client: {} {:?}",
url,
String::from_utf8_lossy(&frame.payload_data)
);
buf = input.to_vec();
} else {
if server_rx.read_buf(&mut buf).await.unwrap() == 0 {
break;
}
loop {
let input = &mut buf.as_slice();
if let Ok((frame, read)) =
websocket::frame.with_taken().parse_next(input)
{
println!(
"{} Server: {}",
&url0,
String::from_utf8_lossy(&frame.payload_data)
);
client_tx.write_all(read).await.unwrap();
buf = input.to_vec();
} else {
break;
}
}
}
}
};

tokio::spawn(client_to_server);
tokio::spawn(server_to_client);
});
}

Expand Down
133 changes: 55 additions & 78 deletions src/default_client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use bytes::Bytes;
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
use http_body_util::Empty;
use hyper::{
body::{Body, Incoming},
client, header, Request, Response, StatusCode, Uri, Version,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::task::{Context, Poll};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio::{net::TcpStream, task::JoinHandle};

#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand All @@ -28,18 +24,21 @@ pub enum Error {
TlsConnectError(Uri, native_tls::Error),
}

/// Upgraded connection
pub struct Upgrade {
/// Client to server traffic
pub client_to_server: UnboundedReceiver<Vec<u8>>,
/// Server to client traffic
pub server_to_client: UnboundedReceiver<Vec<u8>>,
/// Upgraded connections
pub struct Upgraded {
/// A socket to Client
pub client: TokioIo<hyper::upgrade::Upgraded>,
/// A socket to Server
pub server: TokioIo<hyper::upgrade::Upgraded>,
}
#[derive(Clone)]
/// Default HTTP client for this crate
pub struct DefaultClient {
tls_connector_no_alpn: tokio_native_tls::TlsConnector,
tls_connector_alpn_h2: tokio_native_tls::TlsConnector,
/// If true, send_request will returns an Upgraded struct when the response is an upgrade
/// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
pub with_upgrades: bool,
}
impl DefaultClient {
pub fn new() -> native_tls::Result<Self> {
Expand All @@ -51,9 +50,17 @@ impl DefaultClient {
Ok(Self {
tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
with_upgrades: false,
})
}

/// Enable HTTP upgrades
/// If you don't enable HTTP upgrades, send_request will just copy bidirectional when the response is an upgrade
pub fn with_upgrades(mut self) -> Self {
self.with_upgrades = true;
self
}

fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector {
match http_version {
Version::HTTP_2 => &self.tls_connector_alpn_h2,
Expand All @@ -67,7 +74,13 @@ impl DefaultClient {
pub async fn send_request<B>(
&self,
req: Request<B>,
) -> Result<(Response<Incoming>, Option<Upgrade>), Error>
) -> Result<
(
Response<Incoming>,
Option<JoinHandle<Result<Upgraded, hyper::Error>>>,
),
Error,
>
where
B: Body + Unpin + Send + 'static,
B::Data: Send,
Expand All @@ -82,39 +95,41 @@ impl DefaultClient {
.await?;

if res.status() == StatusCode::SWITCHING_PROTOCOLS {
let (tx_client, rx_client) = futures::channel::mpsc::unbounded();
let (tx_server, rx_server) = futures::channel::mpsc::unbounded();

let (res_parts, res_body) = res.into_parts();

let res0 = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());
tokio::task::spawn(async move {
if let (Ok(client), Ok(server)) = (
hyper::upgrade::on(Request::from_parts(req_parts, Empty::<Bytes>::new())).await,
hyper::upgrade::on(res0).await,
) {
upgrade(
TokioIo::new(client),
TokioIo::new(server),
tx_client,
tx_server,
let client_request = Request::from_parts(req_parts, Empty::<Bytes>::new());
let server_response = Response::from_parts(res_parts.clone(), Empty::<Bytes>::new());

let upgrade = if self.with_upgrades {
Some(tokio::task::spawn(async move {
let client = hyper::upgrade::on(client_request).await?;
let server = hyper::upgrade::on(server_response).await?;

Ok(Upgraded {
client: TokioIo::new(client),
server: TokioIo::new(server),
})
}))
} else {
tokio::task::spawn(async move {
let client = hyper::upgrade::on(client_request).await?;
let server = hyper::upgrade::on(server_response).await?;

let _ = tokio::io::copy_bidirectional(
&mut TokioIo::new(client),
&mut TokioIo::new(server),
)
.await;
} else {
tracing::error!("Failed to upgrade connection (HTTP)");
}
});

return Ok((
Response::from_parts(res_parts, res_body),
Some(Upgrade {
client_to_server: rx_client,
server_to_client: rx_server,
}),
));
}

Ok((res, None))
Ok::<_, hyper::Error>(())
});
None
};

Ok((Response::from_parts(res_parts, res_body), upgrade))
} else {
Ok((res, None))
}
}

async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
Expand Down Expand Up @@ -228,44 +243,6 @@ impl<B> SendRequest<B> {
}
}

async fn upgrade<
S1: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
S2: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
>(
client: S1,
server: S2,
tx_client: UnboundedSender<Vec<u8>>,
tx_server: UnboundedSender<Vec<u8>>,
) {
let (mut client_read, mut client_write) = tokio::io::split(client);
let (mut server_read, mut server_write) = tokio::io::split(server);

tokio::spawn(async move {
loop {
let mut buf = vec![];
let n = client_read.read_buf(&mut buf).await?;
if n == 0 {
break;
}
server_write.write_all(&buf).await?;
let _ = tx_client.unbounded_send(buf);
}
Ok::<(), std::io::Error>(())
});
tokio::spawn(async move {
loop {
let mut buf = vec![];
let n = server_read.read_buf(&mut buf).await?;
if n == 0 {
break;
}
client_write.write_all(&buf).await?;
let _ = tx_server.unbounded_send(buf);
}
Ok::<(), std::io::Error>(())
});
}

fn remove_authority<B>(req: &mut Request<B>) {
let mut parts = req.uri().clone().into_parts();
parts.scheme = None;
Expand Down
Loading