aboutsummaryrefslogtreecommitdiffstats
path: root/src/dgram.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/dgram.rs')
-rw-r--r--src/dgram.rs346
1 files changed, 346 insertions, 0 deletions
diff --git a/src/dgram.rs b/src/dgram.rs
new file mode 100644
index 0000000..6fca81b
--- /dev/null
+++ b/src/dgram.rs
@@ -0,0 +1,346 @@
+//! Datagram transport with automatic fragmentation.
+//!
+//! Messages larger
+//! than `MAX_DGRAM_PAYLOAD` (896 bytes) are split into
+//! fragments and reassembled at the destination.
+//!
+//! Routing is automatic: if the local node is behind
+//! symmetric NAT, datagrams are sent through the proxy.
+
+use std::collections::HashMap;
+use std::time::{Duration, Instant};
+
+use crate::id::NodeId;
+
+/// Maximum payload per datagram fragment (bytes).
+///
+/// Ensures each fragment
+/// fits in a single UDP packet with protocol overhead.
+pub const MAX_DGRAM_PAYLOAD: usize = 896;
+
+/// Timeout for incomplete fragment reassembly.
+pub const REASSEMBLY_TIMEOUT: Duration = Duration::from_secs(10);
+
+/// A queued datagram waiting for address resolution.
+#[derive(Debug, Clone)]
+pub struct QueuedDgram {
+ pub data: Vec<u8>,
+ pub src: NodeId,
+ pub queued_at: Instant,
+}
+
+// ── Fragmentation ───────────────────────────────────
+
+/// Fragment header: 4 bytes.
+///
+/// - `total` (u16 BE): total number of fragments.
+/// - `index` (u16 BE): this fragment's index (0-based).
+const FRAG_HEADER_SIZE: usize = 4;
+
+/// Split a message into fragments, each with a 4-byte
+/// header and up to `MAX_DGRAM_PAYLOAD` bytes of data.
+pub fn fragment(data: &[u8]) -> Vec<Vec<u8>> {
+ if data.is_empty() {
+ return vec![make_fragment(1, 0, &[])];
+ }
+
+ let chunk_size = MAX_DGRAM_PAYLOAD;
+ let total = data.len().div_ceil(chunk_size);
+ let total = total as u16;
+
+ data.chunks(chunk_size)
+ .enumerate()
+ .map(|(i, chunk)| make_fragment(total, i as u16, chunk))
+ .collect()
+}
+
+fn make_fragment(total: u16, index: u16, data: &[u8]) -> Vec<u8> {
+ let mut buf = Vec::with_capacity(FRAG_HEADER_SIZE + data.len());
+ buf.extend_from_slice(&total.to_be_bytes());
+ buf.extend_from_slice(&index.to_be_bytes());
+ buf.extend_from_slice(data);
+ buf
+}
+
+/// Parse a fragment header.
+///
+/// Returns `(total_fragments, fragment_index, payload)`.
+pub fn parse_fragment(buf: &[u8]) -> Option<(u16, u16, &[u8])> {
+ if buf.len() < FRAG_HEADER_SIZE {
+ return None;
+ }
+ let total = u16::from_be_bytes([buf[0], buf[1]]);
+ let index = u16::from_be_bytes([buf[2], buf[3]]);
+ Some((total, index, &buf[FRAG_HEADER_SIZE..]))
+}
+
+// ── Reassembly ──────────────────────────────────────
+
+/// State for reassembling fragments from a single sender.
+#[derive(Debug)]
+struct ReassemblyState {
+ total: u16,
+ fragments: HashMap<u16, Vec<u8>>,
+ started_at: Instant,
+}
+
+/// Fragment reassembler.
+///
+/// Tracks incoming fragments per sender and produces
+/// complete messages when all fragments arrive.
+pub struct Reassembler {
+ pending: HashMap<NodeId, ReassemblyState>,
+}
+
+impl Reassembler {
+ pub fn new() -> Self {
+ Self {
+ pending: HashMap::new(),
+ }
+ }
+
+ /// Feed a fragment. Returns the complete message if
+ /// all fragments have arrived.
+ pub fn feed(
+ &mut self,
+ sender: NodeId,
+ total: u16,
+ index: u16,
+ data: Vec<u8>,
+ ) -> Option<Vec<u8>> {
+ // S2-8: cap fragments to prevent memory bomb
+ const MAX_FRAGMENTS: u16 = 10;
+ if total == 0 {
+ log::debug!("Dgram: dropping fragment with total=0");
+ return None;
+ }
+ if total > MAX_FRAGMENTS {
+ log::debug!(
+ "Dgram: dropping fragment with total={total} > {MAX_FRAGMENTS}"
+ );
+ return None;
+ }
+
+ // Single fragment → no reassembly needed
+ if total == 1 && index == 0 {
+ self.pending.remove(&sender);
+ return Some(data);
+ }
+
+ let state =
+ self.pending
+ .entry(sender)
+ .or_insert_with(|| ReassemblyState {
+ total,
+ fragments: HashMap::new(),
+ started_at: Instant::now(),
+ });
+
+ // Total mismatch → reset
+ if state.total != total {
+ *state = ReassemblyState {
+ total,
+ fragments: HashMap::new(),
+ started_at: Instant::now(),
+ };
+ }
+
+ if index < total {
+ state.fragments.insert(index, data);
+ }
+
+ if state.fragments.len() == total as usize {
+ // All fragments received → reassemble
+ let mut result = Vec::new();
+ for i in 0..total {
+ if let Some(frag) = state.fragments.get(&i) {
+ result.extend_from_slice(frag);
+ } else {
+ // Should not happen, but guard
+ self.pending.remove(&sender);
+ return None;
+ }
+ }
+ self.pending.remove(&sender);
+ Some(result)
+ } else {
+ None
+ }
+ }
+
+ /// Remove incomplete reassembly state older than the
+ /// timeout.
+ pub fn expire(&mut self) {
+ self.pending
+ .retain(|_, state| state.started_at.elapsed() < REASSEMBLY_TIMEOUT);
+ }
+
+ /// Number of pending incomplete messages.
+ pub fn pending_count(&self) -> usize {
+ self.pending.len()
+ }
+}
+
+impl Default for Reassembler {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// ── Send queue ──────────────────────────────────────
+
+/// Queue of datagrams waiting for address resolution.
+pub struct SendQueue {
+ queues: HashMap<NodeId, Vec<QueuedDgram>>,
+}
+
+impl SendQueue {
+ pub fn new() -> Self {
+ Self {
+ queues: HashMap::new(),
+ }
+ }
+
+ /// Enqueue a datagram for a destination.
+ pub fn push(&mut self, dst: NodeId, data: Vec<u8>, src: NodeId) {
+ self.queues.entry(dst).or_default().push(QueuedDgram {
+ data,
+ src,
+ queued_at: Instant::now(),
+ });
+ }
+
+ /// Drain the queue for a destination.
+ pub fn drain(&mut self, dst: &NodeId) -> Vec<QueuedDgram> {
+ self.queues.remove(dst).unwrap_or_default()
+ }
+
+ /// Check if there's a pending queue for a destination.
+ pub fn has_pending(&self, dst: &NodeId) -> bool {
+ self.queues.contains_key(dst)
+ }
+
+ /// Remove stale queued messages (>10s).
+ pub fn expire(&mut self) {
+ self.queues.retain(|_, q| {
+ q.retain(|d| d.queued_at.elapsed() < REASSEMBLY_TIMEOUT);
+ !q.is_empty()
+ });
+ }
+}
+
+impl Default for SendQueue {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ // ── Fragmentation tests ─────────────────────────
+
+ #[test]
+ fn small_message_single_fragment() {
+ let frags = fragment(b"hello");
+ assert_eq!(frags.len(), 1);
+ let (total, idx, data) = parse_fragment(&frags[0]).unwrap();
+ assert_eq!(total, 1);
+ assert_eq!(idx, 0);
+ assert_eq!(data, b"hello");
+ }
+
+ #[test]
+ fn large_message_multiple_fragments() {
+ let msg = vec![0xAB; MAX_DGRAM_PAYLOAD * 3 + 100];
+ let frags = fragment(&msg);
+ assert_eq!(frags.len(), 4);
+
+ for (i, frag) in frags.iter().enumerate() {
+ let (total, idx, _) = parse_fragment(frag).unwrap();
+ assert_eq!(total, 4);
+ assert_eq!(idx, i as u16);
+ }
+ }
+
+ #[test]
+ fn empty_message() {
+ let frags = fragment(b"");
+ assert_eq!(frags.len(), 1);
+ let (total, idx, data) = parse_fragment(&frags[0]).unwrap();
+ assert_eq!(total, 1);
+ assert_eq!(idx, 0);
+ assert!(data.is_empty());
+ }
+
+ #[test]
+ fn fragment_roundtrip() {
+ let msg = vec![0x42; MAX_DGRAM_PAYLOAD * 2 + 50];
+ let frags = fragment(&msg);
+
+ let mut reassembled = Vec::new();
+ for frag in &frags {
+ let (_, _, data) = parse_fragment(frag).unwrap();
+ reassembled.extend_from_slice(data);
+ }
+ assert_eq!(reassembled, msg);
+ }
+
+ // ── Reassembler tests ───────────────────────────
+
+ #[test]
+ fn reassemble_single() {
+ let mut r = Reassembler::new();
+ let sender = NodeId::from_bytes([0x01; 32]);
+ let result = r.feed(sender, 1, 0, b"hello".to_vec());
+ assert_eq!(result.unwrap(), b"hello");
+ assert_eq!(r.pending_count(), 0);
+ }
+
+ #[test]
+ fn reassemble_multi() {
+ let mut r = Reassembler::new();
+ let sender = NodeId::from_bytes([0x01; 32]);
+
+ assert!(r.feed(sender, 3, 0, b"aaa".to_vec()).is_none());
+ assert!(r.feed(sender, 3, 2, b"ccc".to_vec()).is_none());
+ let result = r.feed(sender, 3, 1, b"bbb".to_vec());
+
+ assert_eq!(result.unwrap(), b"aaabbbccc");
+ assert_eq!(r.pending_count(), 0);
+ }
+
+ #[test]
+ fn reassemble_out_of_order() {
+ let mut r = Reassembler::new();
+ let sender = NodeId::from_bytes([0x01; 32]);
+
+ // Fragments arrive in reverse order
+ assert!(r.feed(sender, 2, 1, b"world".to_vec()).is_none());
+ let result = r.feed(sender, 2, 0, b"hello".to_vec());
+ assert_eq!(result.unwrap(), b"helloworld");
+ }
+
+ // ── SendQueue tests ─────────────────────────────
+
+ #[test]
+ fn send_queue_push_drain() {
+ let mut q = SendQueue::new();
+ let dst = NodeId::from_bytes([0x01; 32]);
+ let src = NodeId::from_bytes([0x02; 32]);
+
+ q.push(dst, b"msg1".to_vec(), src);
+ q.push(dst, b"msg2".to_vec(), src);
+
+ assert!(q.has_pending(&dst));
+ let msgs = q.drain(&dst);
+ assert_eq!(msgs.len(), 2);
+ assert!(!q.has_pending(&dst));
+ }
+
+ #[test]
+ fn parse_truncated() {
+ assert!(parse_fragment(&[0, 1]).is_none());
+ }
+}