diff options
Diffstat (limited to 'src/dgram.rs')
| -rw-r--r-- | src/dgram.rs | 346 |
1 files changed, 346 insertions, 0 deletions
diff --git a/src/dgram.rs b/src/dgram.rs new file mode 100644 index 0000000..6fca81b --- /dev/null +++ b/src/dgram.rs @@ -0,0 +1,346 @@ +//! Datagram transport with automatic fragmentation. +//! +//! Messages larger +//! than `MAX_DGRAM_PAYLOAD` (896 bytes) are split into +//! fragments and reassembled at the destination. +//! +//! Routing is automatic: if the local node is behind +//! symmetric NAT, datagrams are sent through the proxy. + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +/// Maximum payload per datagram fragment (bytes). +/// +/// Ensures each fragment +/// fits in a single UDP packet with protocol overhead. +pub const MAX_DGRAM_PAYLOAD: usize = 896; + +/// Timeout for incomplete fragment reassembly. +pub const REASSEMBLY_TIMEOUT: Duration = Duration::from_secs(10); + +/// A queued datagram waiting for address resolution. +#[derive(Debug, Clone)] +pub struct QueuedDgram { + pub data: Vec<u8>, + pub src: NodeId, + pub queued_at: Instant, +} + +// ── Fragmentation ─────────────────────────────────── + +/// Fragment header: 4 bytes. +/// +/// - `total` (u16 BE): total number of fragments. +/// - `index` (u16 BE): this fragment's index (0-based). +const FRAG_HEADER_SIZE: usize = 4; + +/// Split a message into fragments, each with a 4-byte +/// header and up to `MAX_DGRAM_PAYLOAD` bytes of data. +pub fn fragment(data: &[u8]) -> Vec<Vec<u8>> { + if data.is_empty() { + return vec![make_fragment(1, 0, &[])]; + } + + let chunk_size = MAX_DGRAM_PAYLOAD; + let total = data.len().div_ceil(chunk_size); + let total = total as u16; + + data.chunks(chunk_size) + .enumerate() + .map(|(i, chunk)| make_fragment(total, i as u16, chunk)) + .collect() +} + +fn make_fragment(total: u16, index: u16, data: &[u8]) -> Vec<u8> { + let mut buf = Vec::with_capacity(FRAG_HEADER_SIZE + data.len()); + buf.extend_from_slice(&total.to_be_bytes()); + buf.extend_from_slice(&index.to_be_bytes()); + buf.extend_from_slice(data); + buf +} + +/// Parse a fragment header. +/// +/// Returns `(total_fragments, fragment_index, payload)`. +pub fn parse_fragment(buf: &[u8]) -> Option<(u16, u16, &[u8])> { + if buf.len() < FRAG_HEADER_SIZE { + return None; + } + let total = u16::from_be_bytes([buf[0], buf[1]]); + let index = u16::from_be_bytes([buf[2], buf[3]]); + Some((total, index, &buf[FRAG_HEADER_SIZE..])) +} + +// ── Reassembly ────────────────────────────────────── + +/// State for reassembling fragments from a single sender. +#[derive(Debug)] +struct ReassemblyState { + total: u16, + fragments: HashMap<u16, Vec<u8>>, + started_at: Instant, +} + +/// Fragment reassembler. +/// +/// Tracks incoming fragments per sender and produces +/// complete messages when all fragments arrive. +pub struct Reassembler { + pending: HashMap<NodeId, ReassemblyState>, +} + +impl Reassembler { + pub fn new() -> Self { + Self { + pending: HashMap::new(), + } + } + + /// Feed a fragment. Returns the complete message if + /// all fragments have arrived. + pub fn feed( + &mut self, + sender: NodeId, + total: u16, + index: u16, + data: Vec<u8>, + ) -> Option<Vec<u8>> { + // S2-8: cap fragments to prevent memory bomb + const MAX_FRAGMENTS: u16 = 10; + if total == 0 { + log::debug!("Dgram: dropping fragment with total=0"); + return None; + } + if total > MAX_FRAGMENTS { + log::debug!( + "Dgram: dropping fragment with total={total} > {MAX_FRAGMENTS}" + ); + return None; + } + + // Single fragment → no reassembly needed + if total == 1 && index == 0 { + self.pending.remove(&sender); + return Some(data); + } + + let state = + self.pending + .entry(sender) + .or_insert_with(|| ReassemblyState { + total, + fragments: HashMap::new(), + started_at: Instant::now(), + }); + + // Total mismatch → reset + if state.total != total { + *state = ReassemblyState { + total, + fragments: HashMap::new(), + started_at: Instant::now(), + }; + } + + if index < total { + state.fragments.insert(index, data); + } + + if state.fragments.len() == total as usize { + // All fragments received → reassemble + let mut result = Vec::new(); + for i in 0..total { + if let Some(frag) = state.fragments.get(&i) { + result.extend_from_slice(frag); + } else { + // Should not happen, but guard + self.pending.remove(&sender); + return None; + } + } + self.pending.remove(&sender); + Some(result) + } else { + None + } + } + + /// Remove incomplete reassembly state older than the + /// timeout. + pub fn expire(&mut self) { + self.pending + .retain(|_, state| state.started_at.elapsed() < REASSEMBLY_TIMEOUT); + } + + /// Number of pending incomplete messages. + pub fn pending_count(&self) -> usize { + self.pending.len() + } +} + +impl Default for Reassembler { + fn default() -> Self { + Self::new() + } +} + +// ── Send queue ────────────────────────────────────── + +/// Queue of datagrams waiting for address resolution. +pub struct SendQueue { + queues: HashMap<NodeId, Vec<QueuedDgram>>, +} + +impl SendQueue { + pub fn new() -> Self { + Self { + queues: HashMap::new(), + } + } + + /// Enqueue a datagram for a destination. + pub fn push(&mut self, dst: NodeId, data: Vec<u8>, src: NodeId) { + self.queues.entry(dst).or_default().push(QueuedDgram { + data, + src, + queued_at: Instant::now(), + }); + } + + /// Drain the queue for a destination. + pub fn drain(&mut self, dst: &NodeId) -> Vec<QueuedDgram> { + self.queues.remove(dst).unwrap_or_default() + } + + /// Check if there's a pending queue for a destination. + pub fn has_pending(&self, dst: &NodeId) -> bool { + self.queues.contains_key(dst) + } + + /// Remove stale queued messages (>10s). + pub fn expire(&mut self) { + self.queues.retain(|_, q| { + q.retain(|d| d.queued_at.elapsed() < REASSEMBLY_TIMEOUT); + !q.is_empty() + }); + } +} + +impl Default for SendQueue { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Fragmentation tests ───────────────────────── + + #[test] + fn small_message_single_fragment() { + let frags = fragment(b"hello"); + assert_eq!(frags.len(), 1); + let (total, idx, data) = parse_fragment(&frags[0]).unwrap(); + assert_eq!(total, 1); + assert_eq!(idx, 0); + assert_eq!(data, b"hello"); + } + + #[test] + fn large_message_multiple_fragments() { + let msg = vec![0xAB; MAX_DGRAM_PAYLOAD * 3 + 100]; + let frags = fragment(&msg); + assert_eq!(frags.len(), 4); + + for (i, frag) in frags.iter().enumerate() { + let (total, idx, _) = parse_fragment(frag).unwrap(); + assert_eq!(total, 4); + assert_eq!(idx, i as u16); + } + } + + #[test] + fn empty_message() { + let frags = fragment(b""); + assert_eq!(frags.len(), 1); + let (total, idx, data) = parse_fragment(&frags[0]).unwrap(); + assert_eq!(total, 1); + assert_eq!(idx, 0); + assert!(data.is_empty()); + } + + #[test] + fn fragment_roundtrip() { + let msg = vec![0x42; MAX_DGRAM_PAYLOAD * 2 + 50]; + let frags = fragment(&msg); + + let mut reassembled = Vec::new(); + for frag in &frags { + let (_, _, data) = parse_fragment(frag).unwrap(); + reassembled.extend_from_slice(data); + } + assert_eq!(reassembled, msg); + } + + // ── Reassembler tests ─────────────────────────── + + #[test] + fn reassemble_single() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + let result = r.feed(sender, 1, 0, b"hello".to_vec()); + assert_eq!(result.unwrap(), b"hello"); + assert_eq!(r.pending_count(), 0); + } + + #[test] + fn reassemble_multi() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + + assert!(r.feed(sender, 3, 0, b"aaa".to_vec()).is_none()); + assert!(r.feed(sender, 3, 2, b"ccc".to_vec()).is_none()); + let result = r.feed(sender, 3, 1, b"bbb".to_vec()); + + assert_eq!(result.unwrap(), b"aaabbbccc"); + assert_eq!(r.pending_count(), 0); + } + + #[test] + fn reassemble_out_of_order() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + + // Fragments arrive in reverse order + assert!(r.feed(sender, 2, 1, b"world".to_vec()).is_none()); + let result = r.feed(sender, 2, 0, b"hello".to_vec()); + assert_eq!(result.unwrap(), b"helloworld"); + } + + // ── SendQueue tests ───────────────────────────── + + #[test] + fn send_queue_push_drain() { + let mut q = SendQueue::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let src = NodeId::from_bytes([0x02; 32]); + + q.push(dst, b"msg1".to_vec(), src); + q.push(dst, b"msg2".to_vec(), src); + + assert!(q.has_pending(&dst)); + let msgs = q.drain(&dst); + assert_eq!(msgs.len(), 2); + assert!(!q.has_pending(&dst)); + } + + #[test] + fn parse_truncated() { + assert!(parse_fragment(&[0, 1]).is_none()); + } +} |