From fea27de25e01db81c7b67fe9d6da331ec2474b21 Mon Sep 17 00:00:00 2001 From: Luis Herasme Date: Wed, 24 Jul 2024 18:21:05 -0400 Subject: [PATCH] feat: Add rate limit layer --- ghost-crab/src/cache/cache_layer.rs | 2 +- ghost-crab/src/cache/manager.rs | 13 +++- ghost-crab/src/lib.rs | 4 +- ghost-crab/src/rate_limit.rs | 112 ++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 ghost-crab/src/rate_limit.rs diff --git a/ghost-crab/src/cache/cache_layer.rs b/ghost-crab/src/cache/cache_layer.rs index 8a71f5d..e0649ac 100644 --- a/ghost-crab/src/cache/cache_layer.rs +++ b/ghost-crab/src/cache/cache_layer.rs @@ -76,7 +76,7 @@ fn contains_invalid_word(input: &[u8]) -> bool { } fn cacheable_request(request: &SerializedRequest) -> bool { - if !matches!(request.method(), "eth_getBlockByNumber") { + if !matches!(request.method(), "eth_getBlockByNumber" | "eth_getLogs" | "eth_call") { return false; } diff --git a/ghost-crab/src/cache/manager.rs b/ghost-crab/src/cache/manager.rs index 7d7ad76..45b934a 100644 --- a/ghost-crab/src/cache/manager.rs +++ b/ghost-crab/src/cache/manager.rs @@ -3,12 +3,16 @@ use alloy::providers::RootProvider; use alloy::rpc::client::ClientBuilder; use alloy::transports::http::{Client, Http}; use std::collections::HashMap; +use std::time::Duration; + +use crate::rate_limit::RateLimit; +use crate::rate_limit::RateLimitLayer; use super::cache::load_cache; use super::cache_layer::CacheLayer; use super::cache_layer::CacheService; -pub type CacheProvider = RootProvider>>; +pub type CacheProvider = RootProvider>>>; pub struct RPCManager { rpcs: HashMap, @@ -26,8 +30,13 @@ impl RPCManager { let cache = load_cache(&network).unwrap(); let cache_layer = CacheLayer::new(cache); + let rate_limit_layer = RateLimitLayer::new(10_000, Duration::from_secs(1)); + + let client = ClientBuilder::default() + .layer(cache_layer) + .layer(rate_limit_layer) + .http(rpc_url.parse().unwrap()); - let client = ClientBuilder::default().layer(cache_layer).http(rpc_url.parse().unwrap()); let provider = ProviderBuilder::new().on_client(client); self.rpcs.insert(rpc_url.clone(), provider.clone()); diff --git a/ghost-crab/src/lib.rs b/ghost-crab/src/lib.rs index e9e7fc0..66e79dd 100644 --- a/ghost-crab/src/lib.rs +++ b/ghost-crab/src/lib.rs @@ -2,8 +2,10 @@ pub mod block_handler; pub mod cache; pub mod event_handler; pub mod indexer; -pub mod latest_block_manager; pub mod prelude; pub use ghost_crab_common::config; pub use indexer::Indexer; + +mod latest_block_manager; +mod rate_limit; diff --git a/ghost-crab/src/rate_limit.rs b/ghost-crab/src/rate_limit.rs new file mode 100644 index 0000000..cc16db2 --- /dev/null +++ b/ghost-crab/src/rate_limit.rs @@ -0,0 +1,112 @@ +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::time::{Instant, Sleep}; +use tower::Layer; +use tower::Service; + +#[derive(Debug, Copy, Clone)] +pub struct Rate { + limit: u64, + period: Duration, +} + +/// Enforces a rate limit on the number of requests the underlying +/// service can handle over a period of time. +#[derive(Debug, Clone)] +pub struct RateLimitLayer { + rate: Rate, +} + +impl RateLimitLayer { + /// Create new rate limit layer. + pub fn new(limit: u64, period: Duration) -> Self { + let rate = Rate { limit, period }; + RateLimitLayer { rate } + } +} + +impl Layer for RateLimitLayer { + type Service = RateLimit; + + fn layer(&self, service: S) -> Self::Service { + RateLimit::new(service, self.rate) + } +} + +/// Enforces a rate limit on the number of requests the underlying +/// service can handle over a period of time. +#[derive(Debug, Clone)] +pub struct RateLimit { + inner: T, + rate: Rate, + state: Arc>, +} + +#[derive(Debug)] +struct State { + until: Instant, + reserved: u64, + timer: Pin>, +} + +impl RateLimit { + /// Create a new rate limiter + pub fn new(inner: T, rate: Rate) -> Self { + let until = Instant::now() + rate.period; + + let state = Arc::new(Mutex::new(State { + until, + reserved: rate.limit, + timer: Box::pin(tokio::time::sleep_until(until)), + })); + + RateLimit { inner, rate, state } + } +} + +impl Service for RateLimit +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { + let now = Instant::now(); + let mut state = self.state.lock().unwrap(); + + if now >= state.until { + state.until = now + self.rate.period; + state.reserved = 0; + state.timer.as_mut().reset(now + self.rate.period); + } + + if state.reserved >= self.rate.limit { + ctx.waker().wake_by_ref(); + let _ = state.timer.as_mut().poll(ctx); + return Poll::Pending; + } + + match self.inner.poll_ready(ctx) { + Poll::Ready(value) => { + state.reserved += 1; + Poll::Ready(value) + } + Poll::Pending => { + ctx.waker().wake_by_ref(); + let _ = state.timer.as_mut().poll(ctx); + Poll::Pending + } + } + } + + fn call(&mut self, request: Request) -> Self::Future { + self.inner.call(request) + } +}