//! Kademlia DHT logic: store, find_node, find_value. //! //! Uses explicit state machines for iterative queries //! instead of nested callbacks. use std::collections::{HashMap, HashSet}; use std::time::{Duration, Instant}; use crate::id::NodeId; use crate::peers::PeerInfo; use crate::routing::NUM_FIND_NODE; // ── Constants ──────────────────────────────────────── /// Max parallel queries per lookup. pub const MAX_QUERY: usize = 6; /// Timeout for a single RPC query (seconds). pub const QUERY_TIMEOUT: Duration = Duration::from_secs(3); /// Interval between data restore cycles. pub const RESTORE_INTERVAL: Duration = Duration::from_secs(120); /// Slow maintenance timer (refresh + restore). pub const SLOW_TIMER_INTERVAL: Duration = Duration::from_secs(600); /// Fast maintenance timer (expire + sweep). pub const FAST_TIMER_INTERVAL: Duration = Duration::from_secs(60); /// Number of original replicas for a put. pub const ORIGINAL_PUT_NUM: i32 = 3; /// Timeout waiting for all values in find_value. pub const RECVD_VALUE_TIMEOUT: Duration = Duration::from_secs(3); /// RDP port for store operations. pub const RDP_STORE_PORT: u16 = 100; /// RDP port for get operations. pub const RDP_GET_PORT: u16 = 101; /// RDP connection timeout. pub const RDP_TIMEOUT: Duration = Duration::from_secs(30); // ── Stored data ───────────────────────────────────── /// A single stored value with metadata. #[derive(Debug, Clone)] pub struct StoredValue { pub key: Vec, pub value: Vec, pub id: NodeId, pub source: NodeId, pub ttl: u16, pub stored_at: Instant, pub is_unique: bool, /// Number of original puts remaining. Starts at /// `ORIGINAL_PUT_NUM` for originator, 0 for replicas. pub original: i32, /// Set of node IDs that already received this value /// during restore, to avoid duplicate sends. pub recvd: HashSet, /// Monotonic version timestamp. Newer versions /// (higher value) replace older ones from the same /// source. Prevents stale replicas from overwriting /// fresh data during restore/republish. pub version: u64, } /// Generate a monotonic version number based on the /// current time (milliseconds since epoch, truncated /// to u64). Sufficient for conflict resolution — two /// stores in the same millisecond will have the same /// version (tie-break: last write wins). pub fn now_version() -> u64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_millis() as u64) .unwrap_or(0) } impl StoredValue { /// Check if this value has expired. pub fn is_expired(&self) -> bool { self.stored_at.elapsed() >= Duration::from_secs(self.ttl as u64) } /// Remaining TTL in seconds. pub fn remaining_ttl(&self) -> u16 { let elapsed = self.stored_at.elapsed().as_secs(); if elapsed >= self.ttl as u64 { 0 } else { (self.ttl as u64 - elapsed) as u16 } } } // ── Storage container ─────────────────────────────── /// Key for the two-level storage map: /// first level is the target NodeId (SHA1 of key), /// second level is the raw key bytes. #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct StorageKey { raw: Vec, } /// Container for stored DHT values. /// /// Maps target_id -> raw_key -> set of values. /// Maps target_id -> raw_key -> set of values. pub struct DhtStorage { data: HashMap>>, /// Maximum number of stored values (0 = unlimited). max_entries: usize, } /// Default maximum storage entries. const DEFAULT_MAX_STORAGE: usize = 65536; impl DhtStorage { pub fn new() -> Self { Self { data: HashMap::new(), max_entries: DEFAULT_MAX_STORAGE, } } /// Set the maximum number of stored values. pub fn set_max_entries(&mut self, max: usize) { self.max_entries = max; } /// Store a value. Handles `is_unique` semantics and /// version-based conflict resolution. /// /// If a value with the same key from the same source /// already exists with a higher version, the store is /// rejected (prevents stale replicas from overwriting /// fresh data). pub fn store(&mut self, val: StoredValue) { // Enforce storage limit if self.max_entries > 0 && self.len() >= self.max_entries { log::warn!( "Storage full ({} entries), dropping store", self.max_entries ); return; } let key = StorageKey { raw: val.key.clone(), }; let entry = self.data.entry(val.id).or_default().entry(key).or_default(); if val.is_unique { // Unique: replace existing from same source, // but only if version is not older if let Some(pos) = entry.iter().position(|v| v.source == val.source) { if val.version > 0 && entry[pos].version > 0 && val.version < entry[pos].version { log::debug!( "Rejecting stale unique store: v{} < v{}", val.version, entry[pos].version, ); return; } entry[pos] = val; } else if entry.is_empty() || !entry[0].is_unique { entry.clear(); entry.push(val); } return; } // Non-unique: update if same value exists (any // source), or append. Check version on update. if let Some(pos) = entry.iter().position(|v| v.value == val.value) { if val.version > 0 && entry[pos].version > 0 && val.version < entry[pos].version { log::debug!( "Rejecting stale store: v{} < v{}", val.version, entry[pos].version, ); return; } entry[pos].ttl = val.ttl; entry[pos].stored_at = val.stored_at; entry[pos].version = val.version; } else { // Don't add if existing data is unique if entry.len() == 1 && entry[0].is_unique { return; } entry.push(val); } } /// Remove a specific value. pub fn remove(&mut self, id: &NodeId, raw_key: &[u8]) { let key = StorageKey { raw: raw_key.to_vec(), }; if let Some(inner) = self.data.get_mut(id) { inner.remove(&key); if inner.is_empty() { self.data.remove(id); } } } /// Get all values for a target ID and key. pub fn get(&self, id: &NodeId, raw_key: &[u8]) -> Vec { let key = StorageKey { raw: raw_key.to_vec(), }; self.data .get(id) .and_then(|inner| inner.get(&key)) .map(|vals| { vals.iter().filter(|v| !v.is_expired()).cloned().collect() }) .unwrap_or_default() } /// Remove all expired values. pub fn expire(&mut self) { self.data.retain(|_, inner| { inner.retain(|_, vals| { vals.retain(|v| !v.is_expired()); !vals.is_empty() }); !inner.is_empty() }); } /// Iterate over all stored values (for restore). pub fn all_values(&self) -> Vec { self.data .values() .flat_map(|inner| inner.values()) .flat_map(|vals| vals.iter()) .filter(|v| !v.is_expired()) .cloned() .collect() } /// Decrement the `original` counter for a value. /// Returns the new count, or -1 if not found /// (in which case the value is inserted). pub fn dec_original(&mut self, val: &StoredValue) -> i32 { let key = StorageKey { raw: val.key.clone(), }; if let Some(inner) = self.data.get_mut(&val.id) { if let Some(vals) = inner.get_mut(&key) { if let Some(existing) = vals .iter_mut() .find(|v| v.value == val.value && v.source == val.source) { if existing.original > 0 { existing.original -= 1; } return existing.original; } } } // Not found: insert it self.store(val.clone()); -1 } /// Mark a node as having received a stored value /// (for restore deduplication). pub fn mark_received(&mut self, val: &StoredValue, node_id: NodeId) { let key = StorageKey { raw: val.key.clone(), }; if let Some(inner) = self.data.get_mut(&val.id) { if let Some(vals) = inner.get_mut(&key) { if let Some(existing) = vals.iter_mut().find(|v| v.value == val.value) { existing.recvd.insert(node_id); } } } } /// Total number of stored values. pub fn len(&self) -> usize { self.data .values() .flat_map(|inner| inner.values()) .map(|vals| vals.len()) .sum() } pub fn is_empty(&self) -> bool { self.len() == 0 } } impl Default for DhtStorage { fn default() -> Self { Self::new() } } // ── Iterative query state machine ─────────────────── /// Phase of an iterative Kademlia query. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum QueryPhase { /// Actively searching: sending queries and processing /// replies. Searching, /// No closer nodes found in last round: query has /// converged. Converged, /// Query complete: results are ready. Done, } /// State of an iterative FIND_NODE or FIND_VALUE query. /// /// Uses an explicit state machine for iterative lookup. /// Maximum duration for an iterative query before /// returning best-effort results. pub const MAX_QUERY_DURATION: Duration = Duration::from_secs(30); pub struct IterativeQuery { pub target: NodeId, pub closest: Vec, pub queried: HashSet, pub pending: HashMap, pub phase: QueryPhase, pub is_find_value: bool, pub key: Vec, pub values: Vec>, pub nonce: u32, pub started_at: Instant, /// Number of iterative rounds completed. Each round /// is a batch of queries followed by reply processing. /// Measures the "depth" of the lookup — useful for /// diagnosing network topology and routing efficiency. pub hops: u32, } impl IterativeQuery { /// Create a new FIND_NODE query. pub fn find_node(target: NodeId, nonce: u32) -> Self { Self { target, closest: Vec::new(), queried: HashSet::new(), pending: HashMap::new(), phase: QueryPhase::Searching, is_find_value: false, key: Vec::new(), values: Vec::new(), nonce, started_at: Instant::now(), hops: 0, } } /// Create a new FIND_VALUE query. pub fn find_value(target: NodeId, key: Vec, nonce: u32) -> Self { Self { target, closest: Vec::new(), queried: HashSet::new(), pending: HashMap::new(), phase: QueryPhase::Searching, is_find_value: true, key, values: Vec::new(), nonce, started_at: Instant::now(), hops: 0, } } /// Select the next batch of peers to query. /// /// Returns up to `MAX_QUERY` un-queried peers from /// `closest`, sorted by XOR distance to target. pub fn next_to_query(&self) -> Vec { let max = MAX_QUERY.saturating_sub(self.pending.len()); self.closest .iter() .filter(|p| { !self.queried.contains(&p.id) && !self.pending.contains_key(&p.id) }) .take(max) .cloned() .collect() } /// Process a reply: merge new nodes into closest, /// remove from pending, detect convergence. /// Increments the hop counter for route length /// tracking. pub fn process_reply(&mut self, from: &NodeId, nodes: Vec) { self.pending.remove(from); self.queried.insert(*from); self.hops += 1; let prev_best = self.closest.first().map(|p| self.target.distance(&p.id)); // Merge new nodes for node in nodes { if node.id != self.target && !self.closest.iter().any(|c| c.id == node.id) { self.closest.push(node); } } // Sort by XOR distance self.closest.sort_by(|a, b| { let da = self.target.distance(&a.id); let db = self.target.distance(&b.id); da.cmp(&db) }); // Trim to NUM_FIND_NODE self.closest.truncate(NUM_FIND_NODE); // Check convergence: did the closest node change? let new_best = self.closest.first().map(|p| self.target.distance(&p.id)); if prev_best == new_best && self.pending.is_empty() { self.phase = QueryPhase::Converged; } } /// Process a value reply (for FIND_VALUE). pub fn process_value(&mut self, value: Vec) { self.values.push(value); self.phase = QueryPhase::Done; } /// Mark a peer as timed out. pub fn timeout(&mut self, id: &NodeId) { self.pending.remove(id); if self.pending.is_empty() && self.next_to_query().is_empty() { self.phase = QueryPhase::Done; } } /// Expire all pending queries that have exceeded /// the timeout. pub fn expire_pending(&mut self) { let expired: Vec = self .pending .iter() .filter(|(_, sent_at)| sent_at.elapsed() >= QUERY_TIMEOUT) .map(|(id, _)| *id) .collect(); for id in expired { self.timeout(&id); } } /// Check if the query is complete (converged, /// finished, or timed out). pub fn is_done(&self) -> bool { self.phase == QueryPhase::Done || self.phase == QueryPhase::Converged || self.started_at.elapsed() >= MAX_QUERY_DURATION } } // ── Maintenance: mask_bit exploration ─────────────── /// Systematic exploration of the 256-bit ID space. /// /// Generates target IDs for find_node queries that probe /// different regions of the network, populating distant /// k-buckets that would otherwise remain empty. /// /// Used by both DHT and DTUN maintenance. pub struct MaskBitExplorer { local_id: NodeId, mask_bit: usize, } impl MaskBitExplorer { pub fn new(local_id: NodeId) -> Self { Self { local_id, mask_bit: 1, } } /// Generate the next pair of exploration targets. /// /// Each call produces two targets by clearing specific /// bits in the local ID, then advances by 2 bits. /// After bit 20, resets to 1. pub fn next_targets(&mut self) -> (NodeId, NodeId) { let id_bytes = *self.local_id.as_bytes(); let t1 = Self::clear_bit(id_bytes, self.mask_bit); let t2 = Self::clear_bit(id_bytes, self.mask_bit + 1); self.mask_bit += 2; if self.mask_bit > 20 { self.mask_bit = 1; } (t1, t2) } /// Current mask_bit position (for testing). pub fn position(&self) -> usize { self.mask_bit } fn clear_bit( mut bytes: [u8; crate::id::ID_LEN], bit_from_msb: usize, ) -> NodeId { if bit_from_msb == 0 || bit_from_msb > crate::id::ID_BITS { return NodeId::from_bytes(bytes); } let pos = bit_from_msb - 1; // 0-indexed let byte_idx = pos / 8; let bit_idx = 7 - (pos % 8); bytes[byte_idx] &= !(1 << bit_idx); NodeId::from_bytes(bytes) } } #[cfg(test)] mod tests { use super::*; use std::net::SocketAddr; fn make_peer(byte: u8, port: u16) -> PeerInfo { PeerInfo::new( NodeId::from_bytes([byte; 32]), SocketAddr::from(([127, 0, 0, 1], port)), ) } // ── DhtStorage tests ──────────────────────────── #[test] fn storage_store_and_get() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"test-key"); let val = StoredValue { key: b"test-key".to_vec(), value: b"hello".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 3, recvd: HashSet::new(), version: 0, }; s.store(val); let got = s.get(&id, b"test-key"); assert_eq!(got.len(), 1); assert_eq!(got[0].value, b"hello"); } #[test] fn storage_unique_replaces() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"uk"); let src = NodeId::from_bytes([0x01; 32]); let v1 = StoredValue { key: b"uk".to_vec(), value: b"v1".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }; s.store(v1); let v2 = StoredValue { key: b"uk".to_vec(), value: b"v2".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }; s.store(v2); let got = s.get(&id, b"uk"); assert_eq!(got.len(), 1); assert_eq!(got[0].value, b"v2"); } #[test] fn storage_unique_rejects_other_source() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"uk"); let v1 = StoredValue { key: b"uk".to_vec(), value: b"v1".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }; s.store(v1); let v2 = StoredValue { key: b"uk".to_vec(), value: b"v2".to_vec(), id, source: NodeId::from_bytes([0x02; 32]), ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }; s.store(v2); let got = s.get(&id, b"uk"); assert_eq!(got.len(), 1); assert_eq!(got[0].value, b"v1"); } #[test] fn storage_multiple_non_unique() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"k"); for i in 0..3u8 { s.store(StoredValue { key: b"k".to_vec(), value: vec![i], id, source: NodeId::from_bytes([i; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 0, recvd: HashSet::new(), version: 0, }); } assert_eq!(s.get(&id, b"k").len(), 3); } #[test] fn storage_remove() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"k"); s.store(StoredValue { key: b"k".to_vec(), value: b"v".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 0, recvd: HashSet::new(), version: 0, }); s.remove(&id, b"k"); assert!(s.get(&id, b"k").is_empty()); assert!(s.is_empty()); } #[test] fn storage_dec_original() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"k"); let val = StoredValue { key: b"k".to_vec(), value: b"v".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 3, recvd: HashSet::new(), version: 0, }; s.store(val.clone()); assert_eq!(s.dec_original(&val), 2); assert_eq!(s.dec_original(&val), 1); assert_eq!(s.dec_original(&val), 0); assert_eq!(s.dec_original(&val), 0); // stays at 0 } #[test] fn storage_mark_received() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"k"); let val = StoredValue { key: b"k".to_vec(), value: b"v".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 0, recvd: HashSet::new(), version: 0, }; s.store(val.clone()); let node = NodeId::from_bytes([0x42; 32]); s.mark_received(&val, node); let got = s.get(&id, b"k"); assert!(got[0].recvd.contains(&node)); } // ── IterativeQuery tests ──────────────────────── #[test] fn query_process_reply_sorts() { let target = NodeId::from_bytes([0x00; 32]); let mut q = IterativeQuery::find_node(target, 1); // Simulate: we have node 0xFF pending let far = NodeId::from_bytes([0xFF; 32]); q.pending.insert(far, Instant::now()); // Reply with closer nodes let nodes = vec![ make_peer(0x10, 3000), make_peer(0x01, 3001), make_peer(0x05, 3002), ]; q.process_reply(&far, nodes); // Should be sorted by distance from target assert_eq!(q.closest[0].id, NodeId::from_bytes([0x01; 32])); assert_eq!(q.closest[1].id, NodeId::from_bytes([0x05; 32])); assert_eq!(q.closest[2].id, NodeId::from_bytes([0x10; 32])); } #[test] fn query_converges_when_no_closer() { let target = NodeId::from_bytes([0x00; 32]); let mut q = IterativeQuery::find_node(target, 1); // Add initial closest q.closest.push(make_peer(0x01, 3000)); // Simulate reply with no closer nodes let from = NodeId::from_bytes([0x01; 32]); q.pending.insert(from, Instant::now()); q.process_reply(&from, vec![make_peer(0x02, 3001)]); // 0x01 is still closest, pending is empty -> converged assert_eq!(q.phase, QueryPhase::Converged); } #[test] fn query_find_value_done_on_value() { let target = NodeId::from_bytes([0x00; 32]); let mut q = IterativeQuery::find_value(target, b"key".to_vec(), 1); q.process_value(b"found-it".to_vec()); assert!(q.is_done()); assert_eq!(q.values, vec![b"found-it".to_vec()]); } // ── MaskBitExplorer tests ─────────────────────── #[test] fn mask_bit_cycles() { let id = NodeId::from_bytes([0xFF; 32]); let mut explorer = MaskBitExplorer::new(id); assert_eq!(explorer.position(), 1); explorer.next_targets(); assert_eq!(explorer.position(), 3); explorer.next_targets(); assert_eq!(explorer.position(), 5); // Run through full cycle for _ in 0..8 { explorer.next_targets(); } // 5 + 8*2 = 21 > 20, so reset to 1 assert_eq!(explorer.position(), 1); } #[test] fn mask_bit_produces_different_targets() { let id = NodeId::from_bytes([0xFF; 32]); let mut explorer = MaskBitExplorer::new(id); let (t1, t2) = explorer.next_targets(); assert_ne!(t1, t2); assert_ne!(t1, id); assert_ne!(t2, id); } // ── Content versioning tests ────────────────── #[test] fn version_rejects_stale_unique() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"vk"); let src = NodeId::from_bytes([0x01; 32]); // Store version 100 s.store(StoredValue { key: b"vk".to_vec(), value: b"new".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 100, }); // Try to store older version 50 — should be rejected s.store(StoredValue { key: b"vk".to_vec(), value: b"old".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 50, }); let got = s.get(&id, b"vk"); assert_eq!(got.len(), 1); assert_eq!(got[0].value, b"new"); assert_eq!(got[0].version, 100); } #[test] fn version_accepts_newer() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"vk"); let src = NodeId::from_bytes([0x01; 32]); s.store(StoredValue { key: b"vk".to_vec(), value: b"old".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 50, }); s.store(StoredValue { key: b"vk".to_vec(), value: b"new".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 100, }); let got = s.get(&id, b"vk"); assert_eq!(got[0].value, b"new"); } #[test] fn version_zero_always_accepted() { // version=0 means "no versioning" — always accepted let mut s = DhtStorage::new(); let id = NodeId::from_key(b"vk"); let src = NodeId::from_bytes([0x01; 32]); s.store(StoredValue { key: b"vk".to_vec(), value: b"v1".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }); s.store(StoredValue { key: b"vk".to_vec(), value: b"v2".to_vec(), id, source: src, ttl: 300, stored_at: Instant::now(), is_unique: true, original: 3, recvd: HashSet::new(), version: 0, }); let got = s.get(&id, b"vk"); assert_eq!(got[0].value, b"v2"); } #[test] fn version_rejects_stale_non_unique() { let mut s = DhtStorage::new(); let id = NodeId::from_key(b"nk"); // Store value with version 100 s.store(StoredValue { key: b"nk".to_vec(), value: b"same".to_vec(), id, source: NodeId::from_bytes([0x01; 32]), ttl: 300, stored_at: Instant::now(), is_unique: false, original: 0, recvd: HashSet::new(), version: 100, }); // Same value with older version — rejected s.store(StoredValue { key: b"nk".to_vec(), value: b"same".to_vec(), id, source: NodeId::from_bytes([0x02; 32]), ttl: 600, stored_at: Instant::now(), is_unique: false, original: 0, recvd: HashSet::new(), version: 50, }); let got = s.get(&id, b"nk"); assert_eq!(got.len(), 1); assert_eq!(got[0].version, 100); assert_eq!(got[0].ttl, 300); // TTL not updated } // ── Route length tests ──────────────────────── #[test] fn query_hops_increment() { let target = NodeId::from_bytes([0x00; 32]); let mut q = IterativeQuery::find_node(target, 1); assert_eq!(q.hops, 0); let from = NodeId::from_bytes([0xFF; 32]); q.pending.insert(from, Instant::now()); q.process_reply(&from, vec![make_peer(0x10, 3000)]); assert_eq!(q.hops, 1); let from2 = NodeId::from_bytes([0x10; 32]); q.pending.insert(from2, Instant::now()); q.process_reply(&from2, vec![make_peer(0x05, 3001)]); assert_eq!(q.hops, 2); } #[test] fn now_version_monotonic() { let v1 = now_version(); std::thread::sleep(std::time::Duration::from_millis(2)); let v2 = now_version(); assert!(v2 >= v1); } }