diff --git a/Cargo.toml b/Cargo.toml index e96312588..a5650cbee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -124,7 +124,7 @@ tokio-native-tls = "0.3" tokio-rustls = {version = "0.26", default-features = false } tokio-openssl = "0.6" tokio-stream = { version = "0.1", default-features = false } -tokio-tungstenite = { version = "0.24", default-features = false } +tokio-tungstenite = { version = "0.25", default-features = false } tokio-util = "0.7" tower = { version = "0.5", default-features = false } tracing-subscriber = { version = "0.3" } diff --git a/crates/extra/src/websocket.rs b/crates/extra/src/websocket.rs index de9676ac7..fcf4ead7b 100644 --- a/crates/extra/src/websocket.rs +++ b/crates/extra/src/websocket.rs @@ -29,7 +29,7 @@ //! // client disconnected //! return; //! }; -//! +//! //! if ws.send(msg).await.is_err() { //! // client disconnected //! return; @@ -47,11 +47,11 @@ //! #[tokio::main] //! async fn main() { //! let router = Router::new().get(index).push(Router::with_path("ws").goal(connect)); -//! +//! //! let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; //! Server::new(acceptor).serve(router).await; //! } -//! +//! //! static INDEX_HTML: &str = r#" //! //! @@ -81,21 +81,22 @@ use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll, ready}; +use std::task::{ready, Context, Poll}; use futures_util::sink::{Sink, SinkExt}; use futures_util::stream::{Stream, StreamExt}; use futures_util::{future, FutureExt, TryFutureExt}; use hyper::upgrade::OnUpgrade; use salvo_core::http::header::{SEC_WEBSOCKET_VERSION, UPGRADE}; -use salvo_core::http::headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade}; +use salvo_core::http::headers::{ + Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade, +}; use salvo_core::http::{StatusCode, StatusError}; use salvo_core::rt::tokio::TokioIo; use salvo_core::{Error, Request, Response}; -use tokio_tungstenite::{ - tungstenite::protocol::{self, WebSocketConfig}, - WebSocketStream, -}; +use tokio_tungstenite::tungstenite::protocol::frame::{Payload, Utf8Payload}; +use tokio_tungstenite::tungstenite::protocol::{self, WebSocketConfig}; +use tokio_tungstenite::WebSocketStream; /// Creates a WebSocket Handler. /// Request: @@ -132,7 +133,9 @@ impl WebSocketUpgrade { /// Create new `WebSocketUpgrade` with config. #[inline] pub fn with_config(config: WebSocketConfig) -> Self { - WebSocketUpgrade { config: Some(config) } + WebSocketUpgrade { + config: Some(config), + } } /// The target minimum size of the write buffer to reach before writing the data @@ -143,7 +146,9 @@ impl WebSocketUpgrade { /// It is often more optimal to allow them to buffer a little, hence the default value. #[inline] pub fn write_buffer_size(mut self, max: usize) -> Self { - self.config.get_or_insert_with(WebSocketConfig::default).write_buffer_size = max; + self.config + .get_or_insert_with(WebSocketConfig::default) + .write_buffer_size = max; self } @@ -159,7 +164,9 @@ impl WebSocketUpgrade { /// and probably a little more depending on error handling strategy. #[inline] pub fn max_write_buffer_size(mut self, max: usize) -> Self { - self.config.get_or_insert_with(WebSocketConfig::default).max_write_buffer_size = max; + self.config + .get_or_insert_with(WebSocketConfig::default) + .max_write_buffer_size = max; self } @@ -180,7 +187,9 @@ impl WebSocketUpgrade { /// by a malicious user. #[inline] pub fn max_frame_size(mut self, max: usize) -> Self { - self.config.get_or_insert_with(WebSocketConfig::default).max_frame_size = Some(max); + self.config + .get_or_insert_with(WebSocketConfig::default) + .max_frame_size = Some(max); self } @@ -191,13 +200,19 @@ impl WebSocketUpgrade { /// By default this option is set to `false`, i.e. according to RFC 6455. #[inline] pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { - self.config.get_or_insert_with(WebSocketConfig::default).accept_unmasked_frames = accept; + self.config + .get_or_insert_with(WebSocketConfig::default) + .accept_unmasked_frames = accept; self } - /// Upgrade websocket request. - pub async fn upgrade(&self, req: &mut Request, res: &mut Response, callback: F) -> Result<(), StatusError> + pub async fn upgrade( + &self, + req: &mut Request, + res: &mut Response, + callback: F, + ) -> Result<(), StatusError> where F: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future + Send + 'static, @@ -218,7 +233,8 @@ impl WebSocketUpgrade { .unwrap_or(false); if !matched { tracing::debug!("missing upgrade header or it is not equal websocket"); - return Err(StatusError::bad_request().brief("Missing upgrade header or it is not equal websocket.")); + return Err(StatusError::bad_request() + .brief("Missing upgrade header or it is not equal websocket.")); } let matched = !req_headers .get(SEC_WEBSOCKET_VERSION) @@ -233,14 +249,16 @@ impl WebSocketUpgrade { key } else { tracing::debug!("sec_websocket_key is not exist in request headers"); - return Err(StatusError::bad_request().brief("sec_websocket_key is not exist in request headers.")); + return Err(StatusError::bad_request() + .brief("sec_websocket_key is not exist in request headers.")); }; res.status_code(StatusCode::SWITCHING_PROTOCOLS); res.headers_mut().typed_insert(Connection::upgrade()); res.headers_mut().typed_insert(Upgrade::websocket()); - res.headers_mut().typed_insert(SecWebsocketAccept::from(sec_ws_key)); + res.headers_mut() + .typed_insert(SecWebsocketAccept::from(sec_ws_key)); if let Some(on_upgrade) = req.extensions_mut().remove::() { let config = self.config; @@ -257,7 +275,8 @@ impl WebSocketUpgrade { Ok(()) } else { tracing::debug!("websocket couldn't be upgraded since no upgrade state was present"); - Err(StatusError::bad_request().brief("Websocket couldn't be upgraded since no upgrade state was present.")) + Err(StatusError::bad_request() + .brief("Websocket couldn't be upgraded since no upgrade state was present.")) } } } @@ -326,22 +345,30 @@ impl Sink for WebSocket { #[inline] fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::other) + Pin::new(&mut self.inner) + .poll_ready(cx) + .map_err(Error::other) } #[inline] fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - Pin::new(&mut self.inner).start_send(item.inner).map_err(Error::other) + Pin::new(&mut self.inner) + .start_send(item.inner) + .map_err(Error::other) } #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::other) + Pin::new(&mut self.inner) + .poll_flush(cx) + .map_err(Error::other) } #[inline] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx).map_err(Error::other) + Pin::new(&mut self.inner) + .poll_close(cx) + .map_err(Error::other) } } @@ -364,7 +391,7 @@ pub struct Message { impl Message { /// Construct a new Text `Message`. #[inline] - pub fn text>(s: S) -> Message { + pub fn text>(s: S) -> Message { Message { inner: protocol::Message::text(s), } @@ -372,7 +399,7 @@ impl Message { /// Construct a new Binary `Message`. #[inline] - pub fn binary>>(v: V) -> Message { + pub fn binary>(v: V) -> Message { Message { inner: protocol::Message::binary(v), } @@ -380,7 +407,7 @@ impl Message { /// Construct a new Ping `Message`. #[inline] - pub fn ping>>(v: V) -> Message { + pub fn ping>(v: V) -> Message { Message { inner: protocol::Message::Ping(v.into()), } @@ -388,7 +415,7 @@ impl Message { /// Construct a new Pong `Message`. #[inline] - pub fn pong>>(v: V) -> Message { + pub fn pong>(v: V) -> Message { Message { inner: protocol::Message::Pong(v.into()), } @@ -455,9 +482,9 @@ impl Message { /// Try to get a reference to the string text, if this is a Text message. #[inline] - pub fn to_str(&self) -> Result<&str, Error> { - match self.inner { - protocol::Message::Text(ref s) => Ok(s), + pub fn as_str(&self) -> Result<&str, Error> { + match &self.inner { + protocol::Message::Text(s) => Ok(s.as_str()), _ => Err(Error::Other("not a text message".into())), } } @@ -465,21 +492,15 @@ impl Message { /// Returns the bytes of this message, if the message can contain data. #[inline] pub fn as_bytes(&self) -> &[u8] { - match self.inner { - protocol::Message::Text(ref s) => s.as_bytes(), - protocol::Message::Binary(ref v) => v, - protocol::Message::Ping(ref v) => v, - protocol::Message::Pong(ref v) => v, + match &self.inner { + protocol::Message::Text(s) => s.as_slice(), + protocol::Message::Binary(v) => v.as_slice(), + protocol::Message::Ping(v) => v.as_slice(), + protocol::Message::Pong(v) => v.as_slice(), protocol::Message::Close(_) => &[], - protocol::Message::Frame(ref v) => v.payload(), + protocol::Message::Frame(v) => v.payload(), } } - - /// Destructure this message into binary data. - #[inline] - pub fn into_bytes(self) -> Vec { - self.inner.into_data() - } } impl Debug for Message { @@ -493,7 +514,7 @@ impl Debug for Message { impl Into> for Message { #[inline] fn into(self) -> Vec { - self.into_bytes() + self.as_bytes().into() } } @@ -529,7 +550,11 @@ mod tests { async fn test_websocket() { let router = Router::new().goal(connect); let acceptor = TcpListener::new("127.0.0.1:0").bind().await; - let addr = acceptor.holdings()[0].local_addr.clone().into_std().unwrap(); + let addr = acceptor.holdings()[0] + .local_addr + .clone() + .into_std() + .unwrap(); tokio::spawn(async move { Server::new(acceptor).serve(router).await; @@ -537,7 +562,9 @@ mod tests { let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); - let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream)).await.unwrap(); + let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(stream)) + .await + .unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/otel-jaeger/src/server1.rs b/examples/otel-jaeger/src/server1.rs index 7e3462298..1aab9f9ab 100644 --- a/examples/otel-jaeger/src/server1.rs +++ b/examples/otel-jaeger/src/server1.rs @@ -24,11 +24,10 @@ fn init_tracer_provider() -> TracerProvider { .expect("failed to create exporter"); TracerProvider::builder() .with_batch_exporter(exporter, runtime::Tokio) - .with_config( - opentelemetry_sdk::trace::Config::default().with_resource(Resource::new(vec![ - KeyValue::new("service.name", "server1"), - ])), - ) + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + "server1", + )])) .build() } diff --git a/examples/otel-jaeger/src/server2.rs b/examples/otel-jaeger/src/server2.rs index 4aaebaf53..7ca0e431d 100644 --- a/examples/otel-jaeger/src/server2.rs +++ b/examples/otel-jaeger/src/server2.rs @@ -15,11 +15,10 @@ fn init_tracer_provider() -> TracerProvider { .build() .expect("failed to create exporter"); opentelemetry_sdk::trace::TracerProvider::builder() - .with_config( - opentelemetry_sdk::trace::Config::default().with_resource(Resource::new(vec![ - KeyValue::new("service.name", "server2"), - ])), - ) + .with_resource(Resource::new(vec![KeyValue::new( + "service.name", + "server2", + )])) .with_batch_exporter(exporter, runtime::Tokio) .build() } diff --git a/examples/websocket-chat/src/main.rs b/examples/websocket-chat/src/main.rs index 04a736b1c..e21950e63 100644 --- a/examples/websocket-chat/src/main.rs +++ b/examples/websocket-chat/src/main.rs @@ -73,7 +73,7 @@ async fn handle_socket(ws: WebSocket) { tokio::task::spawn(fut); } async fn user_message(my_id: usize, msg: Message) { - let msg = if let Ok(s) = msg.to_str() { + let msg = if let Ok(s) = msg.as_str() { s } else { return;