1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
//! Per-IP rate limiting with token bucket algorithm.
//!
//! Limits the number of inbound messages processed per
//! source IP address per second. Prevents a single
//! source from overwhelming the node.
use std::collections::HashMap;
use std::net::IpAddr;
use std::time::Instant;
/// Default rate: 50 messages per second per IP.
pub const DEFAULT_RATE: f64 = 50.0;
/// Default burst: 100 messages.
pub const DEFAULT_BURST: u32 = 100;
/// Stale bucket cleanup threshold.
const STALE_SECS: u64 = 60;
struct Bucket {
tokens: f64,
last_refill: Instant,
}
/// Per-IP token bucket rate limiter.
pub struct RateLimiter {
buckets: HashMap<IpAddr, Bucket>,
rate: f64,
burst: u32,
last_cleanup: Instant,
}
impl RateLimiter {
/// Create a new rate limiter.
///
/// - `rate`: tokens added per second per IP.
/// - `burst`: maximum tokens (burst capacity).
pub fn new(rate: f64, burst: u32) -> Self {
Self {
buckets: HashMap::new(),
rate,
burst,
last_cleanup: Instant::now(),
}
}
/// Check if a message from `ip` should be allowed.
///
/// Returns `true` if allowed (token consumed),
/// `false` if rate-limited (drop the message).
pub fn allow(&mut self, ip: IpAddr) -> bool {
let now = Instant::now();
let burst = self.burst as f64;
let rate = self.rate;
let bucket = self.buckets.entry(ip).or_insert(Bucket {
tokens: burst,
last_refill: now,
});
// Refill tokens based on elapsed time
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * rate).min(burst);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
/// Remove stale buckets (no activity for 60s).
pub fn cleanup(&mut self) {
if self.last_cleanup.elapsed().as_secs() < STALE_SECS {
return;
}
self.last_cleanup = Instant::now();
let cutoff = Instant::now();
self.buckets.retain(|_, b| {
cutoff.duration_since(b.last_refill).as_secs() < STALE_SECS
});
}
/// Number of tracked IPs.
pub fn tracked_count(&self) -> usize {
self.buckets.len()
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(DEFAULT_RATE, DEFAULT_BURST)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allow_within_burst() {
let mut rl = RateLimiter::new(10.0, 5);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
// First 5 should be allowed (burst)
for _ in 0..5 {
assert!(rl.allow(ip));
}
// 6th should be denied
assert!(!rl.allow(ip));
}
#[test]
fn different_ips_independent() {
let mut rl = RateLimiter::new(1.0, 1);
let ip1: IpAddr = "1.2.3.4".parse().unwrap();
let ip2: IpAddr = "5.6.7.8".parse().unwrap();
assert!(rl.allow(ip1));
assert!(rl.allow(ip2));
// Both exhausted
assert!(!rl.allow(ip1));
assert!(!rl.allow(ip2));
}
#[test]
fn tracked_count() {
let mut rl = RateLimiter::default();
let ip: IpAddr = "1.2.3.4".parse().unwrap();
rl.allow(ip);
assert_eq!(rl.tracked_count(), 1);
}
}
|