//! Per-IP rate limiting with token bucket algorithm. //! //! Limits the number of inbound messages processed per //! source IP address per second. Prevents a single //! source from overwhelming the node. use std::collections::HashMap; use std::net::IpAddr; use std::time::Instant; /// Default rate: 50 messages per second per IP. pub const DEFAULT_RATE: f64 = 50.0; /// Default burst: 100 messages. pub const DEFAULT_BURST: u32 = 100; /// Stale bucket cleanup threshold. const STALE_SECS: u64 = 60; struct Bucket { tokens: f64, last_refill: Instant, } /// Per-IP token bucket rate limiter. pub struct RateLimiter { buckets: HashMap, rate: f64, burst: u32, last_cleanup: Instant, } impl RateLimiter { /// Create a new rate limiter. /// /// - `rate`: tokens added per second per IP. /// - `burst`: maximum tokens (burst capacity). pub fn new(rate: f64, burst: u32) -> Self { Self { buckets: HashMap::new(), rate, burst, last_cleanup: Instant::now(), } } /// Check if a message from `ip` should be allowed. /// /// Returns `true` if allowed (token consumed), /// `false` if rate-limited (drop the message). pub fn allow(&mut self, ip: IpAddr) -> bool { let now = Instant::now(); let burst = self.burst as f64; let rate = self.rate; let bucket = self.buckets.entry(ip).or_insert(Bucket { tokens: burst, last_refill: now, }); // Refill tokens based on elapsed time let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); bucket.tokens = (bucket.tokens + elapsed * rate).min(burst); bucket.last_refill = now; if bucket.tokens >= 1.0 { bucket.tokens -= 1.0; true } else { false } } /// Remove stale buckets (no activity for 60s). pub fn cleanup(&mut self) { if self.last_cleanup.elapsed().as_secs() < STALE_SECS { return; } self.last_cleanup = Instant::now(); let cutoff = Instant::now(); self.buckets.retain(|_, b| { cutoff.duration_since(b.last_refill).as_secs() < STALE_SECS }); } /// Number of tracked IPs. pub fn tracked_count(&self) -> usize { self.buckets.len() } } impl Default for RateLimiter { fn default() -> Self { Self::new(DEFAULT_RATE, DEFAULT_BURST) } } #[cfg(test)] mod tests { use super::*; #[test] fn allow_within_burst() { let mut rl = RateLimiter::new(10.0, 5); let ip: IpAddr = "1.2.3.4".parse().unwrap(); // First 5 should be allowed (burst) for _ in 0..5 { assert!(rl.allow(ip)); } // 6th should be denied assert!(!rl.allow(ip)); } #[test] fn different_ips_independent() { let mut rl = RateLimiter::new(1.0, 1); let ip1: IpAddr = "1.2.3.4".parse().unwrap(); let ip2: IpAddr = "5.6.7.8".parse().unwrap(); assert!(rl.allow(ip1)); assert!(rl.allow(ip2)); // Both exhausted assert!(!rl.allow(ip1)); assert!(!rl.allow(ip2)); } #[test] fn tracked_count() { let mut rl = RateLimiter::default(); let ip: IpAddr = "1.2.3.4".parse().unwrap(); rl.allow(ip); assert_eq!(rl.tracked_count(), 1); } }