Skip to content

Commit

Permalink
middleware - various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jondot committed Sep 26, 2024
1 parent 732a059 commit 24f9bb8
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 96 deletions.
20 changes: 11 additions & 9 deletions src/controller/middleware/catch_panic.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
//! Catch Panic Middleware for Axum
//!
//! This middleware catches panics that occur during request handling in the Axum application.
//! When a panic occurs, it logs the error and returns an internal server error response.
//! This middleware helps ensure that the application can gracefully handle unexpected errors
//! without crashing the server.
//! This middleware catches panics that occur during request handling in the
//! Axum application. When a panic occurs, it logs the error and returns an
//! internal server error response. This middleware helps ensure that the
//! application can gracefully handle unexpected errors without crashing the
//! server.
use axum::Router as AXRouter;
use serde::{Deserialize, Serialize};
use tower_http::catch_panic::CatchPanicLayer;

use crate::{
app::AppContext,
controller::{middleware::MiddlewareLayer, IntoResponse},
errors, Result,
};
use axum::Router as AXRouter;
use serde::{Deserialize, Serialize};
use tower_http::catch_panic::CatchPanicLayer;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CatchPanic {
Expand Down Expand Up @@ -42,7 +44,7 @@ fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response:
impl MiddlewareLayer for CatchPanic {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"catch panic"
"catch_panic"
}

/// Returns whether the middleware is enabled or not
Expand All @@ -59,7 +61,6 @@ impl MiddlewareLayer for CatchPanic {
#[cfg(test)]
mod tests {

use crate::tests_cfg;
use axum::{
body::Body,
http::{Method, Request, StatusCode},
Expand All @@ -69,6 +70,7 @@ mod tests {
use tower::ServiceExt;

use super::*;
use crate::tests_cfg;

#[tokio::test]
async fn panic_enabled() {
Expand Down
9 changes: 5 additions & 4 deletions src/controller/middleware/fallback.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Fallback Middleware
//!
//! This middleware handles fallback logic for the application when routes do not match. It serves
//! a file, a custom not-found message, or a default HTML fallback page based on the configuration.
//! This middleware handles fallback logic for the application when routes do
//! not match. It serves a file, a custom not-found message, or a default HTML
//! fallback page based on the configuration.
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
use axum::{http::StatusCode, response::Html, Router as AXRouter};
use serde::{Deserialize, Serialize};
use tower_http::services::ServeFile;

use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};

#[derive(Default, Debug, Clone, Deserialize, Serialize)]
pub struct Fallback {
/// By default when enabled, returns a prebaked 404 not found page optimized
Expand Down Expand Up @@ -58,7 +60,6 @@ impl MiddlewareLayer for Fallback {
let content = include_str!("fallback.html");
app.fallback(move || async move { (code, Html(content)) })
};
tracing::info!("[Middleware] +fallback");
Ok(app)
}
}
22 changes: 13 additions & 9 deletions src/controller/middleware/limit_payload.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
//! Limit Payload Middleware
//!
//! This middleware restricts the maximum allowed size for HTTP request payloads. It is configurable
//! based on the [`LimitPayloadMiddleware`] settings in the application's middleware configuration.
//! The middleware sets a limit on the request body size using Axum's `DefaultBodyLimit` layer.
//! This middleware restricts the maximum allowed size for HTTP request
//! payloads. It is configurable based on the [`LimitPayloadMiddleware`]
//! settings in the application's middleware configuration. The middleware sets
//! a limit on the request body size using Axum's `DefaultBodyLimit` layer.
//!
//! # Note
//!
//! Ensure that the `body: axum::body::Bytes` variable is properly set in the request action to
//! enforce the payload limit correctly. Without this, the middleware will not function as intended.
//! Ensure that the `body: axum::body::Bytes` variable is properly set in the
//! request action to enforce the payload limit correctly. Without this, the
//! middleware will not function as intended.
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
use axum::Router as AXRouter;
use serde::{Deserialize, Deserializer, Serialize};

use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LimitPayload {
pub enable: bool,
Expand All @@ -25,7 +28,7 @@ impl Default for LimitPayload {
fn default() -> Self {
Self {
enable: true,
body_limit: 1024,
body_limit: 2_000_000,
}
}
}
Expand All @@ -44,15 +47,16 @@ where
impl MiddlewareLayer for LimitPayload {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"limit payload"
"limit_payload"
}

/// Returns whether the middleware is enabled or not
fn is_enabled(&self) -> bool {
self.enable
}

/// Applies the payload limit middleware to the application router by adding a `DefaultBodyLimit` layer.
/// Applies the payload limit middleware to the application router by adding
/// a `DefaultBodyLimit` layer.
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(axum::extract::DefaultBodyLimit::max(self.body_limit)))
}
Expand Down
49 changes: 21 additions & 28 deletions src/controller/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Base Middleware for Loco Application
//!
//! This module defines the various middleware components that Loco provides.
//! Each middleware is responsible for handling different aspects of request processing, such as
//! authentication, logging, CORS, compression, and error handling. The middleware can be easily
//! configured and applied to the application's router.
//! Each middleware is responsible for handling different aspects of request
//! processing, such as authentication, logging, CORS, compression, and error
//! handling. The middleware can be easily configured and applied to the
//! application's router.
#[cfg(all(feature = "auth_jwt", feature = "with-db"))]
pub mod auth;
Expand All @@ -21,10 +22,11 @@ pub mod request_id;
pub mod secure_headers;
pub mod static_assets;
pub mod timeout;
use crate::{app::AppContext, Result};
use axum::Router as AXRouter;
use serde::{Deserialize, Serialize};

use crate::{app::AppContext, Result};

/// Trait representing the behavior of middleware components in the application.
pub trait MiddlewareLayer {
/// Returns the name of the middleware.
Expand All @@ -35,18 +37,21 @@ pub trait MiddlewareLayer {
true
}

/// Applies the middleware to the given Axum router and returns the modified router.
/// Applies the middleware to the given Axum router and returns the modified
/// router.
///
/// # Errors
///
/// If there is an issue when adding the middleware to the router.
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>>;
}

/// Constructs a default stack of middleware for the Axum application based on the provided context.
/// Constructs a default stack of middleware for the Axum application based on
/// the provided context.
///
/// This function initializes and returns a vector of middleware components that are commonly used
/// in the application. Each middleware is created using its respective `new` function and
/// This function initializes and returns a vector of middleware components that
/// are commonly used in the application. Each middleware is created using its
/// respective `new` function and
#[must_use]
pub fn default_middleware_stack(ctx: &AppContext) -> Vec<Box<dyn MiddlewareLayer>> {
vec![
Expand All @@ -72,41 +77,29 @@ pub fn default_middleware_stack(ctx: &AppContext) -> Vec<Box<dyn MiddlewareLayer
/// Server middleware configuration structure.
#[derive(Default, Debug, Clone, Deserialize, Serialize)]
pub struct Config {
/// Middleware that enable compression for the response.
#[serde(default)]
/// Compression for the response.
pub compression: compression::Compression,
/// Middleware that enable etag cache headers.
#[serde(default)]
/// Etag cache headers.
pub etag: etag::Etag,
/// Middleware that limit the payload request.
#[serde(default)]
/// Limit the payload request.
pub limit_payload: limit_payload::LimitPayload,
/// Middleware that improve the tracing logger and adding trace id for each
/// request.
#[serde(default)]
/// Logger and augmenting trace id with request data
pub logger: logger::Config,
/// catch any code panic and log the error.
#[serde(default)]
/// Catch any code panic and log the error.
pub catch_panic: catch_panic::CatchPanic,
/// Setting a global timeout for the requests
#[serde(default)]
/// Setting a global timeout for requests
pub timeout_request: timeout::TimeOut,
/// Setting cors configuration
#[serde(default)]
/// CORS configuration
pub cors: cors::Cors,
/// Serving static assets
#[serde(rename = "static")]
#[serde(default)]
pub static_assets: static_assets::StaticAssets,
/// Sets a set of secure headers
#[serde(default)]
pub secure_headers: secure_headers::SecureHeader,
/// Calculates a remote IP based on `X-Forwarded-For` when behind a proxy
#[serde(default)]
pub remote_ip: remote_ip::RemoteIpMiddleware,
/// Configure fallback behavior when hitting a missing URL
#[serde(default)]
pub fallback: fallback::Fallback,
#[serde(default)]
/// Request ID
pub request_id: request_id::RequestId,
}
34 changes: 19 additions & 15 deletions src/controller/middleware/remote_ip.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
//! Remote IP Middleware for inferring the client's IP address based on the `X-Forwarded-For` header.
//! Remote IP Middleware for inferring the client's IP address based on the
//! `X-Forwarded-For` header.
//!
//! This middleware is useful when running behind proxies or load balancers that add the
//! `X-Forwarded-For` header, which includes the original client IP address.
//! This middleware is useful when running behind proxies or load balancers that
//! add the `X-Forwarded-For` header, which includes the original client IP
//! address.
//!
//! The middleware provides a mechanism to configure trusted proxies and extract the most likely
//! client IP from the `X-Forwarded-For` header, skipping any trusted proxy IPs.
//!
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
//! The middleware provides a mechanism to configure trusted proxies and extract
//! the most likely client IP from the `X-Forwarded-For` header, skipping any
//! trusted proxy IPs.
use std::{
fmt,
iter::Iterator,
net::{IpAddr, SocketAddr},
str::FromStr,
task::{Context, Poll},
};

use async_trait::async_trait;
use axum::{
body::Body,
Expand All @@ -20,16 +29,11 @@ use hyper::HeaderMap;
use ipnetwork::IpNetwork;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::{
fmt,
iter::Iterator,
net::{IpAddr, SocketAddr},
str::FromStr,
task::{Context, Poll},
};
use tower::{Layer, Service};
use tracing::error;

use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};

lazy_static! {
// matching what Rails does is probably a smart idea:
// https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/remote_ip.rb#L40
Expand Down Expand Up @@ -98,7 +102,7 @@ pub struct RemoteIpMiddleware {
impl MiddlewareLayer for RemoteIpMiddleware {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"remote IP"
"remote_ip"
}

/// Returns whether the middleware is enabled or not
Expand Down
20 changes: 12 additions & 8 deletions src/controller/middleware/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//! The request ID is stored in the `x-request-id` header, and it is either
//! generated or sanitized if already present in the request.
//!
//! This can be useful for tracking requests across services, logging, and debugging.
//! This can be useful for tracking requests across services, logging, and
//! debugging.
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
use axum::{
extract::Request, http::HeaderValue, middleware::Next, response::Response, Router as AXRouter,
};
Expand All @@ -13,6 +13,8 @@ use regex::Regex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};

const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;

Expand All @@ -34,7 +36,7 @@ impl Default for RequestId {
impl MiddlewareLayer for RequestId {
/// Returns the name of the middleware
fn name(&self) -> &'static str {
"request id"
"request_id"
}

/// Returns whether the middleware is enabled or not
Expand All @@ -44,8 +46,9 @@ impl MiddlewareLayer for RequestId {

/// Applies the request ID middleware to the Axum router.
///
/// This function sets up the middleware in the router and ensures that every
/// request passing through it will have a unique or sanitized request ID.
/// This function sets up the middleware in the router and ensures that
/// every request passing through it will have a unique or sanitized
/// request ID.
///
/// # Errors
/// This function returns an error if the middleware cannot be applied.
Expand All @@ -68,9 +71,10 @@ impl LocoRequestId {

/// Middleware function to ensure or generate a unique request ID.
///
/// This function intercepts requests, checks for the presence of the `x-request-id`
/// header, and either sanitizes its value or generates a new UUID if absent.
/// The resulting request ID is added to both the request extensions and the response headers.
/// This function intercepts requests, checks for the presence of the
/// `x-request-id` header, and either sanitizes its value or generates a new
/// UUID if absent. The resulting request ID is added to both the request
/// extensions and the response headers.
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
Expand Down
Loading

0 comments on commit 24f9bb8

Please sign in to comment.