aboutsummaryrefslogtreecommitdiffstats
path: root/src/banlist.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/banlist.rs')
-rw-r--r--src/banlist.rs207
1 files changed, 207 insertions, 0 deletions
diff --git a/src/banlist.rs b/src/banlist.rs
new file mode 100644
index 0000000..f01e31a
--- /dev/null
+++ b/src/banlist.rs
@@ -0,0 +1,207 @@
+//! Ban list for misbehaving peers.
+//!
+//! Tracks failure counts per peer. After exceeding a
+//! threshold, the peer is temporarily banned. Bans
+//! expire automatically after a configurable duration.
+
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::time::{Duration, Instant};
+
+/// Default number of failures before banning a peer.
+const DEFAULT_BAN_THRESHOLD: u32 = 3;
+
+/// Default ban duration (3 hours, matching gonode).
+const DEFAULT_BAN_DURATION: Duration = Duration::from_secs(3 * 3600);
+
+/// Tracks failures and bans for peers.
+pub struct BanList {
+ /// Failure counts per address.
+ failures: HashMap<SocketAddr, FailureEntry>,
+ /// Active bans: address → expiry time.
+ bans: HashMap<SocketAddr, Instant>,
+ /// Number of failures before a ban is applied.
+ threshold: u32,
+ /// How long a ban lasts.
+ ban_duration: Duration,
+}
+
+struct FailureEntry {
+ count: u32,
+ last_failure: Instant,
+}
+
+impl BanList {
+ pub fn new() -> Self {
+ Self {
+ failures: HashMap::new(),
+ bans: HashMap::new(),
+ threshold: DEFAULT_BAN_THRESHOLD,
+ ban_duration: DEFAULT_BAN_DURATION,
+ }
+ }
+
+ /// Set the failure threshold before banning.
+ pub fn set_threshold(&mut self, threshold: u32) {
+ self.threshold = threshold;
+ }
+
+ /// Set the ban duration.
+ pub fn set_ban_duration(&mut self, duration: Duration) {
+ self.ban_duration = duration;
+ }
+
+ /// Check if a peer is currently banned.
+ pub fn is_banned(&self, addr: &SocketAddr) -> bool {
+ if let Some(expiry) = self.bans.get(addr) {
+ Instant::now() < *expiry
+ } else {
+ false
+ }
+ }
+
+ /// Record a failure for a peer. Returns true if the
+ /// peer was just banned (crossed the threshold).
+ pub fn record_failure(&mut self, addr: SocketAddr) -> bool {
+ let entry = self.failures.entry(addr).or_insert(FailureEntry {
+ count: 0,
+ last_failure: Instant::now(),
+ });
+ entry.count += 1;
+ entry.last_failure = Instant::now();
+
+ if entry.count >= self.threshold {
+ let expiry = Instant::now() + self.ban_duration;
+ self.bans.insert(addr, expiry);
+ self.failures.remove(&addr);
+ log::info!(
+ "Banned peer {addr} for {}s",
+ self.ban_duration.as_secs()
+ );
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Clear failure count for a peer (e.g. after a
+ /// successful interaction).
+ pub fn record_success(&mut self, addr: &SocketAddr) {
+ self.failures.remove(addr);
+ }
+
+ /// Remove expired bans and stale failure entries.
+ /// Call periodically from the event loop.
+ pub fn cleanup(&mut self) {
+ let now = Instant::now();
+ self.bans.retain(|_, expiry| now < *expiry);
+
+ // Clear failure entries older than ban_duration
+ // (stale failures shouldn't accumulate forever)
+ self.failures
+ .retain(|_, e| e.last_failure.elapsed() < self.ban_duration);
+ }
+
+ /// Number of currently active bans.
+ pub fn ban_count(&self) -> usize {
+ self.bans
+ .iter()
+ .filter(|(_, e)| Instant::now() < **e)
+ .count()
+ }
+
+ /// Number of peers with recorded failures.
+ pub fn failure_count(&self) -> usize {
+ self.failures.len()
+ }
+}
+
+impl Default for BanList {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn addr(port: u16) -> SocketAddr {
+ SocketAddr::from(([127, 0, 0, 1], port))
+ }
+
+ #[test]
+ fn not_banned_initially() {
+ let bl = BanList::new();
+ assert!(!bl.is_banned(&addr(1000)));
+ }
+
+ #[test]
+ fn ban_after_threshold() {
+ let mut bl = BanList::new();
+ let a = addr(1000);
+ assert!(!bl.record_failure(a));
+ assert!(!bl.record_failure(a));
+ assert!(bl.record_failure(a)); // 3rd failure → banned
+ assert!(bl.is_banned(&a));
+ }
+
+ #[test]
+ fn success_clears_failures() {
+ let mut bl = BanList::new();
+ let a = addr(1000);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ bl.record_success(&a);
+ // Failures cleared, next failure starts over
+ assert!(!bl.record_failure(a));
+ assert!(!bl.is_banned(&a));
+ }
+
+ #[test]
+ fn ban_expires() {
+ let mut bl = BanList::new();
+ bl.set_ban_duration(Duration::from_millis(1));
+ let a = addr(1000);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ assert!(bl.is_banned(&a));
+ std::thread::sleep(Duration::from_millis(5));
+ assert!(!bl.is_banned(&a));
+ }
+
+ #[test]
+ fn cleanup_removes_expired() {
+ let mut bl = BanList::new();
+ bl.set_ban_duration(Duration::from_millis(1));
+ let a = addr(1000);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ std::thread::sleep(Duration::from_millis(5));
+ bl.cleanup();
+ assert_eq!(bl.ban_count(), 0);
+ }
+
+ #[test]
+ fn custom_threshold() {
+ let mut bl = BanList::new();
+ bl.set_threshold(1);
+ let a = addr(1000);
+ assert!(bl.record_failure(a)); // 1 failure → banned
+ assert!(bl.is_banned(&a));
+ }
+
+ #[test]
+ fn independent_peers() {
+ let mut bl = BanList::new();
+ let a = addr(1000);
+ let b = addr(2000);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ bl.record_failure(a);
+ assert!(bl.is_banned(&a));
+ assert!(!bl.is_banned(&b));
+ }
+}