diff --git a/.cargo/config.toml b/.cargo/config.toml index 50cdadf..63d6138 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,6 @@ [build] -target="wasm32-wasi" +target = "wasm32-wasi" +rustflags = "--cfg tokio_unstable" [target.wasm32-wasi] -runner = "wasmedge" \ No newline at end of file +runner = "wasmedge" diff --git a/client-https/Cargo.toml b/client-https/Cargo.toml index c9fef62..5f3eb82 100644 --- a/client-https/Cargo.toml +++ b/client-https/Cargo.toml @@ -6,9 +6,14 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -hyper_wasi = { version = "0.15", features = ["full"]} -http-body-util = "0.1.0-rc.2" -tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]} +hyper = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] } pretty_env_logger = "0.4.0" -wasmedge_rustls_api = { version = "0.1", features = [ "tokio_async" ] } -wasmedge_hyper_rustls = "0.1.0" + +wasmedge_wasi_socket = "0.5" +pin-project = "1.1.3" +http-body-util = "0.1.0" + +tokio-rustls = "0.25.0" +webpki-roots = "0.26.0" +rustls = "0.22.2" diff --git a/client-https/src/main.rs b/client-https/src/main.rs index 4630c90..8bf6212 100644 --- a/client-https/src/main.rs +++ b/client-https/src/main.rs @@ -1,30 +1,165 @@ -use hyper::Client; +#![deny(warnings)] +#![warn(rust_2018_idioms)] -type Result = std::result::Result>; +// use tokio::io::{self, AsyncWriteExt as _}; + +use std::{ + os::fd::{FromRawFd, IntoRawFd}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use http_body_util::{BodyExt, Empty}; +use hyper::{body::Bytes, Request}; + +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; + +type MainResult = std::result::Result>; #[tokio::main(flavor = "current_thread")] -async fn main() { +async fn main() -> MainResult<()> { + pretty_env_logger::init(); + let url = "https://httpbin.org/get?msg=WasmEdge" .parse::() .unwrap(); - fetch_https_url(url).await.unwrap(); + fetch_https_url(url).await +} + +use pin_project::pin_project; +use tokio_rustls::TlsConnector; + +#[pin_project] +#[derive(Debug)] +struct TokioIo { + #[pin] + inner: T, +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + #[allow(dead_code)] + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } } -async fn fetch_https_url(url: hyper::Uri) -> Result<()> { - let https = wasmedge_hyper_rustls::connector::new_https_connector( - wasmedge_rustls_api::ClientConfig::default(), - ); - let client = Client::builder().build::<_, hyper::Body>(https); +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +async fn fetch_https_url(url: hyper::Uri) -> MainResult<()> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(443); + let addr = format!("{}:{}", host, port); + + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let connector = TlsConnector::from(Arc::new(config)); + let stream = unsafe { + let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd(); + TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))? + }; + + let domain = ServerName::try_from(host.to_string()).unwrap(); + let stream = connector.connect(domain, stream).await.unwrap(); + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let authority = url.authority().unwrap().clone(); + + let req = Request::builder() + .uri(url) + .header(hyper::header::HOST, authority.as_str()) + .body(Empty::::new())?; - let res = client.get(url).await?; + let mut res = sender.send_request(req).await?; println!("Response: {}", res.status()); println!("Headers: {:#?}\n", res.headers()); - let body = hyper::body::to_bytes(res.into_body()).await.unwrap(); - println!("{}", String::from_utf8(body.into()).unwrap()); + let mut resp_data = Vec::new(); + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + resp_data.extend_from_slice(&chunk); + } + } - println!("\n\nDone!"); + println!("{}", String::from_utf8_lossy(&resp_data)); Ok(()) } diff --git a/client/Cargo.toml b/client/Cargo.toml index 4648009..73768e0 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -4,12 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] -hyper_wasi = { version = "0.15", features = ["full"] } -tokio_wasi = { version = "1", features = [ - "rt", - "macros", - "net", - "time", - "io-util", -] } +hyper = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] } pretty_env_logger = "0.4.0" +wasmedge_wasi_socket = "0.5" +pin-project = "1.1.3" +http-body-util = "0.1.0" diff --git a/client/src/main.rs b/client/src/main.rs index c30973a..2aaed55 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,14 +1,24 @@ #![deny(warnings)] #![warn(rust_2018_idioms)] -use hyper::{body::HttpBody as _, Client}; -use hyper::{Body, Method, Request}; + // use tokio::io::{self, AsyncWriteExt as _}; +use std::{ + os::fd::{FromRawFd, IntoRawFd}, + pin::Pin, + task::{Context, Poll}, +}; + +use http_body_util::{BodyExt, Empty}; +use hyper::{body::Bytes, Request}; + +use tokio::net::TcpStream; + // A simple type alias so as to DRY. -type Result = std::result::Result>; +type MainResult = std::result::Result>; #[tokio::main(flavor = "current_thread")] -async fn main() -> Result<()> { +async fn main() -> MainResult<()> { pretty_env_logger::init(); let url_str = "http://eu.httpbin.org/get?msg=Hello"; @@ -24,7 +34,7 @@ async fn main() -> Result<()> { let url_str = "http://eu.httpbin.org/get?msg=WasmEdge"; println!("\nGET and get result as string: {}", url_str); let url = url_str.parse::().unwrap(); - fetch_url_return_str(url).await?; + fetch_url(url).await?; // tokio::time::sleep(std::time::Duration::from_secs(5)).await; let url_str = "http://eu.httpbin.org/post"; @@ -35,51 +45,168 @@ async fn main() -> Result<()> { post_url_return_str(url, post_body_str.as_bytes()).await } -async fn fetch_url(url: hyper::Uri) -> Result<()> { - let client = Client::new(); - let mut res = client.get(url).await?; +use pin_project::pin_project; - println!("Response: {}", res.status()); - println!("Headers: {:#?}\n", res.headers()); +#[pin_project] +#[derive(Debug)] +struct TokioIo { + #[pin] + inner: T, +} - // Stream the body, writing each chunk to stdout as we get it - // (instead of buffering and printing at the end). - while let Some(next) = res.data().await { - let chunk = next?; - println!("{:#?}", chunk); - // io::stdout().write_all(&chunk).await?; +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } } - Ok(()) + #[allow(dead_code)] + pub fn inner(self) -> T { + self.inner + } } -async fn fetch_url_return_str(url: hyper::Uri) -> Result<()> { - let client = Client::new(); - let mut res = client.get(url).await?; +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +async fn fetch_url(url: hyper::Uri) -> MainResult<()> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + let stream = unsafe { + let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd(); + TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))? + }; + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let authority = url.authority().unwrap().clone(); + + let req = Request::builder() + .uri(url) + .header(hyper::header::HOST, authority.as_str()) + .body(Empty::::new())?; + + let mut res = sender.send_request(req).await?; + + println!("Response: {}", res.status()); + println!("Headers: {:#?}\n", res.headers()); let mut resp_data = Vec::new(); - while let Some(next) = res.data().await { - let chunk = next?; - resp_data.extend_from_slice(&chunk); + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + resp_data.extend_from_slice(&chunk); + } } + println!("{}", String::from_utf8_lossy(&resp_data)); Ok(()) } -async fn post_url_return_str(url: hyper::Uri, post_body: &'static [u8]) -> Result<()> { - let client = Client::new(); +async fn post_url_return_str(url: hyper::Uri, post_body: &'static [u8]) -> MainResult<()> { + let host = url.host().expect("uri has no host"); + let port = url.port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + let stream = unsafe { + let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd(); + TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))? + }; + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let authority = url.authority().unwrap().clone(); + let req = Request::builder() - .method(Method::POST) .uri(url) - .body(Body::from(post_body))?; - let mut res = client.request(req).await?; + .method("POST") + .header(hyper::header::HOST, authority.as_str()) + .body(http_body_util::Full::new(post_body))?; + + let mut res = sender.send_request(req).await?; + + println!("Response: {}", res.status()); + println!("Headers: {:#?}\n", res.headers()); let mut resp_data = Vec::new(); - while let Some(next) = res.data().await { - let chunk = next?; - resp_data.extend_from_slice(&chunk); + while let Some(next) = res.frame().await { + let frame = next?; + if let Some(chunk) = frame.data_ref() { + resp_data.extend_from_slice(&chunk); + } } + println!("{}", String::from_utf8_lossy(&resp_data)); Ok(()) diff --git a/server-tflite/Cargo.toml b/server-tflite/Cargo.toml index 1ef1994..2745ab1 100644 --- a/server-tflite/Cargo.toml +++ b/server-tflite/Cargo.toml @@ -4,8 +4,22 @@ version = "0.1.0" edition = "2021" [dependencies] -hyper_wasi = { version = "0.15", features = ["full"]} -tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]} -image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "tiff", "webp", "bmp"] } +hyper = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] } +wasmedge_wasi_socket = "0.5" + +image = { version = "0.23.14", default-features = false, features = [ + "gif", + "jpeg", + "ico", + "png", + "tiff", + "webp", + "bmp", +] } wasi-nn = "0.4.0" anyhow = "1.0" + +pin-project = "1.1.3" +http-body-util = "0.1.0" +bytes = "1" diff --git a/server-tflite/src/main.rs b/server-tflite/src/main.rs index 44d2355..b0dc153 100644 --- a/server-tflite/src/main.rs +++ b/server-tflite/src/main.rs @@ -1,19 +1,29 @@ -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, StatusCode, Server}; -use std::convert::Infallible; -use std::net::SocketAddr; -use std::result::Result; -use std::io::Cursor; +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Method, Request, Response, StatusCode}; use image::io::Reader; use image::DynamicImage; -use wasi_nn::{GraphBuilder, GraphEncoding, ExecutionTarget, TensorType}; +use std::io::Cursor; +use std::net::SocketAddr; +use std::os::fd::{FromRawFd, IntoRawFd}; +use std::pin::Pin; +use std::result::Result; +use std::task::{Context, Poll}; +use tokio::net::TcpListener; +use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding, TensorType}; /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. -async fn classify(req: Request) -> Result, anyhow::Error> { - let model_data: &[u8] = include_bytes!("models/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_quant.tflite"); +async fn classify( + req: Request, +) -> Result>, anyhow::Error> { + let model_data: &[u8] = + include_bytes!("models/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_quant.tflite"); let labels = include_str!("models/mobilenet_v1_1.0_224/labels_mobilenet_quant_v1_224.txt"); - let graph = GraphBuilder::new(GraphEncoding::TensorflowLite, ExecutionTarget::CPU).build_from_bytes(&[model_data])?; + let graph = GraphBuilder::new(GraphEncoding::TensorflowLite, ExecutionTarget::CPU) + .build_from_bytes(&[model_data])?; let mut ctx = graph.init_execution_context()?; /* let graph = unsafe { @@ -29,14 +39,15 @@ async fn classify(req: Request) -> Result, anyhow::Error> { match (req.method(), req.uri().path()) { // Serve some instructions at / - (&Method::GET, "/") => Ok(Response::new(Body::from( + (&Method::GET, "/") => Ok(Response::new(full( "Try POSTing data to /classify such as: `curl http://localhost:3000/classify -X POST --data-binary '@grace_hopper.jpg'`", ))), (&Method::POST, "/classify") => { - let buf = hyper::body::to_bytes(req.into_body()).await?; + let buf = req.collect().await?.to_bytes(); + let tensor_data = image_to_tensor(&buf, 224, 224); - ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?; + ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?; /* let tensor = wasi_nn::Tensor { dimensions: &[1, 224, 224, 3], @@ -79,7 +90,7 @@ async fn classify(req: Request) -> Result, anyhow::Error> { let class_name = labels.lines().nth(results[0].0).unwrap_or("Unknown"); println!("result: {} {}", class_name, results[0].1); - Ok(Response::new(Body::from(format!("{} is detected with {}/255 confidence", class_name, results[0].1)))) + Ok(Response::new(full(format!("{} is detected with {}/255 confidence", class_name, results[0].1)))) } // Return the 404 Not Found for other routes. @@ -91,21 +102,114 @@ async fn classify(req: Request) -> Result, anyhow::Error> { } } +use pin_project::pin_project; + +#[pin_project] +#[derive(Debug)] +struct TokioIo { + #[pin] + inner: T, +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + #[allow(dead_code)] + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let addr = SocketAddr::from(([0, 0, 0, 0], 8080)); - let make_svc = make_service_fn(|_| { - async move { - Ok::<_, Infallible>(service_fn(move |req| { - classify(req) - })) - } - }); - let server = Server::bind(&addr).serve(make_svc); - if let Err(e) = server.await { - eprintln!("server error: {}", e); + let listener = unsafe { + let fd = wasmedge_wasi_socket::TcpListener::bind(addr, true)?.into_raw_fd(); + TcpListener::from_std(std::net::TcpListener::from_raw_fd(fd))? + }; + + loop { + let (stream, _) = listener.accept().await?; + println!("accept"); + let io = TokioIo::new(stream); + + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection(io, service_fn(classify)) + .await + { + println!("Error serving connection: {:?}", err); + } + }); } - Ok(()) /* let addr = SocketAddr::from(([0, 0, 0, 0], 8080)); diff --git a/server/Cargo.toml b/server/Cargo.toml index 0c78db6..c01ee17 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -4,5 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] -hyper_wasi = { version = "0.15", features = ["full"]} -tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]} +hyper = { version = "1", features = ["full"] } +tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] } +wasmedge_wasi_socket = "0.5" + +pin-project = "1.1.3" +http-body-util = "0.1.0" +bytes = "1" diff --git a/server/src/main.rs b/server/src/main.rs index f3e5d48..af52a60 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,49 +1,191 @@ +#![deny(warnings)] + use std::net::SocketAddr; +use std::os::fd::{FromRawFd, IntoRawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; -use hyper::server::conn::Http; +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; +use hyper::body::Frame; +use hyper::server::conn::http1; use hyper::service::service_fn; -use hyper::{Body, Method, Request, Response, StatusCode}; +use hyper::{body::Body, Method, Request, Response, StatusCode}; +use pin_project::pin_project; use tokio::net::TcpListener; +#[pin_project] +#[derive(Debug)] +struct TokioIo { + #[pin] + inner: T, +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + #[allow(dead_code)] + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. -async fn echo(req: Request) -> Result, hyper::Error> { +async fn echo( + req: Request, +) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { // Serve some instructions at / - (&Method::GET, "/") => Ok(Response::new(Body::from( - "Try POSTing data to /echo such as: `curl localhost:8080/echo -XPOST -d 'hello world'`", + (&Method::GET, "/") => Ok(Response::new(full( + "Try POSTing data to /echo such as: `curl localhost:3000/echo -XPOST -d \"hello world\"`", ))), // Simply echo the body back to the client. - (&Method::POST, "/echo") => Ok(Response::new(req.into_body())), + (&Method::POST, "/echo") => Ok(Response::new(req.into_body().boxed())), + + // Convert to uppercase before sending back to client using a stream. + (&Method::POST, "/echo/uppercase") => { + let frame_stream = req.into_body().map_frame(|frame| { + let frame = if let Ok(data) = frame.into_data() { + data.iter() + .map(|byte| byte.to_ascii_uppercase()) + .collect::() + } else { + Bytes::new() + }; + + Frame::data(frame) + }); + Ok(Response::new(frame_stream.boxed())) + } + + // Reverse the entire body before sending back to the client. + // + // Since we don't know the end yet, we can't simply stream + // the chunks as they arrive as we did with the above uppercase endpoint. + // So here we do `.await` on the future, waiting on concatenating the full body, + // then afterwards the content can be reversed. Only then can we return a `Response`. (&Method::POST, "/echo/reversed") => { - let whole_body = hyper::body::to_bytes(req.into_body()).await?; + // To protect our server, reject requests with bodies larger than + // 64kbs of data. + let max = req.body().size_hint().upper().unwrap_or(u64::MAX); + if max > 1024 * 64 { + let mut resp = Response::new(full("Body too big")); + *resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE; + return Ok(resp); + } + + let whole_body = req.collect().await?.to_bytes(); let reversed_body = whole_body.iter().rev().cloned().collect::>(); - Ok(Response::new(Body::from(reversed_body))) + Ok(Response::new(full(reversed_body))) } // Return the 404 Not Found for other routes. _ => { - let mut not_found = Response::default(); + let mut not_found = Response::new(empty()); *not_found.status_mut() = StatusCode::NOT_FOUND; Ok(not_found) } } } +fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - let addr = SocketAddr::from(([0, 0, 0, 0], 8080)); + let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); + + let listener = unsafe { + let fd = wasmedge_wasi_socket::TcpListener::bind(addr, true)?.into_raw_fd(); + TcpListener::from_std(std::net::TcpListener::from_raw_fd(fd))? + }; - let listener = TcpListener::bind(addr).await?; println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + println!("accept"); + let io = TokioIo::new(stream); tokio::task::spawn(async move { - if let Err(err) = Http::new().serve_connection(stream, service_fn(echo)).await { + if let Err(err) = http1::Builder::new() + .serve_connection(io, service_fn(echo)) + .await + { println!("Error serving connection: {:?}", err); } });