diff options
Diffstat (limited to 'src/banlist.rs')
| -rw-r--r-- | src/banlist.rs | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/src/banlist.rs b/src/banlist.rs new file mode 100644 index 0000000..f01e31a --- /dev/null +++ b/src/banlist.rs @@ -0,0 +1,207 @@ +//! Ban list for misbehaving peers. +//! +//! Tracks failure counts per peer. After exceeding a +//! threshold, the peer is temporarily banned. Bans +//! expire automatically after a configurable duration. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +/// Default number of failures before banning a peer. +const DEFAULT_BAN_THRESHOLD: u32 = 3; + +/// Default ban duration (3 hours, matching gonode). +const DEFAULT_BAN_DURATION: Duration = Duration::from_secs(3 * 3600); + +/// Tracks failures and bans for peers. +pub struct BanList { + /// Failure counts per address. + failures: HashMap<SocketAddr, FailureEntry>, + /// Active bans: address → expiry time. + bans: HashMap<SocketAddr, Instant>, + /// Number of failures before a ban is applied. + threshold: u32, + /// How long a ban lasts. + ban_duration: Duration, +} + +struct FailureEntry { + count: u32, + last_failure: Instant, +} + +impl BanList { + pub fn new() -> Self { + Self { + failures: HashMap::new(), + bans: HashMap::new(), + threshold: DEFAULT_BAN_THRESHOLD, + ban_duration: DEFAULT_BAN_DURATION, + } + } + + /// Set the failure threshold before banning. + pub fn set_threshold(&mut self, threshold: u32) { + self.threshold = threshold; + } + + /// Set the ban duration. + pub fn set_ban_duration(&mut self, duration: Duration) { + self.ban_duration = duration; + } + + /// Check if a peer is currently banned. + pub fn is_banned(&self, addr: &SocketAddr) -> bool { + if let Some(expiry) = self.bans.get(addr) { + Instant::now() < *expiry + } else { + false + } + } + + /// Record a failure for a peer. Returns true if the + /// peer was just banned (crossed the threshold). + pub fn record_failure(&mut self, addr: SocketAddr) -> bool { + let entry = self.failures.entry(addr).or_insert(FailureEntry { + count: 0, + last_failure: Instant::now(), + }); + entry.count += 1; + entry.last_failure = Instant::now(); + + if entry.count >= self.threshold { + let expiry = Instant::now() + self.ban_duration; + self.bans.insert(addr, expiry); + self.failures.remove(&addr); + log::info!( + "Banned peer {addr} for {}s", + self.ban_duration.as_secs() + ); + true + } else { + false + } + } + + /// Clear failure count for a peer (e.g. after a + /// successful interaction). + pub fn record_success(&mut self, addr: &SocketAddr) { + self.failures.remove(addr); + } + + /// Remove expired bans and stale failure entries. + /// Call periodically from the event loop. + pub fn cleanup(&mut self) { + let now = Instant::now(); + self.bans.retain(|_, expiry| now < *expiry); + + // Clear failure entries older than ban_duration + // (stale failures shouldn't accumulate forever) + self.failures + .retain(|_, e| e.last_failure.elapsed() < self.ban_duration); + } + + /// Number of currently active bans. + pub fn ban_count(&self) -> usize { + self.bans + .iter() + .filter(|(_, e)| Instant::now() < **e) + .count() + } + + /// Number of peers with recorded failures. + pub fn failure_count(&self) -> usize { + self.failures.len() + } +} + +impl Default for BanList { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], port)) + } + + #[test] + fn not_banned_initially() { + let bl = BanList::new(); + assert!(!bl.is_banned(&addr(1000))); + } + + #[test] + fn ban_after_threshold() { + let mut bl = BanList::new(); + let a = addr(1000); + assert!(!bl.record_failure(a)); + assert!(!bl.record_failure(a)); + assert!(bl.record_failure(a)); // 3rd failure → banned + assert!(bl.is_banned(&a)); + } + + #[test] + fn success_clears_failures() { + let mut bl = BanList::new(); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_success(&a); + // Failures cleared, next failure starts over + assert!(!bl.record_failure(a)); + assert!(!bl.is_banned(&a)); + } + + #[test] + fn ban_expires() { + let mut bl = BanList::new(); + bl.set_ban_duration(Duration::from_millis(1)); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + assert!(bl.is_banned(&a)); + std::thread::sleep(Duration::from_millis(5)); + assert!(!bl.is_banned(&a)); + } + + #[test] + fn cleanup_removes_expired() { + let mut bl = BanList::new(); + bl.set_ban_duration(Duration::from_millis(1)); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + std::thread::sleep(Duration::from_millis(5)); + bl.cleanup(); + assert_eq!(bl.ban_count(), 0); + } + + #[test] + fn custom_threshold() { + let mut bl = BanList::new(); + bl.set_threshold(1); + let a = addr(1000); + assert!(bl.record_failure(a)); // 1 failure → banned + assert!(bl.is_banned(&a)); + } + + #[test] + fn independent_peers() { + let mut bl = BanList::new(); + let a = addr(1000); + let b = addr(2000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + assert!(bl.is_banned(&a)); + assert!(!bl.is_banned(&b)); + } +} |