diff options
| author | murilo ijanc | 2026-03-24 15:04:03 -0300 |
|---|---|---|
| committer | murilo ijanc | 2026-03-24 15:04:03 -0300 |
| commit | 9821aabf0b50d2487b07502d3d2cd89e7d62bdbe (patch) | |
| tree | 53da095ff90cc755bac3d4bf699172b5e8cd07d6 | |
| download | tesseras-dht-9821aabf0b50d2487b07502d3d2cd89e7d62bdbe.tar.gz | |
Initial commitv0.1.0
NAT-aware Kademlia DHT library for peer-to-peer networks.
Features:
- Distributed key-value storage (iterative FIND_NODE, FIND_VALUE, STORE)
- NAT traversal via DTUN hole-punching and proxy relay
- Reliable Datagram Protocol (RDP) with 7-state connection machine
- Datagram transport with automatic fragmentation/reassembly
- Ed25519 packet authentication
- 256-bit node IDs (Ed25519 public keys)
- Rate limiting, ban list, and eclipse attack mitigation
- Persistence and metrics
- OpenBSD and Linux support
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | .rustfmt.toml | 4 | ||||
| -rw-r--r-- | CHANGELOG.md | 34 | ||||
| -rw-r--r-- | Cargo.lock | 535 | ||||
| -rw-r--r-- | Cargo.toml | 26 | ||||
| -rw-r--r-- | LICENSE | 14 | ||||
| -rw-r--r-- | Makefile | 39 | ||||
| -rw-r--r-- | README.md | 5 | ||||
| -rw-r--r-- | benches/bench.rs | 144 | ||||
| -rw-r--r-- | deny.toml | 242 | ||||
| -rw-r--r-- | examples/dgram.rs | 100 | ||||
| -rw-r--r-- | examples/join.rs | 65 | ||||
| -rw-r--r-- | examples/network.rs | 86 | ||||
| -rw-r--r-- | examples/put_get.rs | 93 | ||||
| -rw-r--r-- | examples/rdp.rs | 147 | ||||
| -rw-r--r-- | examples/remote_get.rs | 106 | ||||
| -rw-r--r-- | examples/tesserasd.rs | 338 | ||||
| -rw-r--r-- | examples/two_nodes.rs | 133 | ||||
| -rw-r--r-- | fuzz/fuzz_parse.rs | 87 | ||||
| -rw-r--r-- | src/advertise.rs | 173 | ||||
| -rw-r--r-- | src/banlist.rs | 207 | ||||
| -rw-r--r-- | src/config.rs | 139 | ||||
| -rw-r--r-- | src/crypto.rs | 172 | ||||
| -rw-r--r-- | src/dgram.rs | 346 | ||||
| -rw-r--r-- | src/dht.rs | 1028 | ||||
| -rw-r--r-- | src/dtun.rs | 436 | ||||
| -rw-r--r-- | src/error.rs | 80 | ||||
| -rw-r--r-- | src/event.rs | 30 | ||||
| -rw-r--r-- | src/handlers.rs | 1049 | ||||
| -rw-r--r-- | src/id.rs | 238 | ||||
| -rw-r--r-- | src/lib.rs | 128 | ||||
| -rw-r--r-- | src/metrics.rs | 121 | ||||
| -rw-r--r-- | src/msg.rs | 830 | ||||
| -rw-r--r-- | src/nat.rs | 384 | ||||
| -rw-r--r-- | src/net.rs | 744 | ||||
| -rw-r--r-- | src/node.rs | 1395 | ||||
| -rw-r--r-- | src/peers.rs | 337 | ||||
| -rw-r--r-- | src/persist.rs | 84 | ||||
| -rw-r--r-- | src/proxy.rs | 370 | ||||
| -rw-r--r-- | src/ratelimit.rs | 136 | ||||
| -rw-r--r-- | src/rdp.rs | 1343 | ||||
| -rw-r--r-- | src/routing.rs | 843 | ||||
| -rw-r--r-- | src/socket.rs | 159 | ||||
| -rw-r--r-- | src/store_track.rs | 275 | ||||
| -rw-r--r-- | src/sys.rs | 127 | ||||
| -rw-r--r-- | src/timer.rs | 221 | ||||
| -rw-r--r-- | src/wire.rs | 368 | ||||
| -rw-r--r-- | tests/integration.rs | 704 | ||||
| -rw-r--r-- | tests/rdp_lossy.rs | 125 | ||||
| -rw-r--r-- | tests/scale.rs | 138 |
50 files changed, 14929 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eb5a316 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +target diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..bf95564 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,4 @@ +# group_imports = "StdExternalCrate" +# imports_granularity = "Module" +max_width = 80 +reorder_imports = true diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3d5c3f2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,34 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.1.0] - 2026-03-24 + +### Added + +- Kademlia DHT with iterative FIND_NODE, FIND_VALUE, and STORE operations. +- NAT type detection (global, cone NAT, symmetric NAT). +- DTUN hole-punching for cone NAT traversal. +- Proxy relay for symmetric NAT nodes. +- Reliable Datagram Protocol (RDP) with 3-way handshake, sliding windows, + cumulative ACK, EACK/SACK, retransmission, and graceful close. +- Datagram transport with automatic fragmentation and reassembly. +- Ed25519 packet authentication (reject unsigned packets). +- 256-bit node IDs derived from Ed25519 public keys. +- Address advertisement for routing table updates. +- Rate limiting per source address. +- Ban list for misbehaving peers. +- Eclipse attack mitigation via bucket diversity checks. +- Persistence for routing table and DHT storage. +- Metrics collection (message counts, latency, peer churn). +- Configurable parameters via `Config` builder. +- Event-based API for async notification of DHT events. +- OpenBSD (`kqueue`) and Linux (`epoll`) support via mio. +- Examples: `join`, `put_get`, `dgram`, `rdp`, `two_nodes`, `network`, + `remote_get`, `tesserasd`. +- Integration, scale, and lossy-network test suites. +- Fuzz harness for wire protocol parsing. +- Benchmark suite. diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..67162f1 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,535 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "zeroize", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + +[[package]] +name = "env_filter" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "rand_core", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tesseras-dht" +version = "0.3.0" +dependencies = [ + "ed25519-dalek", + "env_logger", + "log", + "mio", + "sha2", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ffbcb16 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "tesseras-dht" +version = "0.1.0" +edition = "2024" +authors = [ "murilo ijanc' <murilo@ijanc.org>" ] +categories = ["network-programming"] +description = "NAT-aware Kademlia DHT library" +homepage = "https://tesseras.net" +keywords = ["dht", "kademlia", "p2p", "nat", "distributed"] +license = "ISC" +readme = "README.md" +repository = "https://git.sr.ht/~ijanc/tesseras-dht" +rust-version = "1.93.0" + +[dependencies] +ed25519-dalek = "=2.2.0" +log = "=0.4.29" +mio = { version = "=1.1.1", features = ["net", "os-poll"] } +sha2 = "=0.10.9" + +[dev-dependencies] +env_logger = "0.11" + +[[bench]] +name = "bench" +harness = false @@ -0,0 +1,14 @@ +Copyright (c) 2026 murilo ijanc' <murilo@ijanc.org> + +Permission to use, copy, modify, and distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2dd7cf3 --- /dev/null +++ b/Makefile @@ -0,0 +1,39 @@ +all: release + +release: + cargo build --release + +debug: + cargo build + +test: + cargo test + +test-release: + cargo test --release + +check: + cargo check + +clean: + cargo clean + +fmt: + cargo fmt + +clippy: + cargo clippy -- -D warnings + +examples: release + cargo build --release --examples + +doc: + cargo doc --no-deps --open + +audit: + cargo deny check + +bench: + cargo bench + +.PHONY: all release debug test test-release check clean fmt clippy examples doc audit bench diff --git a/README.md b/README.md new file mode 100644 index 0000000..e70a247 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# tesseras-dht + +## License + +ISC diff --git a/benches/bench.rs b/benches/bench.rs new file mode 100644 index 0000000..0b427ef --- /dev/null +++ b/benches/bench.rs @@ -0,0 +1,144 @@ +//! Benchmarks for core operations. +//! +//! Run with: cargo bench + +use std::net::SocketAddr; +use tesseras_dht::crypto::Identity; +use tesseras_dht::id::NodeId; +use tesseras_dht::peers::PeerInfo; +use tesseras_dht::routing::RoutingTable; +use tesseras_dht::wire::MsgHeader; + +fn main() { + println!("=== tesseras-dht benchmarks ===\n"); + + bench_sha256(); + bench_ed25519_sign(); + bench_ed25519_verify(); + bench_routing_add(); + bench_routing_closest(); + bench_header_roundtrip(); + bench_node_id_xor(); +} + +fn bench_sha256() { + let data = vec![0xABu8; 1024]; + let start = std::time::Instant::now(); + let iters = 100_000; + for _ in 0..iters { + let _ = NodeId::from_key(&data); + } + let elapsed = start.elapsed(); + println!( + "SHA-256 (1KB): {:>8.0} ns/op ({iters} iters)", + elapsed.as_nanos() as f64 / iters as f64 + ); +} + +fn bench_ed25519_sign() { + let id = Identity::generate(); + let data = vec![0x42u8; 256]; + let start = std::time::Instant::now(); + let iters = 10_000; + for _ in 0..iters { + let _ = id.sign(&data); + } + let elapsed = start.elapsed(); + println!( + "Ed25519 sign: {:>8.0} ns/op ({iters} iters)", + elapsed.as_nanos() as f64 / iters as f64 + ); +} + +fn bench_ed25519_verify() { + let id = Identity::generate(); + let data = vec![0x42u8; 256]; + let sig = id.sign(&data); + let start = std::time::Instant::now(); + let iters = 10_000; + for _ in 0..iters { + let _ = Identity::verify(id.public_key(), &data, &sig); + } + let elapsed = start.elapsed(); + println!( + "Ed25519 verify: {:>8.0} ns/op ({iters} iters)", + elapsed.as_nanos() as f64 / iters as f64 + ); +} + +fn bench_routing_add() { + let local = NodeId::random(); + let mut rt = RoutingTable::new(local); + let start = std::time::Instant::now(); + let iters = 10_000; + for i in 0..iters { + let id = NodeId::random(); + let addr = SocketAddr::from(([10, 0, (i % 256) as u8, 1], 3000)); + rt.add(PeerInfo::new(id, addr)); + } + let elapsed = start.elapsed(); + println!( + "Routing add: {:>8.0} ns/op ({iters} iters, size={})", + elapsed.as_nanos() as f64 / iters as f64, + rt.size() + ); +} + +fn bench_routing_closest() { + let local = NodeId::random(); + let mut rt = RoutingTable::new(local); + for i in 0..500 { + let id = NodeId::random(); + let addr = SocketAddr::from(([10, 0, (i % 256) as u8, 1], 3000)); + rt.add(PeerInfo::new(id, addr)); + } + + let target = NodeId::random(); + let start = std::time::Instant::now(); + let iters = 10_000; + for _ in 0..iters { + let _ = rt.closest(&target, 10); + } + let elapsed = start.elapsed(); + println!( + "Routing closest: {:>8.0} ns/op ({iters} iters, size={})", + elapsed.as_nanos() as f64 / iters as f64, + rt.size() + ); +} + +fn bench_header_roundtrip() { + let hdr = MsgHeader::new( + tesseras_dht::wire::MsgType::DhtPing, + 100, + NodeId::random(), + NodeId::random(), + ); + let mut buf = vec![0u8; tesseras_dht::wire::HEADER_SIZE]; + let start = std::time::Instant::now(); + let iters = 1_000_000; + for _ in 0..iters { + hdr.write(&mut buf).unwrap(); + let _ = MsgHeader::parse(&buf).unwrap(); + } + let elapsed = start.elapsed(); + println!( + "Header roundtrip: {:>8.0} ns/op ({iters} iters)", + elapsed.as_nanos() as f64 / iters as f64 + ); +} + +fn bench_node_id_xor() { + let a = NodeId::random(); + let b = NodeId::random(); + let start = std::time::Instant::now(); + let iters = 10_000_000; + for _ in 0..iters { + let _ = a.distance(&b); + } + let elapsed = start.elapsed(); + println!( + "NodeId XOR: {:>8.0} ns/op ({iters} iters)", + elapsed.as_nanos() as f64 / iters as f64 + ); +} diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..827b081 --- /dev/null +++ b/deny.toml @@ -0,0 +1,242 @@ +# This template contains all of the possible sections and their default values + +# Note that all fields that take a lint level have these possible values: +# * deny - An error will be produced and the check will fail +# * warn - A warning will be produced, but the check will not fail +# * allow - No warning or error will be produced, though in some cases a note +# will be + +# The values provided in this template are the default values that will be used +# when any section or field is not specified in your own configuration + +# Root options + +# The graph table configures how the dependency graph is constructed and thus +# which crates the checks are performed against +[graph] +# If 1 or more target triples (and optionally, target_features) are specified, +# only the specified targets will be checked when running `cargo deny check`. +# This means, if a particular package is only ever used as a target specific +# dependency, such as, for example, the `nix` crate only being used via the +# `target_family = "unix"` configuration, that only having windows targets in +# this list would mean the nix crate, as well as any of its exclusive +# dependencies not shared by any other crates, would be ignored, as the target +# list here is effectively saying which targets you are building for. +targets = [ + # The triple can be any string, but only the target triples built in to + # rustc (as of 1.40) can be checked against actual config expressions + #"x86_64-unknown-linux-musl", + # You can also specify which target_features you promise are enabled for a + # particular target. target_features are currently not validated against + # the actual valid features supported by the target architecture. + #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, +] +# When creating the dependency graph used as the source of truth when checks are +# executed, this field can be used to prune crates from the graph, removing them +# from the view of cargo-deny. This is an extremely heavy hammer, as if a crate +# is pruned from the graph, all of its dependencies will also be pruned unless +# they are connected to another crate in the graph that hasn't been pruned, +# so it should be used with care. The identifiers are [Package ID Specifications] +# (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html) +#exclude = [] +# If true, metadata will be collected with `--all-features`. Note that this can't +# be toggled off if true, if you want to conditionally enable `--all-features` it +# is recommended to pass `--all-features` on the cmd line instead +all-features = false +# If true, metadata will be collected with `--no-default-features`. The same +# caveat with `all-features` applies +no-default-features = false +# If set, these feature will be enabled when collecting metadata. If `--features` +# is specified on the cmd line they will take precedence over this option. +#features = [] + +# The output table provides options for how/if diagnostics are outputted +[output] +# When outputting inclusion graphs in diagnostics that include features, this +# option can be used to specify the depth at which feature edges will be added. +# This option is included since the graphs can be quite large and the addition +# of features from the crate(s) to all of the graph roots can be far too verbose. +# This option can be overridden via `--feature-depth` on the cmd line +feature-depth = 1 + +# This section is considered when running `cargo deny check advisories` +# More documentation for the advisories section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html +[advisories] +# The path where the advisory databases are cloned/fetched into +#db-path = "$CARGO_HOME/advisory-dbs" +# The url(s) of the advisory databases to use +#db-urls = ["https://github.com/rustsec/advisory-db"] +# A list of advisory IDs to ignore. Note that ignored advisories will still +# output a note when they are encountered. +ignore = [ + #"RUSTSEC-0000-0000", + #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, + #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish + #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, +] +# If this is true, then cargo deny will use the git executable to fetch advisory database. +# If this is false, then it uses a built-in git library. +# Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support. +# See Git Authentication for more information about setting up git authentication. +#git-fetch-with-cli = true + +# This section is considered when running `cargo deny check licenses` +# More documentation for the licenses section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html +[licenses] +# List of explicitly allowed licenses +# See https://spdx.org/licenses/ for list of possible licenses +# [possible values: any SPDX 3.11 short identifier (+ optional exception)]. +allow = [ + "ISC", + "MIT", + "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", + "BSD-3-Clause", + "Unicode-3.0", +] +# The confidence threshold for detecting a license from license text. +# The higher the value, the more closely the license text must be to the +# canonical license text of a valid SPDX license file. +# [possible values: any between 0.0 and 1.0]. +confidence-threshold = 0.8 +# Allow 1 or more licenses on a per-crate basis, so that particular licenses +# aren't accepted for every possible crate as with the normal allow list +exceptions = [ + # Each entry is the crate and version constraint, and its specific allow + # list + #{ allow = ["Zlib"], crate = "adler32" }, +] + +# Some crates don't have (easily) machine readable licensing information, +# adding a clarification entry for it allows you to manually specify the +# licensing information +#[[licenses.clarify]] +# The package spec the clarification applies to +#crate = "ring" +# The SPDX expression for the license requirements of the crate +#expression = "MIT AND ISC AND OpenSSL" +# One or more files in the crate's source used as the "source of truth" for +# the license expression. If the contents match, the clarification will be used +# when running the license check, otherwise the clarification will be ignored +# and the crate will be checked normally, which may produce warnings or errors +# depending on the rest of your configuration +#license-files = [ +# Each entry is a crate relative path, and the (opaque) hash of its contents +#{ path = "LICENSE", hash = 0xbd0eed23 } +#] + +[licenses.private] +# If true, ignores workspace crates that aren't published, or are only +# published to private registries. +# To see how to mark a crate as unpublished (to the official registry), +# visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field. +ignore = false +# One or more private registries that you might publish crates to, if a crate +# is only published to private registries, and ignore is true, the crate will +# not have its license(s) checked +registries = [ + #"https://sekretz.com/registry +] + +# This section is considered when running `cargo deny check bans`. +# More documentation about the 'bans' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html +[bans] +# Lint level for when multiple versions of the same crate are detected +multiple-versions = "warn" +# Lint level for when a crate version requirement is `*` +wildcards = "allow" +# The graph highlighting used when creating dotgraphs for crates +# with multiple versions +# * lowest-version - The path to the lowest versioned duplicate is highlighted +# * simplest-path - The path to the version with the fewest edges is highlighted +# * all - Both lowest-version and simplest-path are used +highlight = "all" +# The default lint level for `default` features for crates that are members of +# the workspace that is being checked. This can be overridden by allowing/denying +# `default` on a crate-by-crate basis if desired. +workspace-default-features = "allow" +# The default lint level for `default` features for external crates that are not +# members of the workspace. This can be overridden by allowing/denying `default` +# on a crate-by-crate basis if desired. +external-default-features = "allow" +# List of crates that are allowed. Use with care! +allow = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, +] +# If true, workspace members are automatically allowed even when using deny-by-default +# This is useful for organizations that want to deny all external dependencies by default +# but allow their own workspace crates without having to explicitly list them +allow-workspace = false +# List of crates to deny +deny = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" }, + # Wrapper crates can optionally be specified to allow the crate when it + # is a direct dependency of the otherwise banned crate + #{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] }, +] + +# List of features to allow/deny +# Each entry the name of a crate and a version range. If version is +# not specified, all versions will be matched. +#[[bans.features]] +#crate = "reqwest" +# Features to not allow +#deny = ["json"] +# Features to allow +#allow = [ +# "rustls", +# "__rustls", +# "__tls", +# "hyper-rustls", +# "rustls", +# "rustls-pemfile", +# "rustls-tls-webpki-roots", +# "tokio-rustls", +# "webpki-roots", +#] +# If true, the allowed features must exactly match the enabled feature set. If +# this is set there is no point setting `deny` +#exact = true + +# Certain crates/versions that will be skipped when doing duplicate detection. +skip = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" }, +] +# Similarly to `skip` allows you to skip certain crates during duplicate +# detection. Unlike skip, it also includes the entire tree of transitive +# dependencies starting at the specified crate, up to a certain depth, which is +# by default infinite. +skip-tree = [ + #"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies + #{ crate = "ansi_term@0.11.0", depth = 20 }, +] + +# This section is considered when running `cargo deny check sources`. +# More documentation about the 'sources' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html +[sources] +# Lint level for what to happen when a crate from a crate registry that is not +# in the allow list is encountered +unknown-registry = "warn" +# Lint level for what to happen when a crate from a git repository that is not +# in the allow list is encountered +unknown-git = "warn" +# List of URLs for allowed crate registries. Defaults to the crates.io index +# if not specified. If it is specified but empty, no registries are allowed. +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +# List of URLs for allowed Git repositories +allow-git = [] + +[sources.allow-org] +# github.com organizations to allow git sources for +github = [] +# gitlab.com organizations to allow git sources for +gitlab = [] +# bitbucket.org organizations to allow git sources for +bitbucket = [] diff --git a/examples/dgram.rs b/examples/dgram.rs new file mode 100644 index 0000000..ab25e78 --- /dev/null +++ b/examples/dgram.rs @@ -0,0 +1,100 @@ +//! Datagram transport example (equivalent to example4.cpp). +//! +//! Creates two nodes and sends datagrams between them +//! using the dgram callback API. +//! +//! Usage: +//! cargo run --example dgram + +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::id::NodeId; +use tesseras_dht::nat::NatState; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + let mut node1 = Node::bind(0).expect("bind node1"); + node1.set_nat_state(NatState::Global); + let addr1 = node1.local_addr().unwrap(); + let id1 = *node1.id(); + + let mut node2 = Node::bind(0).expect("bind node2"); + node2.set_nat_state(NatState::Global); + let addr2 = node2.local_addr().unwrap(); + let id2 = *node2.id(); + + println!("Node 1: {} @ {addr1}", node1.id_hex()); + println!("Node 2: {} @ {addr2}", node2.id_hex()); + + // Join node2 to node1 + node2.join("127.0.0.1", addr1.port()).expect("join"); + + // Poll to establish routing + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + std::thread::sleep(Duration::from_millis(50)); + } + + // Set up dgram callbacks + let received1: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new())); + let received2: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new())); + + let r1 = received1.clone(); + node1.set_dgram_callback(move |data: &[u8], from: &NodeId| { + let msg = String::from_utf8_lossy(data).to_string(); + println!("Node 1 received: '{msg}' from {from:?}"); + r1.lock().unwrap().push(msg); + }); + + let r2 = received2.clone(); + node2.set_dgram_callback(move |data: &[u8], from: &NodeId| { + let msg = String::from_utf8_lossy(data).to_string(); + println!("Node 2 received: '{msg}' from {from:?}"); + r2.lock().unwrap().push(msg); + }); + + // Send datagrams + println!("\n--- Sending datagrams ---"); + node1.send_dgram(b"hello from node1", &id2); + node2.send_dgram(b"hello from node2", &id1); + + // Poll to deliver + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + std::thread::sleep(Duration::from_millis(50)); + } + + // Note: actual dgram delivery requires the full + // send queue → address resolution → send flow. + // This example demonstrates the API; full delivery + // is wired in integration. + + println!("\n--- Summary ---"); + println!( + "Node 1 received {} messages", + received1.lock().unwrap().len() + ); + println!( + "Node 2 received {} messages", + received2.lock().unwrap().len() + ); + println!("Node 1 send queue pending: queued for delivery"); + println!("Node 2 send queue pending: queued for delivery"); +} diff --git a/examples/join.rs b/examples/join.rs new file mode 100644 index 0000000..a478410 --- /dev/null +++ b/examples/join.rs @@ -0,0 +1,65 @@ +//! Basic bootstrap example (equivalent to example1.cpp). +//! +//! Usage: +//! cargo run --example join -- 10000 +//! cargo run --example join -- 10001 127.0.0.1 10000 +//! +//! The first invocation creates a standalone node. +//! The second joins via the first. + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + let args: Vec<String> = std::env::args().collect(); + if args.len() < 2 { + eprintln!("usage: {} port [host port]", args[0]); + eprintln!(); + eprintln!("example:"); + eprintln!(" $ {} 10000 &", args[0]); + eprintln!(" $ {} 10001 127.0.0.1 10000", args[0]); + std::process::exit(1); + } + + let port: u16 = args[1].parse().expect("invalid port"); + + let mut node = Node::bind(port).expect("bind failed"); + node.set_nat_state(NatState::Global); + + println!("Node {} listening on port {port}", node.id_hex()); + + if args.len() >= 4 { + let dst_host = &args[2]; + let dst_port: u16 = args[3].parse().expect("invalid dst port"); + + match node.join(dst_host, dst_port) { + Ok(()) => println!("Join request sent"), + Err(e) => { + eprintln!("Join failed: {e}"); + std::process::exit(1); + } + } + } + + // Event loop + loop { + node.poll().ok(); + std::thread::sleep(Duration::from_millis(100)); + } +} diff --git a/examples/network.rs b/examples/network.rs new file mode 100644 index 0000000..be3e0ef --- /dev/null +++ b/examples/network.rs @@ -0,0 +1,86 @@ +//! Multi-node network example (equivalent to example2.cpp). +//! +//! Creates N nodes on localhost, joins them recursively +//! via the first node, then prints state periodically. +//! +//! Usage: +//! cargo run --example network +//! RUST_LOG=debug cargo run --example network + +use std::time::{Duration, Instant}; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +const NUM_NODES: usize = 20; +const POLL_ROUNDS: usize = 30; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + println!("Creating {NUM_NODES} nodes..."); + let start = Instant::now(); + + // Create bootstrap node + let mut nodes: Vec<Node> = Vec::new(); + let bootstrap = Node::bind(0).expect("bind bootstrap"); + let bootstrap_port = bootstrap.local_addr().unwrap().port(); + println!("Bootstrap: {} @ port {bootstrap_port}", bootstrap.id_hex()); + nodes.push(bootstrap); + + // Create and join remaining nodes + for i in 1..NUM_NODES { + let mut node = Node::bind(0).expect("bind node"); + node.set_nat_state(NatState::Global); + node.join("127.0.0.1", bootstrap_port).expect("join"); + println!("Node {i}: {} joined", &node.id_hex()[..8]); + nodes.push(node); + } + nodes[0].set_nat_state(NatState::Global); + + println!("\nAll {NUM_NODES} nodes created in {:?}", start.elapsed()); + + // Poll all nodes to exchange messages + println!("\nPolling {POLL_ROUNDS} rounds..."); + for round in 0..POLL_ROUNDS { + for node in nodes.iter_mut() { + node.poll().ok(); + } + std::thread::sleep(Duration::from_millis(50)); + + if (round + 1) % 10 == 0 { + let sizes: Vec<usize> = + nodes.iter().map(|n| n.routing_table_size()).collect(); + let avg = sizes.iter().sum::<usize>() / sizes.len(); + let max = sizes.iter().max().unwrap(); + println!( + " Round {}: avg routing table = {avg}, max = {max}", + round + 1 + ); + } + } + + // Print final state + println!("\n--- Final state ---"); + for (i, node) in nodes.iter().enumerate() { + println!( + "Node {i}: {} | rt={} peers={} storage={}", + &node.id_hex()[..8], + node.routing_table_size(), + node.peer_count(), + node.storage_count(), + ); + } +} diff --git a/examples/put_get.rs b/examples/put_get.rs new file mode 100644 index 0000000..e2bfaca --- /dev/null +++ b/examples/put_get.rs @@ -0,0 +1,93 @@ +//! DHT put/get example (equivalent to example3.cpp). +//! +//! Creates a small network, stores key-value pairs from +//! one node, and retrieves them from another. +//! +//! Usage: +//! cargo run --example put_get + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +const NUM_NODES: usize = 5; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + // Create network + let mut nodes: Vec<Node> = Vec::new(); + let bootstrap = Node::bind(0).expect("bind"); + let bp = bootstrap.local_addr().unwrap().port(); + nodes.push(bootstrap); + nodes[0].set_nat_state(NatState::Global); + + for _ in 1..NUM_NODES { + let mut n = Node::bind(0).expect("bind"); + n.set_nat_state(NatState::Global); + n.join("127.0.0.1", bp).expect("join"); + nodes.push(n); + } + + // Poll to establish routing + for _ in 0..20 { + for n in nodes.iter_mut() { + n.poll().ok(); + } + std::thread::sleep(Duration::from_millis(50)); + } + + println!("Network ready: {NUM_NODES} nodes"); + + // Node 0 stores several key-value pairs + println!("\n--- Storing values from Node 0 ---"); + for i in 0..5u32 { + let key = format!("key-{i}"); + let val = format!("value-{i}"); + nodes[0].put(key.as_bytes(), val.as_bytes(), 300, false); + println!(" put({key}, {val})"); + } + + // Poll to distribute stores + for _ in 0..20 { + for n in nodes.iter_mut() { + n.poll().ok(); + } + std::thread::sleep(Duration::from_millis(50)); + } + + // Each node retrieves values + println!("\n--- Retrieving values ---"); + for (ni, node) in nodes.iter_mut().enumerate() { + for i in 0..5u32 { + let key = format!("key-{i}"); + let vals = node.get(key.as_bytes()); + let found: Vec<String> = vals + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect(); + if !found.is_empty() { + println!(" Node {ni} get({key}) = {:?}", found); + } + } + } + + // Summary + println!("\n--- Storage summary ---"); + for (i, n) in nodes.iter().enumerate() { + println!(" Node {i}: {} values stored", n.storage_count()); + } +} diff --git a/examples/rdp.rs b/examples/rdp.rs new file mode 100644 index 0000000..319c779 --- /dev/null +++ b/examples/rdp.rs @@ -0,0 +1,147 @@ +//! RDP reliable transport example (equivalent to example5.cpp). +//! +//! Two nodes: server listens, client connects, sends +//! data, server receives it. +//! +//! Usage: +//! cargo run --example rdp + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; +use tesseras_dht::rdp::RdpState; + +const RDP_PORT: u16 = 5000; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + // Create two nodes + let mut server = Node::bind(0).expect("bind server"); + server.set_nat_state(NatState::Global); + let server_addr = server.local_addr().unwrap(); + let server_id = *server.id(); + println!("Server: {} @ {server_addr}", server.id_hex()); + + let mut client = Node::bind(0).expect("bind client"); + client.set_nat_state(NatState::Global); + println!("Client: {}", client.id_hex()); + + // Client joins server so they know each other + client.join("127.0.0.1", server_addr.port()).expect("join"); + + // Poll to exchange routing info + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(20)); + } + println!("Client knows {} peers", client.routing_table_size()); + + // Server listens on RDP port + let _listen = server.rdp_listen(RDP_PORT).expect("listen"); + println!("Server listening on RDP port {RDP_PORT}"); + + // Client connects + let desc = client + .rdp_connect(0, &server_id, RDP_PORT) + .expect("connect"); + println!("Client state: {:?}", client.rdp_state(desc).unwrap()); + + // Poll to complete handshake + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(20)); + } + + println!( + "Client state after handshake: {:?}", + client.rdp_state(desc).unwrap_or(RdpState::Closed) + ); + + // Send data if connection is open + match client.rdp_state(desc) { + Ok(RdpState::Open) => { + for i in 0..3u16 { + let msg = format!("hello {i}"); + match client.rdp_send(desc, msg.as_bytes()) { + Ok(n) => println!("Sent: '{msg}' ({n} bytes)"), + Err(e) => println!("Send error: {e}"), + } + } + + // Poll to deliver + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(20)); + } + + // Server reads received data + println!("\n--- Server reading ---"); + let server_status = server.rdp_status(); + for s in &server_status { + if s.state == RdpState::Open { + let mut buf = [0u8; 256]; + loop { + match server.rdp_recv(s.sport as i32 + 1, &mut buf) { + Ok(0) => break, + Ok(n) => { + let msg = String::from_utf8_lossy(&buf[..n]); + println!("Server received: '{msg}'"); + } + Err(_) => break, + } + } + } + } + // Try reading from desc 2 (server-side accepted desc) + let mut buf = [0u8; 256]; + for attempt_desc in 1..=5 { + loop { + match server.rdp_recv(attempt_desc, &mut buf) { + Ok(0) => break, + Ok(n) => { + let msg = String::from_utf8_lossy(&buf[..n]); + println!("Server desc={attempt_desc}: '{msg}'"); + } + Err(_) => break, + } + } + } + } + Ok(state) => { + println!("Connection not open: {state:?}"); + } + Err(e) => { + println!("Descriptor error: {e}"); + } + } + + // Show status + println!("\n--- RDP Status ---"); + for s in &client.rdp_status() { + println!(" state={:?} dport={} sport={}", s.state, s.dport, s.sport); + } + + // Cleanup + client.rdp_close(desc); + + println!("\n--- Done ---"); + println!("Server: {server}"); + println!("Client: {client}"); +} diff --git a/examples/remote_get.rs b/examples/remote_get.rs new file mode 100644 index 0000000..60add19 --- /dev/null +++ b/examples/remote_get.rs @@ -0,0 +1,106 @@ +//! Remote get example: FIND_VALUE across the network. +//! +//! Node 1 stores a value locally. Node 3 retrieves it +//! via iterative FIND_VALUE, even though it never +//! received a STORE. +//! +//! Usage: +//! cargo run --example remote_get + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + // Create 3 nodes + let mut node1 = Node::bind(0).expect("bind"); + node1.set_nat_state(NatState::Global); + let port1 = node1.local_addr().unwrap().port(); + + let mut node2 = Node::bind(0).expect("bind"); + node2.set_nat_state(NatState::Global); + node2.join("127.0.0.1", port1).expect("join"); + + let mut node3 = Node::bind(0).expect("bind"); + node3.set_nat_state(NatState::Global); + node3.join("127.0.0.1", port1).expect("join"); + + println!("Node 1: {} (has the value)", &node1.id_hex()[..8]); + println!("Node 2: {} (relay)", &node2.id_hex()[..8]); + println!("Node 3: {} (will search)", &node3.id_hex()[..8]); + + // Let them discover each other + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + node3.poll().ok(); + std::thread::sleep(Duration::from_millis(20)); + } + + println!( + "\nRouting tables: N1={} N2={} N3={}", + node1.routing_table_size(), + node2.routing_table_size(), + node3.routing_table_size(), + ); + + // Node 1 stores a value locally only (no STORE sent) + println!("\n--- Node 1 stores 'secret-key' ---"); + node1.put(b"secret-key", b"secret-value", 300, false); + + // Verify: Node 3 does NOT have it + assert!( + node3.get(b"secret-key").is_empty(), + "Node 3 should not have the value yet" + ); + println!("Node 3 get('secret-key'): [] (not found)"); + + // Node 3 does get() — triggers FIND_VALUE + println!("\n--- Node 3 searches via FIND_VALUE ---"); + let _ = node3.get(b"secret-key"); // starts query + + // Poll to let FIND_VALUE propagate + for _ in 0..15 { + node1.poll().ok(); + node2.poll().ok(); + node3.poll().ok(); + std::thread::sleep(Duration::from_millis(30)); + } + + // Now Node 3 should have cached the value + let result = node3.get(b"secret-key"); + println!( + "Node 3 get('secret-key'): {:?}", + result + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect::<Vec<_>>() + ); + + if result.is_empty() { + println!("\n(Value not found — may need more poll rounds)"); + } else { + println!("\nRemote get successful!"); + } + + // Storage summary + println!("\n--- Storage ---"); + println!("Node 1: {} values", node1.storage_count()); + println!("Node 2: {} values", node2.storage_count()); + println!("Node 3: {} values", node3.storage_count()); +} diff --git a/examples/tesserasd.rs b/examples/tesserasd.rs new file mode 100644 index 0000000..81fc3bc --- /dev/null +++ b/examples/tesserasd.rs @@ -0,0 +1,338 @@ +//! tesserasd: daemon with TCP command interface. +//! +//! Manages multiple DHT nodes via a line-oriented TCP +//! protocol. +//! +//! Usage: +//! cargo run --example tesserasd +//! cargo run --example tesserasd -- --host 0.0.0.0 --port 8080 +//! +//! Then connect with: +//! nc localhost 12080 +//! +//! Commands: +//! new,NAME,PORT[,global] Create a node +//! delete,NAME Delete a node +//! set_id,NAME,DATA Set node ID from data +//! join,NAME,HOST,PORT Join network +//! put,NAME,KEY,VALUE,TTL Store key-value +//! get,NAME,KEY Retrieve value +//! dump,NAME Print node state +//! list List all nodes +//! quit Close connection +//! +//! Response codes: +//! 200-205 Success +//! 400-409 Error + +use std::collections::HashMap; +use std::io::{BufRead, BufReader, Write}; +use std::net::{TcpListener, TcpStream}; +use std::time::Duration; + +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +// Response codes +const OK_NEW: &str = "200"; +const OK_DELETE: &str = "201"; +const OK_JOIN: &str = "202"; +const OK_PUT: &str = "203"; +const OK_GET: &str = "204"; +const OK_SET_ID: &str = "205"; + +const ERR_UNKNOWN: &str = "400"; +const ERR_INVALID: &str = "401"; +const ERR_PORT: &str = "402"; +const ERR_EXISTS: &str = "403"; +const ERR_NO_NODE: &str = "404"; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + let mut host = "127.0.0.1".to_string(); + let mut port: u16 = 12080; + + let args: Vec<String> = std::env::args().collect(); + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--host" if i + 1 < args.len() => { + host = args[i + 1].clone(); + i += 2; + } + "--port" if i + 1 < args.len() => { + port = args[i + 1].parse().expect("invalid port"); + i += 2; + } + _ => i += 1, + } + } + + let listener = + TcpListener::bind(format!("{host}:{port}")).expect("bind TCP"); + listener.set_nonblocking(true).expect("set_nonblocking"); + + println!("tesserasd listening on 127.0.0.1:{port}"); + println!("Connect with: nc localhost {port}"); + + let mut nodes: HashMap<String, Node> = HashMap::new(); + let mut clients: Vec<TcpStream> = Vec::new(); + + loop { + // Accept new TCP connections + if let Ok((stream, addr)) = listener.accept() { + println!("Client connected: {addr}"); + stream.set_nonblocking(true).expect("set_nonblocking"); + let mut s = stream.try_clone().unwrap(); + let _ = s.write_all(b"tesserasd ready\n"); + clients.push(stream); + } + + // Process commands from connected clients + let mut to_remove = Vec::new(); + for (i, client) in clients.iter().enumerate() { + let mut reader = BufReader::new(client.try_clone().unwrap()); + let mut line = String::new(); + match reader.read_line(&mut line) { + Ok(0) => to_remove.push(i), // disconnected + Ok(_) => { + let line = line.trim().to_string(); + if !line.is_empty() { + let mut out = client.try_clone().unwrap(); + let response = handle_command(&line, &mut nodes); + let _ = out.write_all(response.as_bytes()); + let _ = out.write_all(b"\n"); + + if line == "quit" { + to_remove.push(i); + } + } + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // No data yet + } + Err(_) => to_remove.push(i), + } + } + // Remove disconnected (reverse order) + to_remove.sort(); + to_remove.dedup(); + for i in to_remove.into_iter().rev() { + println!("Client disconnected"); + clients.remove(i); + } + + // Poll all DHT nodes + for node in nodes.values_mut() { + node.poll().ok(); + } + + std::thread::sleep(Duration::from_millis(10)); + } +} + +fn handle_command(line: &str, nodes: &mut HashMap<String, Node>) -> String { + let parts: Vec<&str> = line.split(',').collect(); + if parts.is_empty() { + return format!("{ERR_UNKNOWN},unknown command"); + } + + match parts[0] { + "new" => cmd_new(&parts, nodes), + "delete" => cmd_delete(&parts, nodes), + "set_id" => cmd_set_id(&parts, nodes), + "join" => cmd_join(&parts, nodes), + "put" => cmd_put(&parts, nodes), + "get" => cmd_get(&parts, nodes), + "dump" => cmd_dump(&parts, nodes), + "list" => cmd_list(nodes), + "quit" => "goodbye".to_string(), + _ => format!("{ERR_UNKNOWN},unknown command: {}", parts[0]), + } +} + +// new,NAME,PORT[,global] +fn cmd_new(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 3 { + return format!("{ERR_INVALID},usage: new,NAME,PORT[,global]"); + } + let name = parts[1]; + let port: u16 = match parts[2].parse() { + Ok(p) => p, + Err(_) => return format!("{ERR_INVALID},invalid port"), + }; + + if nodes.contains_key(name) { + return format!("{ERR_EXISTS},new,{name},{port},already exists"); + } + + let mut node = match Node::bind(port) { + Ok(n) => n, + Err(e) => return format!("{ERR_PORT},new,{name},{port},{e}"), + }; + + if parts.len() >= 4 && parts[3] == "global" { + node.set_nat_state(NatState::Global); + } + + let id = node.id_hex(); + nodes.insert(name.to_string(), node); + format!("{OK_NEW},new,{name},{port},{id}") +} + +// delete,NAME +fn cmd_delete(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 2 { + return format!("{ERR_INVALID},usage: delete,NAME"); + } + let name = parts[1]; + if nodes.remove(name).is_some() { + format!("{OK_DELETE},delete,{name}") + } else { + format!("{ERR_NO_NODE},delete,{name},not found") + } +} + +// set_id,NAME,DATA +fn cmd_set_id(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 3 { + return format!("{ERR_INVALID},usage: set_id,NAME,DATA"); + } + let name = parts[1]; + let data = parts[2]; + match nodes.get_mut(name) { + Some(node) => { + node.set_id(data.as_bytes()); + format!("{OK_SET_ID},set_id,{name},{}", node.id_hex()) + } + None => format!("{ERR_NO_NODE},set_id,{name},not found"), + } +} + +// join,NAME,HOST,PORT +fn cmd_join(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 4 { + return format!("{ERR_INVALID},usage: join,NAME,HOST,PORT"); + } + let name = parts[1]; + let host = parts[2]; + let port: u16 = match parts[3].parse() { + Ok(p) => p, + Err(_) => return format!("{ERR_INVALID},invalid port"), + }; + match nodes.get_mut(name) { + Some(node) => match node.join(host, port) { + Ok(()) => { + format!("{OK_JOIN},join,{name},{host},{port}") + } + Err(e) => format!("{ERR_INVALID},join,{name},{host},{port},{e}"), + }, + None => { + format!("{ERR_NO_NODE},join,{name},not found") + } + } +} + +// put,NAME,KEY,VALUE,TTL +fn cmd_put(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 5 { + return format!("{ERR_INVALID},usage: put,NAME,KEY,VALUE,TTL"); + } + let name = parts[1]; + let key = parts[2]; + let value = parts[3]; + let ttl: u16 = match parts[4].parse() { + Ok(t) => t, + Err(_) => return format!("{ERR_INVALID},invalid TTL"), + }; + match nodes.get_mut(name) { + Some(node) => { + node.put(key.as_bytes(), value.as_bytes(), ttl, false); + format!("{OK_PUT},put,{name},{key}") + } + None => { + format!("{ERR_NO_NODE},put,{name},not found") + } + } +} + +// get,NAME,KEY +fn cmd_get(parts: &[&str], nodes: &mut HashMap<String, Node>) -> String { + if parts.len() < 3 { + return format!("{ERR_INVALID},usage: get,NAME,KEY"); + } + let name = parts[1]; + let key = parts[2]; + match nodes.get_mut(name) { + Some(node) => { + let vals = node.get(key.as_bytes()); + if vals.is_empty() { + format!("{OK_GET},get,{name},{key},<not found>") + } else { + let values: Vec<String> = vals + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect(); + format!("{OK_GET},get,{name},{key},{}", values.join(";")) + } + } + None => { + format!("{ERR_NO_NODE},get,{name},not found") + } + } +} + +// dump,NAME +fn cmd_dump(parts: &[&str], nodes: &HashMap<String, Node>) -> String { + if parts.len() < 2 { + return format!("{ERR_INVALID},usage: dump,NAME"); + } + let name = parts[1]; + match nodes.get(name) { + Some(node) => { + format!( + "{OK_NEW},dump,{name},id={},nat={:?},\ + rt={},peers={},storage={}", + node.id_hex(), + node.nat_state(), + node.routing_table_size(), + node.peer_count(), + node.storage_count(), + ) + } + None => { + format!("{ERR_NO_NODE},dump,{name},not found") + } + } +} + +// list +fn cmd_list(nodes: &HashMap<String, Node>) -> String { + if nodes.is_empty() { + return format!("{OK_NEW},list,<empty>"); + } + let mut lines = Vec::new(); + for (name, node) in nodes { + lines.push(format!( + "{name}={} rt={} storage={}", + &node.id_hex()[..8], + node.routing_table_size(), + node.storage_count(), + )); + } + format!("{OK_NEW},list,{}", lines.join(";")) +} diff --git a/examples/two_nodes.rs b/examples/two_nodes.rs new file mode 100644 index 0000000..13565c6 --- /dev/null +++ b/examples/two_nodes.rs @@ -0,0 +1,133 @@ +//! Two-node example: bootstrap, put, get. +//! +//! Creates two Node nodes on localhost. Node 2 joins +//! via Node 1, then Node 1 stores a value and Node 2 +//! retrieves it (via protocol exchange). +//! +//! Run with: +//! cargo run --example two_nodes +//! +//! With debug logging: +//! RUST_LOG=debug cargo run --example two_nodes + +use std::time::Duration; + +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +fn main() { + env_logger::Builder::from_env( + env_logger::Env::default().default_filter_or("info"), + ) + .format(|buf, record| { + use std::io::Write; + writeln!( + buf, + "{} [{}] {}", + record.level(), + record.target(), + record.args() + ) + }) + .init(); + + // ── Create two nodes ──────────────────────────── + + let mut node1 = Node::bind(0).expect("bind node1"); + node1.set_nat_state(NatState::Global); + let addr1 = node1.local_addr().expect("local addr"); + println!("Node 1: {} @ {}", node1.id_hex(), addr1); + + let mut node2 = Node::bind(0).expect("bind node2"); + node2.set_nat_state(NatState::Global); + let addr2 = node2.local_addr().expect("local addr"); + println!("Node 2: {} @ {}", node2.id_hex(), addr2); + + // ── Node 2 joins via Node 1 ───────────────────── + + println!("\n--- Node 2 joining via Node 1 ---"); + node2.join("127.0.0.1", addr1.port()).expect("join"); + + // Poll both nodes a few times to exchange messages + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + std::thread::sleep(Duration::from_millis(50)); + } + + println!("Node 1 routing table: {} peers", node1.routing_table_size()); + println!("Node 2 routing table: {} peers", node2.routing_table_size()); + + // ── Node 1 stores a value ─────────────────────── + + println!("\n--- Node 1 storing key='hello' ---"); + node1.put(b"hello", b"world", 300, false); + + // Poll to deliver STORE messages + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + std::thread::sleep(Duration::from_millis(50)); + } + + // ── Check storage ─────────────────────────────── + + let vals1 = node1.get(b"hello"); + let vals2 = node2.get(b"hello"); + + println!( + "\nNode 1 get('hello'): {:?}", + vals1 + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect::<Vec<_>>() + ); + println!( + "Node 2 get('hello'): {:?}", + vals2 + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect::<Vec<_>>() + ); + + // ── Test remote get via FIND_VALUE ──────────── + + println!("\n--- Node 1 storing key='secret' (local only) ---"); + // Store only on Node 1 (no STORE sent because + // we bypass put and go directly to storage) + node1.put(b"remote-key", b"remote-val", 300, false); + // Don't poll — so Node 2 doesn't get the STORE + + // Node 2 tries to get it — should trigger FIND_VALUE + println!( + "Node 2 get('remote-key') before poll: {:?}", + node2.get(b"remote-key") + ); + + // Now poll to let FIND_VALUE exchange happen + for _ in 0..10 { + node1.poll().ok(); + node2.poll().ok(); + std::thread::sleep(Duration::from_millis(50)); + } + + let remote_vals = node2.get(b"remote-key"); + println!( + "Node 2 get('remote-key') after poll: {:?}", + remote_vals + .iter() + .map(|v| String::from_utf8_lossy(v).to_string()) + .collect::<Vec<_>>() + ); + + // ── Print state ───────────────────────────────── + + println!("\n--- Node 1 state ---"); + node1.print_state(); + println!("\n--- Node 2 state ---"); + node2.print_state(); + + println!("\n--- Done ---"); + println!("Node 1: {node1}"); + println!("Node 2: {node2}"); +} diff --git a/fuzz/fuzz_parse.rs b/fuzz/fuzz_parse.rs new file mode 100644 index 0000000..0efdf42 --- /dev/null +++ b/fuzz/fuzz_parse.rs @@ -0,0 +1,87 @@ +//! Fuzz targets for message parsers. +//! +//! Run with: cargo +nightly fuzz run fuzz_parse +//! +//! Requires: cargo install cargo-fuzz +//! +//! These targets verify that no input can cause a panic, +//! buffer overflow, or undefined behavior in the parsers. + +// Note: this file is a reference for cargo-fuzz targets. +// To use, create a fuzz/Cargo.toml and fuzz_targets/ +// directory per cargo-fuzz conventions. The actual fuzz +// harnesses are: + +#[cfg(test)] +mod tests { + /// Fuzz MsgHeader::parse with random bytes. + #[test] + fn fuzz_header_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 128]; + tesseras_dht::sys::random_bytes(&mut buf); + // Should never panic + let _ = tesseras_dht::wire::MsgHeader::parse(&buf); + } + } + + /// Fuzz msg::parse_store with random bytes. + #[test] + fn fuzz_store_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 256]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::msg::parse_store(&buf); + } + } + + /// Fuzz msg::parse_find_node with random bytes. + #[test] + fn fuzz_find_node_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 128]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::msg::parse_find_node(&buf); + } + } + + /// Fuzz msg::parse_find_value with random bytes. + #[test] + fn fuzz_find_value_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 256]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::msg::parse_find_value(&buf); + } + } + + /// Fuzz rdp::parse_rdp_wire with random bytes. + #[test] + fn fuzz_rdp_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 128]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::rdp::parse_rdp_wire(&buf); + } + } + + /// Fuzz dgram::parse_fragment with random bytes. + #[test] + fn fuzz_fragment_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 64]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::dgram::parse_fragment(&buf); + } + } + + /// Fuzz msg::parse_find_value_reply with random bytes. + #[test] + fn fuzz_find_value_reply_parse() { + for _ in 0..10_000 { + let mut buf = [0u8; 256]; + tesseras_dht::sys::random_bytes(&mut buf); + let _ = tesseras_dht::msg::parse_find_value_reply(&buf); + } + } +} diff --git a/src/advertise.rs b/src/advertise.rs new file mode 100644 index 0000000..b415b5b --- /dev/null +++ b/src/advertise.rs @@ -0,0 +1,173 @@ +//! Local address advertisement. +//! +//! Allows a node to announce its address to peers so +//! they can update their routing tables with the correct +//! endpoint. +//! +//! Used after NAT detection to inform peers of the +//! node's externally-visible address. + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +// ── Constants ──────────────────────────────────────── + +/// TTL for advertisements. +pub const ADVERTISE_TTL: Duration = Duration::from_secs(300); + +/// Timeout for a single advertisement attempt. +pub const ADVERTISE_TIMEOUT: Duration = Duration::from_secs(2); + +/// Interval between refresh cycles. +pub const ADVERTISE_REFRESH_INTERVAL: Duration = Duration::from_secs(100); + +/// A pending outgoing advertisement. +#[derive(Debug)] +struct PendingAd { + sent_at: Instant, +} + +/// A received advertisement from a peer. +#[derive(Debug)] +struct ReceivedAd { + received_at: Instant, +} + +/// Address advertisement manager. +pub struct Advertise { + /// Pending outgoing advertisements by nonce. + pending: HashMap<u32, PendingAd>, + + /// Received advertisements by peer ID. + received: HashMap<NodeId, ReceivedAd>, +} + +impl Advertise { + pub fn new(_local_id: NodeId) -> Self { + Self { + pending: HashMap::new(), + received: HashMap::new(), + } + } + + /// Start an advertisement to a peer. + /// + /// Returns the nonce to include in the message. + pub fn start_advertise(&mut self, nonce: u32) -> u32 { + self.pending.insert( + nonce, + PendingAd { + sent_at: Instant::now(), + }, + ); + nonce + } + + /// Handle an advertisement reply (our ad was accepted). + /// + /// Returns `true` if the nonce matched a pending ad. + pub fn recv_reply(&mut self, nonce: u32) -> bool { + self.pending.remove(&nonce).is_some() + } + + /// Handle an incoming advertisement from a peer. + /// + /// Records that this peer has advertised to us. + pub fn recv_advertise(&mut self, peer_id: NodeId) { + self.received.insert( + peer_id, + ReceivedAd { + received_at: Instant::now(), + }, + ); + } + + /// Check if a peer has advertised to us recently. + pub fn has_advertised(&self, peer_id: &NodeId) -> bool { + self.received + .get(peer_id) + .map(|ad| ad.received_at.elapsed() < ADVERTISE_TTL) + .unwrap_or(false) + } + + /// Remove expired pending ads and stale received ads. + pub fn refresh(&mut self) { + self.pending + .retain(|_, ad| ad.sent_at.elapsed() < ADVERTISE_TIMEOUT); + self.received + .retain(|_, ad| ad.received_at.elapsed() < ADVERTISE_TTL); + } + + /// Number of pending outgoing advertisements. + pub fn pending_count(&self) -> usize { + self.pending.len() + } + + /// Number of active received advertisements. + pub fn received_count(&self) -> usize { + self.received + .values() + .filter(|ad| ad.received_at.elapsed() < ADVERTISE_TTL) + .count() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn advertise_and_reply() { + let mut adv = Advertise::new(NodeId::from_bytes([0x01; 32])); + adv.start_advertise(42); + assert_eq!(adv.pending_count(), 1); + + assert!(adv.recv_reply(42)); + assert_eq!(adv.pending_count(), 0); + } + + #[test] + fn unknown_reply_ignored() { + let mut adv = Advertise::new(NodeId::from_bytes([0x01; 32])); + assert!(!adv.recv_reply(999)); + } + + #[test] + fn recv_advertisement() { + let mut adv = Advertise::new(NodeId::from_bytes([0x01; 32])); + let peer = NodeId::from_bytes([0x02; 32]); + + assert!(!adv.has_advertised(&peer)); + adv.recv_advertise(peer); + assert!(adv.has_advertised(&peer)); + assert_eq!(adv.received_count(), 1); + } + + #[test] + fn refresh_clears_stale() { + let mut adv = Advertise::new(NodeId::from_bytes([0x01; 32])); + + // Insert already-expired pending ad + adv.pending.insert( + 1, + PendingAd { + sent_at: Instant::now() - Duration::from_secs(10), + }, + ); + + // Insert already-expired received ad + let peer = NodeId::from_bytes([0x02; 32]); + adv.received.insert( + peer, + ReceivedAd { + received_at: Instant::now() - Duration::from_secs(600), + }, + ); + + adv.refresh(); + assert_eq!(adv.pending_count(), 0); + assert_eq!(adv.received_count(), 0); + } +} diff --git a/src/banlist.rs b/src/banlist.rs new file mode 100644 index 0000000..f01e31a --- /dev/null +++ b/src/banlist.rs @@ -0,0 +1,207 @@ +//! Ban list for misbehaving peers. +//! +//! Tracks failure counts per peer. After exceeding a +//! threshold, the peer is temporarily banned. Bans +//! expire automatically after a configurable duration. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +/// Default number of failures before banning a peer. +const DEFAULT_BAN_THRESHOLD: u32 = 3; + +/// Default ban duration (3 hours, matching gonode). +const DEFAULT_BAN_DURATION: Duration = Duration::from_secs(3 * 3600); + +/// Tracks failures and bans for peers. +pub struct BanList { + /// Failure counts per address. + failures: HashMap<SocketAddr, FailureEntry>, + /// Active bans: address → expiry time. + bans: HashMap<SocketAddr, Instant>, + /// Number of failures before a ban is applied. + threshold: u32, + /// How long a ban lasts. + ban_duration: Duration, +} + +struct FailureEntry { + count: u32, + last_failure: Instant, +} + +impl BanList { + pub fn new() -> Self { + Self { + failures: HashMap::new(), + bans: HashMap::new(), + threshold: DEFAULT_BAN_THRESHOLD, + ban_duration: DEFAULT_BAN_DURATION, + } + } + + /// Set the failure threshold before banning. + pub fn set_threshold(&mut self, threshold: u32) { + self.threshold = threshold; + } + + /// Set the ban duration. + pub fn set_ban_duration(&mut self, duration: Duration) { + self.ban_duration = duration; + } + + /// Check if a peer is currently banned. + pub fn is_banned(&self, addr: &SocketAddr) -> bool { + if let Some(expiry) = self.bans.get(addr) { + Instant::now() < *expiry + } else { + false + } + } + + /// Record a failure for a peer. Returns true if the + /// peer was just banned (crossed the threshold). + pub fn record_failure(&mut self, addr: SocketAddr) -> bool { + let entry = self.failures.entry(addr).or_insert(FailureEntry { + count: 0, + last_failure: Instant::now(), + }); + entry.count += 1; + entry.last_failure = Instant::now(); + + if entry.count >= self.threshold { + let expiry = Instant::now() + self.ban_duration; + self.bans.insert(addr, expiry); + self.failures.remove(&addr); + log::info!( + "Banned peer {addr} for {}s", + self.ban_duration.as_secs() + ); + true + } else { + false + } + } + + /// Clear failure count for a peer (e.g. after a + /// successful interaction). + pub fn record_success(&mut self, addr: &SocketAddr) { + self.failures.remove(addr); + } + + /// Remove expired bans and stale failure entries. + /// Call periodically from the event loop. + pub fn cleanup(&mut self) { + let now = Instant::now(); + self.bans.retain(|_, expiry| now < *expiry); + + // Clear failure entries older than ban_duration + // (stale failures shouldn't accumulate forever) + self.failures + .retain(|_, e| e.last_failure.elapsed() < self.ban_duration); + } + + /// Number of currently active bans. + pub fn ban_count(&self) -> usize { + self.bans + .iter() + .filter(|(_, e)| Instant::now() < **e) + .count() + } + + /// Number of peers with recorded failures. + pub fn failure_count(&self) -> usize { + self.failures.len() + } +} + +impl Default for BanList { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], port)) + } + + #[test] + fn not_banned_initially() { + let bl = BanList::new(); + assert!(!bl.is_banned(&addr(1000))); + } + + #[test] + fn ban_after_threshold() { + let mut bl = BanList::new(); + let a = addr(1000); + assert!(!bl.record_failure(a)); + assert!(!bl.record_failure(a)); + assert!(bl.record_failure(a)); // 3rd failure → banned + assert!(bl.is_banned(&a)); + } + + #[test] + fn success_clears_failures() { + let mut bl = BanList::new(); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_success(&a); + // Failures cleared, next failure starts over + assert!(!bl.record_failure(a)); + assert!(!bl.is_banned(&a)); + } + + #[test] + fn ban_expires() { + let mut bl = BanList::new(); + bl.set_ban_duration(Duration::from_millis(1)); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + assert!(bl.is_banned(&a)); + std::thread::sleep(Duration::from_millis(5)); + assert!(!bl.is_banned(&a)); + } + + #[test] + fn cleanup_removes_expired() { + let mut bl = BanList::new(); + bl.set_ban_duration(Duration::from_millis(1)); + let a = addr(1000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + std::thread::sleep(Duration::from_millis(5)); + bl.cleanup(); + assert_eq!(bl.ban_count(), 0); + } + + #[test] + fn custom_threshold() { + let mut bl = BanList::new(); + bl.set_threshold(1); + let a = addr(1000); + assert!(bl.record_failure(a)); // 1 failure → banned + assert!(bl.is_banned(&a)); + } + + #[test] + fn independent_peers() { + let mut bl = BanList::new(); + let a = addr(1000); + let b = addr(2000); + bl.record_failure(a); + bl.record_failure(a); + bl.record_failure(a); + assert!(bl.is_banned(&a)); + assert!(!bl.is_banned(&b)); + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..b2aaf02 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,139 @@ +//! Node configuration. +//! +//! All tunable parameters in one place. Passed to +//! `Tessera::bind_with_config()`. + +use std::time::Duration; + +/// Configuration for a Tessera node. +#[derive(Debug, Clone)] +pub struct Config { + /// Maximum entries per k-bucket (default: 20). + pub bucket_size: usize, + + /// Number of closest nodes returned in lookups + /// (default: 10). + pub num_find_node: usize, + + /// Maximum parallel queries per lookup (default: 6). + pub max_query: usize, + + /// Single RPC query timeout (default: 3s). + pub query_timeout: Duration, + + /// Maximum iterative query duration (default: 30s). + pub max_query_duration: Duration, + + /// Data restore interval (default: 120s). + pub restore_interval: Duration, + + /// Bucket refresh interval (default: 60s). + pub refresh_interval: Duration, + + /// Maintain (mask_bit exploration) interval + /// (default: 120s). + pub maintain_interval: Duration, + + /// Default value TTL in seconds (default: 300). + /// Max 65535 (~18 hours). For longer TTLs, use + /// periodic republish. + pub default_ttl: u16, + + /// Maximum value size in bytes (default: 65536). + pub max_value_size: usize, + + /// Rate limiter: messages per second per IP + /// (default: 50). + pub rate_limit: f64, + + /// Rate limiter: burst capacity (default: 100). + pub rate_burst: u32, + + /// Maximum nodes per /24 subnet (default: 2). + pub max_per_subnet: usize, + + /// Enable DTUN (NAT traversal) (default: true). + pub enable_dtun: bool, + + /// Require Ed25519 signature on all packets + /// (default: true). Set to false only for testing. + pub require_signatures: bool, + + /// Ban threshold: failures before banning a peer + /// (default: 3). + pub ban_threshold: u32, + + /// Ban duration in seconds (default: 10800 = 3h). + pub ban_duration_secs: u64, + + /// Node activity check interval (default: 120s). + /// Proactively pings routing table peers to detect + /// failures early. + pub activity_check_interval: Duration, + + /// Store retry interval (default: 30s). How often + /// to sweep for timed-out stores and retry them. + pub store_retry_interval: Duration, +} + +impl Default for Config { + fn default() -> Self { + Self { + bucket_size: 20, + num_find_node: 10, + max_query: 6, + query_timeout: Duration::from_secs(3), + max_query_duration: Duration::from_secs(30), + restore_interval: Duration::from_secs(120), + refresh_interval: Duration::from_secs(60), + maintain_interval: Duration::from_secs(120), + default_ttl: 300, + max_value_size: 65536, + rate_limit: 50.0, + rate_burst: 100, + max_per_subnet: 2, + enable_dtun: true, + require_signatures: true, + ban_threshold: 3, + ban_duration_secs: 10800, + activity_check_interval: Duration::from_secs(120), + store_retry_interval: Duration::from_secs(30), + } + } +} + +impl Config { + /// Create a config tuned for a pastebin. + /// + /// Higher TTL (24h), larger max value (1 MB), + /// HMAC enabled. + pub fn pastebin() -> Self { + Self { + default_ttl: 65535, // ~18h, use republish for longer + max_value_size: 1_048_576, + require_signatures: true, + ..Default::default() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_values() { + let c = Config::default(); + assert_eq!(c.bucket_size, 20); + assert_eq!(c.default_ttl, 300); + assert!(c.require_signatures); + } + + #[test] + fn pastebin_preset() { + let c = Config::pastebin(); + assert_eq!(c.default_ttl, 65535); + assert_eq!(c.max_value_size, 1_048_576); + assert!(c.require_signatures); + } +} diff --git a/src/crypto.rs b/src/crypto.rs new file mode 100644 index 0000000..10587ec --- /dev/null +++ b/src/crypto.rs @@ -0,0 +1,172 @@ +//! Ed25519 identity and packet signing. +//! +//! Each node has an Ed25519 keypair: +//! - **Private key** (32 bytes): never leaves the node. +//! Used to sign every outgoing packet. +//! - **Public key** (32 bytes): shared with peers. +//! Used to verify incoming packets. +//! - **NodeId** = public key (32 bytes). Direct 1:1 +//! binding, no hashing. + +use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; + +use crate::id::NodeId; + +/// Ed25519 signature size (64 bytes). +pub const SIGNATURE_SIZE: usize = 64; + +/// Ed25519 public key size (32 bytes). +pub const PUBLIC_KEY_SIZE: usize = 32; + +/// A node's cryptographic identity. +/// +/// Contains the Ed25519 keypair and the derived NodeId. +/// The NodeId is the public key directly (32 bytes) +/// deterministically bound to the keypair. +pub struct Identity { + signing_key: SigningKey, + verifying_key: VerifyingKey, + node_id: NodeId, +} + +impl Identity { + /// Generate a new random identity. + pub fn generate() -> Self { + let mut seed = [0u8; 32]; + crate::sys::random_bytes(&mut seed); + Self::from_seed(seed) + } + + /// Create an identity from a 32-byte seed. + /// + /// Deterministic: same seed → same keypair → same + /// NodeId. + pub fn from_seed(seed: [u8; 32]) -> Self { + let signing_key = SigningKey::from_bytes(&seed); + let verifying_key = signing_key.verifying_key(); + let node_id = NodeId::from_bytes(*verifying_key.as_bytes()); + Self { + signing_key, + verifying_key, + node_id, + } + } + + /// The node's 256-bit ID (= public key). + pub fn node_id(&self) -> &NodeId { + &self.node_id + } + + /// The Ed25519 public key (32 bytes). + pub fn public_key(&self) -> &[u8; PUBLIC_KEY_SIZE] { + self.verifying_key.as_bytes() + } + + /// Sign data with the private key. + /// + /// Returns a 64-byte Ed25519 signature. + pub fn sign(&self, data: &[u8]) -> [u8; SIGNATURE_SIZE] { + let sig = self.signing_key.sign(data); + sig.to_bytes() + } + + /// Verify a signature against a public key. + /// + /// This is a static method — used to verify packets + /// from other nodes using their public key. + pub fn verify( + public_key: &[u8; PUBLIC_KEY_SIZE], + data: &[u8], + signature: &[u8], + ) -> bool { + // Length check is not timing-sensitive (length is + // public). ed25519_dalek::verify() is constant-time. + if signature.len() != SIGNATURE_SIZE { + return false; + } + let Ok(vk) = VerifyingKey::from_bytes(public_key) else { + return false; + }; + let mut sig_bytes = [0u8; SIGNATURE_SIZE]; + sig_bytes.copy_from_slice(signature); + let sig = Signature::from_bytes(&sig_bytes); + vk.verify(data, &sig).is_ok() + } +} + +impl std::fmt::Debug for Identity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Identity({})", self.node_id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_unique() { + let a = Identity::generate(); + let b = Identity::generate(); + assert_ne!(a.node_id(), b.node_id()); + } + + #[test] + fn from_seed_deterministic() { + let seed = [0x42u8; 32]; + let a = Identity::from_seed(seed); + let b = Identity::from_seed(seed); + assert_eq!(a.node_id(), b.node_id()); + assert_eq!(a.public_key(), b.public_key()); + } + + #[test] + fn node_id_is_pubkey() { + let id = Identity::generate(); + let expected = NodeId::from_bytes(*id.public_key()); + assert_eq!(*id.node_id(), expected); + } + + #[test] + fn sign_verify() { + let id = Identity::generate(); + let data = b"hello world"; + let sig = id.sign(data); + + assert!(Identity::verify(id.public_key(), data, &sig)); + } + + #[test] + fn verify_wrong_data() { + let id = Identity::generate(); + let sig = id.sign(b"correct"); + + assert!(!Identity::verify(id.public_key(), b"wrong", &sig)); + } + + #[test] + fn verify_wrong_key() { + let id1 = Identity::generate(); + let id2 = Identity::generate(); + let sig = id1.sign(b"data"); + + assert!(!Identity::verify(id2.public_key(), b"data", &sig)); + } + + #[test] + fn verify_truncated_sig() { + let id = Identity::generate(); + assert!(!Identity::verify( + id.public_key(), + b"data", + &[0u8; 10] // too short + )); + } + + #[test] + fn signature_size() { + let id = Identity::generate(); + let sig = id.sign(b"test"); + assert_eq!(sig.len(), SIGNATURE_SIZE); + } +} diff --git a/src/dgram.rs b/src/dgram.rs new file mode 100644 index 0000000..6fca81b --- /dev/null +++ b/src/dgram.rs @@ -0,0 +1,346 @@ +//! Datagram transport with automatic fragmentation. +//! +//! Messages larger +//! than `MAX_DGRAM_PAYLOAD` (896 bytes) are split into +//! fragments and reassembled at the destination. +//! +//! Routing is automatic: if the local node is behind +//! symmetric NAT, datagrams are sent through the proxy. + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +/// Maximum payload per datagram fragment (bytes). +/// +/// Ensures each fragment +/// fits in a single UDP packet with protocol overhead. +pub const MAX_DGRAM_PAYLOAD: usize = 896; + +/// Timeout for incomplete fragment reassembly. +pub const REASSEMBLY_TIMEOUT: Duration = Duration::from_secs(10); + +/// A queued datagram waiting for address resolution. +#[derive(Debug, Clone)] +pub struct QueuedDgram { + pub data: Vec<u8>, + pub src: NodeId, + pub queued_at: Instant, +} + +// ── Fragmentation ─────────────────────────────────── + +/// Fragment header: 4 bytes. +/// +/// - `total` (u16 BE): total number of fragments. +/// - `index` (u16 BE): this fragment's index (0-based). +const FRAG_HEADER_SIZE: usize = 4; + +/// Split a message into fragments, each with a 4-byte +/// header and up to `MAX_DGRAM_PAYLOAD` bytes of data. +pub fn fragment(data: &[u8]) -> Vec<Vec<u8>> { + if data.is_empty() { + return vec![make_fragment(1, 0, &[])]; + } + + let chunk_size = MAX_DGRAM_PAYLOAD; + let total = data.len().div_ceil(chunk_size); + let total = total as u16; + + data.chunks(chunk_size) + .enumerate() + .map(|(i, chunk)| make_fragment(total, i as u16, chunk)) + .collect() +} + +fn make_fragment(total: u16, index: u16, data: &[u8]) -> Vec<u8> { + let mut buf = Vec::with_capacity(FRAG_HEADER_SIZE + data.len()); + buf.extend_from_slice(&total.to_be_bytes()); + buf.extend_from_slice(&index.to_be_bytes()); + buf.extend_from_slice(data); + buf +} + +/// Parse a fragment header. +/// +/// Returns `(total_fragments, fragment_index, payload)`. +pub fn parse_fragment(buf: &[u8]) -> Option<(u16, u16, &[u8])> { + if buf.len() < FRAG_HEADER_SIZE { + return None; + } + let total = u16::from_be_bytes([buf[0], buf[1]]); + let index = u16::from_be_bytes([buf[2], buf[3]]); + Some((total, index, &buf[FRAG_HEADER_SIZE..])) +} + +// ── Reassembly ────────────────────────────────────── + +/// State for reassembling fragments from a single sender. +#[derive(Debug)] +struct ReassemblyState { + total: u16, + fragments: HashMap<u16, Vec<u8>>, + started_at: Instant, +} + +/// Fragment reassembler. +/// +/// Tracks incoming fragments per sender and produces +/// complete messages when all fragments arrive. +pub struct Reassembler { + pending: HashMap<NodeId, ReassemblyState>, +} + +impl Reassembler { + pub fn new() -> Self { + Self { + pending: HashMap::new(), + } + } + + /// Feed a fragment. Returns the complete message if + /// all fragments have arrived. + pub fn feed( + &mut self, + sender: NodeId, + total: u16, + index: u16, + data: Vec<u8>, + ) -> Option<Vec<u8>> { + // S2-8: cap fragments to prevent memory bomb + const MAX_FRAGMENTS: u16 = 10; + if total == 0 { + log::debug!("Dgram: dropping fragment with total=0"); + return None; + } + if total > MAX_FRAGMENTS { + log::debug!( + "Dgram: dropping fragment with total={total} > {MAX_FRAGMENTS}" + ); + return None; + } + + // Single fragment → no reassembly needed + if total == 1 && index == 0 { + self.pending.remove(&sender); + return Some(data); + } + + let state = + self.pending + .entry(sender) + .or_insert_with(|| ReassemblyState { + total, + fragments: HashMap::new(), + started_at: Instant::now(), + }); + + // Total mismatch → reset + if state.total != total { + *state = ReassemblyState { + total, + fragments: HashMap::new(), + started_at: Instant::now(), + }; + } + + if index < total { + state.fragments.insert(index, data); + } + + if state.fragments.len() == total as usize { + // All fragments received → reassemble + let mut result = Vec::new(); + for i in 0..total { + if let Some(frag) = state.fragments.get(&i) { + result.extend_from_slice(frag); + } else { + // Should not happen, but guard + self.pending.remove(&sender); + return None; + } + } + self.pending.remove(&sender); + Some(result) + } else { + None + } + } + + /// Remove incomplete reassembly state older than the + /// timeout. + pub fn expire(&mut self) { + self.pending + .retain(|_, state| state.started_at.elapsed() < REASSEMBLY_TIMEOUT); + } + + /// Number of pending incomplete messages. + pub fn pending_count(&self) -> usize { + self.pending.len() + } +} + +impl Default for Reassembler { + fn default() -> Self { + Self::new() + } +} + +// ── Send queue ────────────────────────────────────── + +/// Queue of datagrams waiting for address resolution. +pub struct SendQueue { + queues: HashMap<NodeId, Vec<QueuedDgram>>, +} + +impl SendQueue { + pub fn new() -> Self { + Self { + queues: HashMap::new(), + } + } + + /// Enqueue a datagram for a destination. + pub fn push(&mut self, dst: NodeId, data: Vec<u8>, src: NodeId) { + self.queues.entry(dst).or_default().push(QueuedDgram { + data, + src, + queued_at: Instant::now(), + }); + } + + /// Drain the queue for a destination. + pub fn drain(&mut self, dst: &NodeId) -> Vec<QueuedDgram> { + self.queues.remove(dst).unwrap_or_default() + } + + /// Check if there's a pending queue for a destination. + pub fn has_pending(&self, dst: &NodeId) -> bool { + self.queues.contains_key(dst) + } + + /// Remove stale queued messages (>10s). + pub fn expire(&mut self) { + self.queues.retain(|_, q| { + q.retain(|d| d.queued_at.elapsed() < REASSEMBLY_TIMEOUT); + !q.is_empty() + }); + } +} + +impl Default for SendQueue { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Fragmentation tests ───────────────────────── + + #[test] + fn small_message_single_fragment() { + let frags = fragment(b"hello"); + assert_eq!(frags.len(), 1); + let (total, idx, data) = parse_fragment(&frags[0]).unwrap(); + assert_eq!(total, 1); + assert_eq!(idx, 0); + assert_eq!(data, b"hello"); + } + + #[test] + fn large_message_multiple_fragments() { + let msg = vec![0xAB; MAX_DGRAM_PAYLOAD * 3 + 100]; + let frags = fragment(&msg); + assert_eq!(frags.len(), 4); + + for (i, frag) in frags.iter().enumerate() { + let (total, idx, _) = parse_fragment(frag).unwrap(); + assert_eq!(total, 4); + assert_eq!(idx, i as u16); + } + } + + #[test] + fn empty_message() { + let frags = fragment(b""); + assert_eq!(frags.len(), 1); + let (total, idx, data) = parse_fragment(&frags[0]).unwrap(); + assert_eq!(total, 1); + assert_eq!(idx, 0); + assert!(data.is_empty()); + } + + #[test] + fn fragment_roundtrip() { + let msg = vec![0x42; MAX_DGRAM_PAYLOAD * 2 + 50]; + let frags = fragment(&msg); + + let mut reassembled = Vec::new(); + for frag in &frags { + let (_, _, data) = parse_fragment(frag).unwrap(); + reassembled.extend_from_slice(data); + } + assert_eq!(reassembled, msg); + } + + // ── Reassembler tests ─────────────────────────── + + #[test] + fn reassemble_single() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + let result = r.feed(sender, 1, 0, b"hello".to_vec()); + assert_eq!(result.unwrap(), b"hello"); + assert_eq!(r.pending_count(), 0); + } + + #[test] + fn reassemble_multi() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + + assert!(r.feed(sender, 3, 0, b"aaa".to_vec()).is_none()); + assert!(r.feed(sender, 3, 2, b"ccc".to_vec()).is_none()); + let result = r.feed(sender, 3, 1, b"bbb".to_vec()); + + assert_eq!(result.unwrap(), b"aaabbbccc"); + assert_eq!(r.pending_count(), 0); + } + + #[test] + fn reassemble_out_of_order() { + let mut r = Reassembler::new(); + let sender = NodeId::from_bytes([0x01; 32]); + + // Fragments arrive in reverse order + assert!(r.feed(sender, 2, 1, b"world".to_vec()).is_none()); + let result = r.feed(sender, 2, 0, b"hello".to_vec()); + assert_eq!(result.unwrap(), b"helloworld"); + } + + // ── SendQueue tests ───────────────────────────── + + #[test] + fn send_queue_push_drain() { + let mut q = SendQueue::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let src = NodeId::from_bytes([0x02; 32]); + + q.push(dst, b"msg1".to_vec(), src); + q.push(dst, b"msg2".to_vec(), src); + + assert!(q.has_pending(&dst)); + let msgs = q.drain(&dst); + assert_eq!(msgs.len(), 2); + assert!(!q.has_pending(&dst)); + } + + #[test] + fn parse_truncated() { + assert!(parse_fragment(&[0, 1]).is_none()); + } +} diff --git a/src/dht.rs b/src/dht.rs new file mode 100644 index 0000000..72ec019 --- /dev/null +++ b/src/dht.rs @@ -0,0 +1,1028 @@ +//! Kademlia DHT logic: store, find_node, find_value. +//! +//! Uses explicit state machines for iterative queries +//! instead of nested callbacks. + +use std::collections::{HashMap, HashSet}; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; +use crate::peers::PeerInfo; +use crate::routing::NUM_FIND_NODE; + +// ── Constants ──────────────────────────────────────── + +/// Max parallel queries per lookup. +pub const MAX_QUERY: usize = 6; + +/// Timeout for a single RPC query (seconds). +pub const QUERY_TIMEOUT: Duration = Duration::from_secs(3); + +/// Interval between data restore cycles. +pub const RESTORE_INTERVAL: Duration = Duration::from_secs(120); + +/// Slow maintenance timer (refresh + restore). +pub const SLOW_TIMER_INTERVAL: Duration = Duration::from_secs(600); + +/// Fast maintenance timer (expire + sweep). +pub const FAST_TIMER_INTERVAL: Duration = Duration::from_secs(60); + +/// Number of original replicas for a put. +pub const ORIGINAL_PUT_NUM: i32 = 3; + +/// Timeout waiting for all values in find_value. +pub const RECVD_VALUE_TIMEOUT: Duration = Duration::from_secs(3); + +/// RDP port for store operations. +pub const RDP_STORE_PORT: u16 = 100; + +/// RDP port for get operations. +pub const RDP_GET_PORT: u16 = 101; + +/// RDP connection timeout. +pub const RDP_TIMEOUT: Duration = Duration::from_secs(30); + +// ── Stored data ───────────────────────────────────── + +/// A single stored value with metadata. +#[derive(Debug, Clone)] +pub struct StoredValue { + pub key: Vec<u8>, + pub value: Vec<u8>, + pub id: NodeId, + pub source: NodeId, + pub ttl: u16, + pub stored_at: Instant, + pub is_unique: bool, + + /// Number of original puts remaining. Starts at + /// `ORIGINAL_PUT_NUM` for originator, 0 for replicas. + pub original: i32, + + /// Set of node IDs that already received this value + /// during restore, to avoid duplicate sends. + pub recvd: HashSet<NodeId>, + + /// Monotonic version timestamp. Newer versions + /// (higher value) replace older ones from the same + /// source. Prevents stale replicas from overwriting + /// fresh data during restore/republish. + pub version: u64, +} + +/// Generate a monotonic version number based on the +/// current time (milliseconds since epoch, truncated +/// to u64). Sufficient for conflict resolution — two +/// stores in the same millisecond will have the same +/// version (tie-break: last write wins). +pub fn now_version() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0) +} + +impl StoredValue { + /// Check if this value has expired. + pub fn is_expired(&self) -> bool { + self.stored_at.elapsed() >= Duration::from_secs(self.ttl as u64) + } + + /// Remaining TTL in seconds. + pub fn remaining_ttl(&self) -> u16 { + let elapsed = self.stored_at.elapsed().as_secs(); + if elapsed >= self.ttl as u64 { + 0 + } else { + (self.ttl as u64 - elapsed) as u16 + } + } +} + +// ── Storage container ─────────────────────────────── + +/// Key for the two-level storage map: +/// first level is the target NodeId (SHA1 of key), +/// second level is the raw key bytes. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct StorageKey { + raw: Vec<u8>, +} + +/// Container for stored DHT values. +/// +/// Maps target_id -> raw_key -> set of values. +/// Maps target_id -> raw_key -> set of values. +pub struct DhtStorage { + data: HashMap<NodeId, HashMap<StorageKey, Vec<StoredValue>>>, + /// Maximum number of stored values (0 = unlimited). + max_entries: usize, +} + +/// Default maximum storage entries. +const DEFAULT_MAX_STORAGE: usize = 65536; + +impl DhtStorage { + pub fn new() -> Self { + Self { + data: HashMap::new(), + max_entries: DEFAULT_MAX_STORAGE, + } + } + + /// Set the maximum number of stored values. + pub fn set_max_entries(&mut self, max: usize) { + self.max_entries = max; + } + + /// Store a value. Handles `is_unique` semantics and + /// version-based conflict resolution. + /// + /// If a value with the same key from the same source + /// already exists with a higher version, the store is + /// rejected (prevents stale replicas from overwriting + /// fresh data). + pub fn store(&mut self, val: StoredValue) { + // Enforce storage limit + if self.max_entries > 0 && self.len() >= self.max_entries { + log::warn!( + "Storage full ({} entries), dropping store", + self.max_entries + ); + return; + } + + let key = StorageKey { + raw: val.key.clone(), + }; + let entry = + self.data.entry(val.id).or_default().entry(key).or_default(); + + if val.is_unique { + // Unique: replace existing from same source, + // but only if version is not older + if let Some(pos) = entry.iter().position(|v| v.source == val.source) + { + if val.version > 0 + && entry[pos].version > 0 + && val.version < entry[pos].version + { + log::debug!( + "Rejecting stale unique store: v{} < v{}", + val.version, + entry[pos].version, + ); + return; + } + entry[pos] = val; + } else if entry.is_empty() || !entry[0].is_unique { + entry.clear(); + entry.push(val); + } + return; + } + + // Non-unique: update if same value exists (any + // source), or append. Check version on update. + if let Some(pos) = entry.iter().position(|v| v.value == val.value) { + if val.version > 0 + && entry[pos].version > 0 + && val.version < entry[pos].version + { + log::debug!( + "Rejecting stale store: v{} < v{}", + val.version, + entry[pos].version, + ); + return; + } + entry[pos].ttl = val.ttl; + entry[pos].stored_at = val.stored_at; + entry[pos].version = val.version; + } else { + // Don't add if existing data is unique + if entry.len() == 1 && entry[0].is_unique { + return; + } + entry.push(val); + } + } + + /// Remove a specific value. + pub fn remove(&mut self, id: &NodeId, raw_key: &[u8]) { + let key = StorageKey { + raw: raw_key.to_vec(), + }; + if let Some(inner) = self.data.get_mut(id) { + inner.remove(&key); + if inner.is_empty() { + self.data.remove(id); + } + } + } + + /// Get all values for a target ID and key. + pub fn get(&self, id: &NodeId, raw_key: &[u8]) -> Vec<StoredValue> { + let key = StorageKey { + raw: raw_key.to_vec(), + }; + self.data + .get(id) + .and_then(|inner| inner.get(&key)) + .map(|vals| { + vals.iter().filter(|v| !v.is_expired()).cloned().collect() + }) + .unwrap_or_default() + } + + /// Remove all expired values. + pub fn expire(&mut self) { + self.data.retain(|_, inner| { + inner.retain(|_, vals| { + vals.retain(|v| !v.is_expired()); + !vals.is_empty() + }); + !inner.is_empty() + }); + } + + /// Iterate over all stored values (for restore). + pub fn all_values(&self) -> Vec<StoredValue> { + self.data + .values() + .flat_map(|inner| inner.values()) + .flat_map(|vals| vals.iter()) + .filter(|v| !v.is_expired()) + .cloned() + .collect() + } + + /// Decrement the `original` counter for a value. + /// Returns the new count, or -1 if not found + /// (in which case the value is inserted). + pub fn dec_original(&mut self, val: &StoredValue) -> i32 { + let key = StorageKey { + raw: val.key.clone(), + }; + if let Some(inner) = self.data.get_mut(&val.id) { + if let Some(vals) = inner.get_mut(&key) { + if let Some(existing) = vals + .iter_mut() + .find(|v| v.value == val.value && v.source == val.source) + { + if existing.original > 0 { + existing.original -= 1; + } + return existing.original; + } + } + } + + // Not found: insert it + self.store(val.clone()); + -1 + } + + /// Mark a node as having received a stored value + /// (for restore deduplication). + pub fn mark_received(&mut self, val: &StoredValue, node_id: NodeId) { + let key = StorageKey { + raw: val.key.clone(), + }; + if let Some(inner) = self.data.get_mut(&val.id) { + if let Some(vals) = inner.get_mut(&key) { + if let Some(existing) = + vals.iter_mut().find(|v| v.value == val.value) + { + existing.recvd.insert(node_id); + } + } + } + } + + /// Total number of stored values. + pub fn len(&self) -> usize { + self.data + .values() + .flat_map(|inner| inner.values()) + .map(|vals| vals.len()) + .sum() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for DhtStorage { + fn default() -> Self { + Self::new() + } +} + +// ── Iterative query state machine ─────────────────── + +/// Phase of an iterative Kademlia query. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QueryPhase { + /// Actively searching: sending queries and processing + /// replies. + Searching, + + /// No closer nodes found in last round: query has + /// converged. + Converged, + + /// Query complete: results are ready. + Done, +} + +/// State of an iterative FIND_NODE or FIND_VALUE query. +/// +/// Uses an explicit state machine for iterative lookup. +/// Maximum duration for an iterative query before +/// returning best-effort results. +pub const MAX_QUERY_DURATION: Duration = Duration::from_secs(30); + +pub struct IterativeQuery { + pub target: NodeId, + pub closest: Vec<PeerInfo>, + pub queried: HashSet<NodeId>, + pub pending: HashMap<NodeId, Instant>, + pub phase: QueryPhase, + pub is_find_value: bool, + pub key: Vec<u8>, + pub values: Vec<Vec<u8>>, + pub nonce: u32, + pub started_at: Instant, + + /// Number of iterative rounds completed. Each round + /// is a batch of queries followed by reply processing. + /// Measures the "depth" of the lookup — useful for + /// diagnosing network topology and routing efficiency. + pub hops: u32, +} + +impl IterativeQuery { + /// Create a new FIND_NODE query. + pub fn find_node(target: NodeId, nonce: u32) -> Self { + Self { + target, + closest: Vec::new(), + queried: HashSet::new(), + pending: HashMap::new(), + phase: QueryPhase::Searching, + is_find_value: false, + key: Vec::new(), + values: Vec::new(), + nonce, + started_at: Instant::now(), + hops: 0, + } + } + + /// Create a new FIND_VALUE query. + pub fn find_value(target: NodeId, key: Vec<u8>, nonce: u32) -> Self { + Self { + target, + closest: Vec::new(), + queried: HashSet::new(), + pending: HashMap::new(), + phase: QueryPhase::Searching, + is_find_value: true, + key, + values: Vec::new(), + nonce, + started_at: Instant::now(), + hops: 0, + } + } + + /// Select the next batch of peers to query. + /// + /// Returns up to `MAX_QUERY` un-queried peers from + /// `closest`, sorted by XOR distance to target. + pub fn next_to_query(&self) -> Vec<PeerInfo> { + let max = MAX_QUERY.saturating_sub(self.pending.len()); + self.closest + .iter() + .filter(|p| { + !self.queried.contains(&p.id) + && !self.pending.contains_key(&p.id) + }) + .take(max) + .cloned() + .collect() + } + + /// Process a reply: merge new nodes into closest, + /// remove from pending, detect convergence. + /// Increments the hop counter for route length + /// tracking. + pub fn process_reply(&mut self, from: &NodeId, nodes: Vec<PeerInfo>) { + self.pending.remove(from); + self.queried.insert(*from); + self.hops += 1; + + let prev_best = + self.closest.first().map(|p| self.target.distance(&p.id)); + + // Merge new nodes + for node in nodes { + if node.id != self.target + && !self.closest.iter().any(|c| c.id == node.id) + { + self.closest.push(node); + } + } + + // Sort by XOR distance + self.closest.sort_by(|a, b| { + let da = self.target.distance(&a.id); + let db = self.target.distance(&b.id); + da.cmp(&db) + }); + + // Trim to NUM_FIND_NODE + self.closest.truncate(NUM_FIND_NODE); + + // Check convergence: did the closest node change? + let new_best = + self.closest.first().map(|p| self.target.distance(&p.id)); + + if prev_best == new_best && self.pending.is_empty() { + self.phase = QueryPhase::Converged; + } + } + + /// Process a value reply (for FIND_VALUE). + pub fn process_value(&mut self, value: Vec<u8>) { + self.values.push(value); + self.phase = QueryPhase::Done; + } + + /// Mark a peer as timed out. + pub fn timeout(&mut self, id: &NodeId) { + self.pending.remove(id); + if self.pending.is_empty() && self.next_to_query().is_empty() { + self.phase = QueryPhase::Done; + } + } + + /// Expire all pending queries that have exceeded + /// the timeout. + pub fn expire_pending(&mut self) { + let expired: Vec<NodeId> = self + .pending + .iter() + .filter(|(_, sent_at)| sent_at.elapsed() >= QUERY_TIMEOUT) + .map(|(id, _)| *id) + .collect(); + + for id in expired { + self.timeout(&id); + } + } + + /// Check if the query is complete (converged, + /// finished, or timed out). + pub fn is_done(&self) -> bool { + self.phase == QueryPhase::Done + || self.phase == QueryPhase::Converged + || self.started_at.elapsed() >= MAX_QUERY_DURATION + } +} + +// ── Maintenance: mask_bit exploration ─────────────── + +/// Systematic exploration of the 256-bit ID space. +/// +/// Generates target IDs for find_node queries that probe +/// different regions of the network, populating distant +/// k-buckets that would otherwise remain empty. +/// +/// Used by both DHT and DTUN maintenance. +pub struct MaskBitExplorer { + local_id: NodeId, + mask_bit: usize, +} + +impl MaskBitExplorer { + pub fn new(local_id: NodeId) -> Self { + Self { + local_id, + mask_bit: 1, + } + } + + /// Generate the next pair of exploration targets. + /// + /// Each call produces two targets by clearing specific + /// bits in the local ID, then advances by 2 bits. + /// After bit 20, resets to 1. + pub fn next_targets(&mut self) -> (NodeId, NodeId) { + let id_bytes = *self.local_id.as_bytes(); + + let t1 = Self::clear_bit(id_bytes, self.mask_bit); + let t2 = Self::clear_bit(id_bytes, self.mask_bit + 1); + + self.mask_bit += 2; + if self.mask_bit > 20 { + self.mask_bit = 1; + } + + (t1, t2) + } + + /// Current mask_bit position (for testing). + pub fn position(&self) -> usize { + self.mask_bit + } + + fn clear_bit( + mut bytes: [u8; crate::id::ID_LEN], + bit_from_msb: usize, + ) -> NodeId { + if bit_from_msb == 0 || bit_from_msb > crate::id::ID_BITS { + return NodeId::from_bytes(bytes); + } + let pos = bit_from_msb - 1; // 0-indexed + let byte_idx = pos / 8; + let bit_idx = 7 - (pos % 8); + bytes[byte_idx] &= !(1 << bit_idx); + NodeId::from_bytes(bytes) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + + fn make_peer(byte: u8, port: u16) -> PeerInfo { + PeerInfo::new( + NodeId::from_bytes([byte; 32]), + SocketAddr::from(([127, 0, 0, 1], port)), + ) + } + + // ── DhtStorage tests ──────────────────────────── + + #[test] + fn storage_store_and_get() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"test-key"); + let val = StoredValue { + key: b"test-key".to_vec(), + value: b"hello".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(val); + let got = s.get(&id, b"test-key"); + assert_eq!(got.len(), 1); + assert_eq!(got[0].value, b"hello"); + } + + #[test] + fn storage_unique_replaces() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"uk"); + let src = NodeId::from_bytes([0x01; 32]); + + let v1 = StoredValue { + key: b"uk".to_vec(), + value: b"v1".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(v1); + + let v2 = StoredValue { + key: b"uk".to_vec(), + value: b"v2".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(v2); + + let got = s.get(&id, b"uk"); + assert_eq!(got.len(), 1); + assert_eq!(got[0].value, b"v2"); + } + + #[test] + fn storage_unique_rejects_other_source() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"uk"); + + let v1 = StoredValue { + key: b"uk".to_vec(), + value: b"v1".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(v1); + + let v2 = StoredValue { + key: b"uk".to_vec(), + value: b"v2".to_vec(), + id, + source: NodeId::from_bytes([0x02; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(v2); + + let got = s.get(&id, b"uk"); + assert_eq!(got.len(), 1); + assert_eq!(got[0].value, b"v1"); + } + + #[test] + fn storage_multiple_non_unique() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"k"); + + for i in 0..3u8 { + s.store(StoredValue { + key: b"k".to_vec(), + value: vec![i], + id, + source: NodeId::from_bytes([i; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: HashSet::new(), + version: 0, + }); + } + + assert_eq!(s.get(&id, b"k").len(), 3); + } + + #[test] + fn storage_remove() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"k"); + s.store(StoredValue { + key: b"k".to_vec(), + value: b"v".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: HashSet::new(), + version: 0, + }); + s.remove(&id, b"k"); + assert!(s.get(&id, b"k").is_empty()); + assert!(s.is_empty()); + } + + #[test] + fn storage_dec_original() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"k"); + let val = StoredValue { + key: b"k".to_vec(), + value: b"v".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 3, + recvd: HashSet::new(), + version: 0, + }; + s.store(val.clone()); + + assert_eq!(s.dec_original(&val), 2); + assert_eq!(s.dec_original(&val), 1); + assert_eq!(s.dec_original(&val), 0); + assert_eq!(s.dec_original(&val), 0); // stays at 0 + } + + #[test] + fn storage_mark_received() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"k"); + let val = StoredValue { + key: b"k".to_vec(), + value: b"v".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: HashSet::new(), + version: 0, + }; + s.store(val.clone()); + + let node = NodeId::from_bytes([0x42; 32]); + s.mark_received(&val, node); + + let got = s.get(&id, b"k"); + assert!(got[0].recvd.contains(&node)); + } + + // ── IterativeQuery tests ──────────────────────── + + #[test] + fn query_process_reply_sorts() { + let target = NodeId::from_bytes([0x00; 32]); + let mut q = IterativeQuery::find_node(target, 1); + + // Simulate: we have node 0xFF pending + let far = NodeId::from_bytes([0xFF; 32]); + q.pending.insert(far, Instant::now()); + + // Reply with closer nodes + let nodes = vec![ + make_peer(0x10, 3000), + make_peer(0x01, 3001), + make_peer(0x05, 3002), + ]; + q.process_reply(&far, nodes); + + // Should be sorted by distance from target + assert_eq!(q.closest[0].id, NodeId::from_bytes([0x01; 32])); + assert_eq!(q.closest[1].id, NodeId::from_bytes([0x05; 32])); + assert_eq!(q.closest[2].id, NodeId::from_bytes([0x10; 32])); + } + + #[test] + fn query_converges_when_no_closer() { + let target = NodeId::from_bytes([0x00; 32]); + let mut q = IterativeQuery::find_node(target, 1); + + // Add initial closest + q.closest.push(make_peer(0x01, 3000)); + + // Simulate reply with no closer nodes + let from = NodeId::from_bytes([0x01; 32]); + q.pending.insert(from, Instant::now()); + q.process_reply(&from, vec![make_peer(0x02, 3001)]); + + // 0x01 is still closest, pending is empty -> converged + assert_eq!(q.phase, QueryPhase::Converged); + } + + #[test] + fn query_find_value_done_on_value() { + let target = NodeId::from_bytes([0x00; 32]); + let mut q = IterativeQuery::find_value(target, b"key".to_vec(), 1); + + q.process_value(b"found-it".to_vec()); + assert!(q.is_done()); + assert_eq!(q.values, vec![b"found-it".to_vec()]); + } + + // ── MaskBitExplorer tests ─────────────────────── + + #[test] + fn mask_bit_cycles() { + let id = NodeId::from_bytes([0xFF; 32]); + let mut explorer = MaskBitExplorer::new(id); + + assert_eq!(explorer.position(), 1); + explorer.next_targets(); + assert_eq!(explorer.position(), 3); + explorer.next_targets(); + assert_eq!(explorer.position(), 5); + + // Run through full cycle + for _ in 0..8 { + explorer.next_targets(); + } + + // 5 + 8*2 = 21 > 20, so reset to 1 + assert_eq!(explorer.position(), 1); + } + + #[test] + fn mask_bit_produces_different_targets() { + let id = NodeId::from_bytes([0xFF; 32]); + let mut explorer = MaskBitExplorer::new(id); + + let (t1, t2) = explorer.next_targets(); + assert_ne!(t1, t2); + assert_ne!(t1, id); + assert_ne!(t2, id); + } + + // ── Content versioning tests ────────────────── + + #[test] + fn version_rejects_stale_unique() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"vk"); + let src = NodeId::from_bytes([0x01; 32]); + + // Store version 100 + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"new".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 100, + }); + + // Try to store older version 50 — should be rejected + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"old".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 50, + }); + + let got = s.get(&id, b"vk"); + assert_eq!(got.len(), 1); + assert_eq!(got[0].value, b"new"); + assert_eq!(got[0].version, 100); + } + + #[test] + fn version_accepts_newer() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"vk"); + let src = NodeId::from_bytes([0x01; 32]); + + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"old".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 50, + }); + + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"new".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 100, + }); + + let got = s.get(&id, b"vk"); + assert_eq!(got[0].value, b"new"); + } + + #[test] + fn version_zero_always_accepted() { + // version=0 means "no versioning" — always accepted + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"vk"); + let src = NodeId::from_bytes([0x01; 32]); + + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"v1".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }); + + s.store(StoredValue { + key: b"vk".to_vec(), + value: b"v2".to_vec(), + id, + source: src, + ttl: 300, + stored_at: Instant::now(), + is_unique: true, + original: 3, + recvd: HashSet::new(), + version: 0, + }); + + let got = s.get(&id, b"vk"); + assert_eq!(got[0].value, b"v2"); + } + + #[test] + fn version_rejects_stale_non_unique() { + let mut s = DhtStorage::new(); + let id = NodeId::from_key(b"nk"); + + // Store value with version 100 + s.store(StoredValue { + key: b"nk".to_vec(), + value: b"same".to_vec(), + id, + source: NodeId::from_bytes([0x01; 32]), + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: HashSet::new(), + version: 100, + }); + + // Same value with older version — rejected + s.store(StoredValue { + key: b"nk".to_vec(), + value: b"same".to_vec(), + id, + source: NodeId::from_bytes([0x02; 32]), + ttl: 600, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: HashSet::new(), + version: 50, + }); + + let got = s.get(&id, b"nk"); + assert_eq!(got.len(), 1); + assert_eq!(got[0].version, 100); + assert_eq!(got[0].ttl, 300); // TTL not updated + } + + // ── Route length tests ──────────────────────── + + #[test] + fn query_hops_increment() { + let target = NodeId::from_bytes([0x00; 32]); + let mut q = IterativeQuery::find_node(target, 1); + assert_eq!(q.hops, 0); + + let from = NodeId::from_bytes([0xFF; 32]); + q.pending.insert(from, Instant::now()); + q.process_reply(&from, vec![make_peer(0x10, 3000)]); + assert_eq!(q.hops, 1); + + let from2 = NodeId::from_bytes([0x10; 32]); + q.pending.insert(from2, Instant::now()); + q.process_reply(&from2, vec![make_peer(0x05, 3001)]); + assert_eq!(q.hops, 2); + } + + #[test] + fn now_version_monotonic() { + let v1 = now_version(); + std::thread::sleep(std::time::Duration::from_millis(2)); + let v2 = now_version(); + assert!(v2 >= v1); + } +} diff --git a/src/dtun.rs b/src/dtun.rs new file mode 100644 index 0000000..367262c --- /dev/null +++ b/src/dtun.rs @@ -0,0 +1,436 @@ +//! Distributed tunnel for NAT traversal (DTUN). +//! +//! Maintains a separate +//! routing table used to register NAT'd nodes and resolve +//! their addresses for hole-punching. +//! +//! ## How it works +//! +//! 1. A node behind NAT **registers** itself with the +//! k-closest global nodes (find_node + register). +//! 2. When another node wants to reach a NAT'd node, it +//! does a **find_value** in the DTUN table to discover +//! which global node holds the registration. +//! 3. It then sends a **request** to that global node, +//! which forwards the request to the NAT'd node, +//! causing it to send a packet that punches a hole. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; +use crate::peers::PeerInfo; +use crate::routing::RoutingTable; + +// ── Constants ──────────────────────────────────────── + +/// k-closest nodes for DTUN lookups. +pub const DTUN_NUM_FIND_NODE: usize = 10; + +/// Max parallel queries. +pub const DTUN_MAX_QUERY: usize = 6; + +/// Query timeout. +pub const DTUN_QUERY_TIMEOUT: Duration = Duration::from_secs(2); + +/// Retries for reachability requests. +pub const DTUN_REQUEST_RETRY: usize = 2; + +/// Request timeout. +pub const DTUN_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); + +/// TTL for node registrations. +pub const DTUN_REGISTERED_TTL: Duration = Duration::from_secs(300); + +/// Refresh timer interval. +pub const DTUN_TIMER_INTERVAL: Duration = Duration::from_secs(30); + +/// Maintenance interval (mask_bit exploration). +pub const DTUN_MAINTAIN_INTERVAL: Duration = Duration::from_secs(120); + +// ── Registration record ───────────────────────────── + +/// A node registered in the DTUN overlay. +#[derive(Debug, Clone)] +pub struct Registration { + /// The registered node's address. + pub addr: SocketAddr, + + /// Session identifier (must match for updates). + pub session: u32, + + /// When this registration was created/refreshed. + pub registered_at: Instant, +} + +impl Registration { + /// Check if this registration has expired. + pub fn is_expired(&self) -> bool { + self.registered_at.elapsed() >= DTUN_REGISTERED_TTL + } +} + +// ── Request state ─────────────────────────────────── + +/// State of a reachability request. +#[derive(Debug)] +pub struct RequestState { + /// Target node we're trying to reach. + pub target: NodeId, + + /// When the request was sent. + pub sent_at: Instant, + + /// Remaining retries. + pub retries: usize, + + /// Whether find_value completed. + pub found: bool, + + /// The intermediary node (if found). + pub intermediary: Option<PeerInfo>, +} + +// ── DTUN ──────────────────────────────────────────── + +/// Distributed tunnel for NAT traversal. +pub struct Dtun { + /// Separate routing table for the DTUN overlay. + table: RoutingTable, + + /// Nodes registered through us (we're their + /// "registration server"). Capped at 1000. + registered: HashMap<NodeId, Registration>, + + /// Our own registration session. + register_session: u32, + + /// Whether we're currently registering. + registering: bool, + + /// Last time registration was refreshed. + last_registered: Instant, + + /// Pending reachability requests by nonce. + requests: HashMap<u32, RequestState>, + + /// Whether DTUN is enabled. + enabled: bool, + + /// Local node ID. + id: NodeId, + + /// Mask bit for maintain() exploration. + mask_bit: usize, + + /// Last maintain() call. + last_maintain: Instant, +} + +impl Dtun { + pub fn new(id: NodeId) -> Self { + Self { + table: RoutingTable::new(id), + registered: HashMap::new(), + register_session: 0, + registering: false, + last_registered: Instant::now(), + requests: HashMap::new(), + enabled: true, + id, + mask_bit: 1, + last_maintain: Instant::now(), + } + } + + /// Access the DTUN routing table. + pub fn table(&self) -> &RoutingTable { + &self.table + } + + /// Mutable access to the routing table. + pub fn table_mut(&mut self) -> &mut RoutingTable { + &mut self.table + } + + /// Whether DTUN is enabled. + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Enable or disable DTUN. + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Current registration session. + pub fn session(&self) -> u32 { + self.register_session + } + + // ── Registration (server side) ────────────────── + + /// Register a remote node (we act as their + /// registration server). + /// + /// Returns `true` if the registration was accepted. + pub fn register_node( + &mut self, + id: NodeId, + addr: SocketAddr, + session: u32, + ) -> bool { + const MAX_REGISTRATIONS: usize = 1000; + if let Some(existing) = self.registered.get(&id) { + if existing.session != session && !existing.is_expired() { + return false; + } + } + if self.registered.len() >= MAX_REGISTRATIONS + && !self.registered.contains_key(&id) + { + log::debug!("DTUN: registration limit reached"); + return false; + } + + self.registered.insert( + id, + Registration { + addr, + session, + registered_at: Instant::now(), + }, + ); + true + } + + /// Look up a registered node by ID. + pub fn get_registered(&self, id: &NodeId) -> Option<&Registration> { + self.registered.get(id).filter(|r| !r.is_expired()) + } + + /// Remove expired registrations. + pub fn expire_registrations(&mut self) { + self.registered.retain(|_, r| !r.is_expired()); + } + + /// Number of active registrations. + pub fn registration_count(&self) -> usize { + self.registered.values().filter(|r| !r.is_expired()).count() + } + + // ── Registration (client side) ────────────────── + + /// Prepare to register ourselves. Increments the + /// session and returns (session, closest_nodes) for + /// the caller to send DTUN_REGISTER messages. + pub fn prepare_register(&mut self) -> (u32, Vec<PeerInfo>) { + self.register_session = self.register_session.wrapping_add(1); + self.registering = true; + let closest = self.table.closest(&self.id, DTUN_NUM_FIND_NODE); + (self.register_session, closest) + } + + /// Mark registration as complete. + pub fn registration_done(&mut self) { + self.registering = false; + self.last_registered = Instant::now(); + } + + /// Check if re-registration is needed. + pub fn needs_reregister(&self) -> bool { + !self.registering + && self.last_registered.elapsed() >= DTUN_REGISTERED_TTL / 2 + } + + // ── Reachability requests ─────────────────────── + + /// Start a reachability request for a target node. + pub fn start_request(&mut self, nonce: u32, target: NodeId) { + self.requests.insert( + nonce, + RequestState { + target, + sent_at: Instant::now(), + retries: DTUN_REQUEST_RETRY, + found: false, + intermediary: None, + }, + ); + } + + /// Record that find_value found the intermediary for + /// a request. + pub fn request_found(&mut self, nonce: u32, intermediary: PeerInfo) { + if let Some(req) = self.requests.get_mut(&nonce) { + req.found = true; + req.intermediary = Some(intermediary); + } + } + + /// Get the intermediary for a pending request. + pub fn get_request(&self, nonce: &u32) -> Option<&RequestState> { + self.requests.get(nonce) + } + + /// Remove a completed or timed-out request. + pub fn remove_request(&mut self, nonce: &u32) -> Option<RequestState> { + self.requests.remove(nonce) + } + + /// Expire timed-out requests. + pub fn expire_requests(&mut self) { + self.requests.retain(|_, req| { + req.sent_at.elapsed() < DTUN_REQUEST_TIMEOUT || req.retries > 0 + }); + } + + // ── Maintenance ───────────────────────────────── + + /// Periodic maintenance: explore the ID space with + /// mask_bit and expire stale data. + /// + /// Returns target IDs for find_node queries, or empty + /// if maintenance isn't due yet. + pub fn maintain(&mut self) -> Vec<NodeId> { + if self.last_maintain.elapsed() < DTUN_MAINTAIN_INTERVAL { + return Vec::new(); + } + self.last_maintain = Instant::now(); + + // Generate exploration targets + let id_bytes = *self.id.as_bytes(); + let t1 = clear_bit(id_bytes, self.mask_bit); + let t2 = clear_bit(id_bytes, self.mask_bit + 1); + + self.mask_bit += 2; + if self.mask_bit > 20 { + self.mask_bit = 1; + } + + self.expire_registrations(); + self.expire_requests(); + + vec![t1, t2] + } +} + +/// Clear a specific bit (1-indexed from MSB) in a NodeId. +fn clear_bit( + mut bytes: [u8; crate::id::ID_LEN], + bit_from_msb: usize, +) -> NodeId { + if bit_from_msb == 0 || bit_from_msb > crate::id::ID_BITS { + return NodeId::from_bytes(bytes); + } + let pos = bit_from_msb - 1; + let byte_idx = pos / 8; + let bit_idx = 7 - (pos % 8); + bytes[byte_idx] &= !(1 << bit_idx); + NodeId::from_bytes(bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], port)) + } + + #[test] + fn register_and_lookup() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + let nid = NodeId::from_bytes([0x02; 32]); + + assert!(dtun.register_node(nid, addr(3000), 1)); + assert!(dtun.get_registered(&nid).is_some()); + assert_eq!(dtun.registration_count(), 1); + } + + #[test] + fn register_rejects_different_session() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + let nid = NodeId::from_bytes([0x02; 32]); + + assert!(dtun.register_node(nid, addr(3000), 1)); + + // Different session, not expired → rejected + assert!(!dtun.register_node(nid, addr(3001), 2)); + } + + #[test] + fn expire_registrations() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + let nid = NodeId::from_bytes([0x02; 32]); + dtun.registered.insert( + nid, + Registration { + addr: addr(3000), + session: 1, + + // Expired: registered 10 minutes ago + registered_at: Instant::now() - Duration::from_secs(600), + }, + ); + + dtun.expire_registrations(); + assert_eq!(dtun.registration_count(), 0); + } + + #[test] + fn request_lifecycle() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + let target = NodeId::from_bytes([0x02; 32]); + + dtun.start_request(42, target); + assert!(dtun.get_request(&42).is_some()); + + let intermediary = + PeerInfo::new(NodeId::from_bytes([0x03; 32]), addr(4000)); + dtun.request_found(42, intermediary); + assert!(dtun.get_request(&42).unwrap().found); + + dtun.remove_request(&42); + assert!(dtun.get_request(&42).is_none()); + } + + #[test] + fn prepare_register_increments_session() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + let (s1, _) = dtun.prepare_register(); + let (s2, _) = dtun.prepare_register(); + assert_eq!(s2, s1 + 1); + } + + #[test] + fn maintain_returns_targets() { + let mut dtun = Dtun::new(NodeId::from_bytes([0xFF; 32])); + + // Force last_maintain to be old enough + dtun.last_maintain = Instant::now() - DTUN_MAINTAIN_INTERVAL; + + let targets = dtun.maintain(); + assert_eq!(targets.len(), 2); + + // Should differ from local ID + assert_ne!(targets[0], dtun.id); + assert_ne!(targets[1], dtun.id); + } + + #[test] + fn maintain_skips_if_recent() { + let mut dtun = Dtun::new(NodeId::from_bytes([0xFF; 32])); + let targets = dtun.maintain(); + assert!(targets.is_empty()); + } + + #[test] + fn enable_disable() { + let mut dtun = Dtun::new(NodeId::from_bytes([0x01; 32])); + assert!(dtun.is_enabled()); + dtun.set_enabled(false); + assert!(!dtun.is_enabled()); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..d5d6f56 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,80 @@ +use std::fmt; +use std::io; + +#[derive(Debug)] +pub enum Error { + Io(io::Error), + + /// Bad magic number in header. + BadMagic(u16), + + /// Protocol version not supported. + UnsupportedVersion(u8), + + /// Unknown message type byte. + UnknownMessageType(u8), + + /// Packet or buffer too small for the operation. + BufferTooSmall, + + /// Payload length doesn't match declared length. + PayloadMismatch, + + /// Generic invalid message (malformed content). + InvalidMessage, + + /// Not connected to the network. + NotConnected, + + /// Operation timed out. + Timeout, + + /// Signature verification failed. + BadSignature, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Io(e) => write!(f, "I/O error: {e}"), + Error::BadMagic(m) => { + write!(f, "bad magic: 0x{m:04x}") + } + Error::UnsupportedVersion(v) => { + write!(f, "unsupported version: {v}") + } + Error::UnknownMessageType(t) => { + write!(f, "unknown message type: 0x{t:02x}") + } + Error::BufferTooSmall => { + write!(f, "buffer too small") + } + Error::PayloadMismatch => { + write!(f, "payload length mismatch") + } + Error::InvalidMessage => { + write!(f, "invalid message") + } + Error::NotConnected => write!(f, "not connected"), + Error::Timeout => write!(f, "timeout"), + Error::BadSignature => { + write!(f, "signature verification failed") + } + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Io(e) => Some(e), + _ => None, + } + } +} + +impl From<io::Error> for Error { + fn from(e: io::Error) -> Self { + Error::Io(e) + } +} diff --git a/src/event.rs b/src/event.rs new file mode 100644 index 0000000..df7295d --- /dev/null +++ b/src/event.rs @@ -0,0 +1,30 @@ +//! Event types for typed callback integration. +//! +//! Alternative to `Box<dyn Fn>` callbacks. Applications +//! can receive events via `std::sync::mpsc::Receiver`. + +use crate::id::NodeId; +use crate::rdp::{RdpAddr, RdpEvent}; + +/// Events emitted by a Node. +#[derive(Debug, Clone)] +pub enum NodeEvent { + /// Datagram received from a peer. + DgramReceived { data: Vec<u8>, from: NodeId }, + + /// RDP event on a connection. + Rdp { + desc: i32, + addr: RdpAddr, + event: RdpEvent, + }, + + /// Peer added to routing table. + PeerAdded(NodeId), + + /// Value stored via DHT STORE. + ValueStored { key: Vec<u8> }, + + /// Node shutdown initiated. + Shutdown, +} diff --git a/src/handlers.rs b/src/handlers.rs new file mode 100644 index 0000000..f4c5b2c --- /dev/null +++ b/src/handlers.rs @@ -0,0 +1,1049 @@ +//! Packet dispatch and message handlers. +//! +//! Extension impl block for Node. All handle_* +//! methods process incoming protocol messages. + +use std::net::SocketAddr; +use std::time::Instant; + +use crate::msg; +use crate::nat::NatState; +use crate::node::Node; +use crate::peers::PeerInfo; +use crate::wire::{DOMAIN_INET, DOMAIN_INET6, HEADER_SIZE, MsgHeader, MsgType}; + +impl Node { + // ── Packet dispatch ───────────────────────────── + + pub(crate) fn handle_packet(&mut self, raw: &[u8], from: SocketAddr) { + self.metrics + .bytes_received + .fetch_add(raw.len() as u64, std::sync::atomic::Ordering::Relaxed); + + // Verify signature and strip it + let buf = match self.verify_incoming(raw, from) { + Some(b) => b, + None => { + self.metrics + .packets_rejected + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + log::trace!("Dropped unsigned packet from {from}"); + return; + } + }; + + self.metrics + .messages_received + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Reject banned peers + if self.ban_list.is_banned(&from) { + log::trace!("Dropped packet from banned peer {from}"); + return; + } + + // Rate limit per source IP + if !self.rate_limiter.allow(from.ip()) { + log::trace!("Rate limited packet from {from}"); + return; + } + + if buf.len() < HEADER_SIZE { + return; + } + + let hdr = match MsgHeader::parse(buf) { + Ok(h) => h, + Err(e) => { + log::trace!("Dropped packet from {from}: {e}"); + return; + } + }; + + log::trace!( + "Received {:?} from {from} (src={:?})", + hdr.msg_type, + hdr.src + ); + + // Register sender as peer and record successful + // communication (clears ban list failures) + self.peers.add(PeerInfo::new(hdr.src, from)); + self.ban_list.record_success(&from); + + match hdr.msg_type { + // ── DHT messages ──────────────────────── + MsgType::DhtPing => { + self.handle_dht_ping(buf, &hdr, from); + } + MsgType::DhtPingReply => { + self.handle_dht_ping_reply(buf, &hdr, from); + } + MsgType::DhtFindNode => { + self.handle_dht_find_node(buf, &hdr, from); + } + MsgType::DhtFindNodeReply => { + self.handle_dht_find_node_reply(buf, &hdr, from); + } + MsgType::DhtStore => { + self.handle_dht_store(buf, &hdr, from); + } + MsgType::DhtFindValue => { + self.handle_dht_find_value(buf, &hdr, from); + } + MsgType::DhtFindValueReply => { + self.handle_dht_find_value_reply(buf, &hdr); + } + + // ── NAT detection ─────────────────────── + MsgType::NatEcho => { + // Respond with our observed address of the sender + if let Ok(nonce) = msg::parse_nat_echo(buf) { + log::debug!("NatEcho from {:?} nonce={nonce}", hdr.src); + let reply = msg::NatEchoReply { + nonce, + domain: if from.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }, + port: from.port(), + addr: { + let mut a = [0u8; 16]; + match from.ip() { + std::net::IpAddr::V4(v4) => { + a[..4].copy_from_slice(&v4.octets()) + } + std::net::IpAddr::V6(v6) => { + a.copy_from_slice(&v6.octets()) + } + } + a + }, + }; + let size = HEADER_SIZE + msg::NAT_ECHO_REPLY_BODY; + let mut rbuf = vec![0u8; size]; + let rhdr = MsgHeader::new( + MsgType::NatEchoReply, + Self::len16(size), + self.id, + hdr.src, + ); + if rhdr.write(&mut rbuf).is_ok() { + msg::write_nat_echo_reply(&mut rbuf, &reply); + let _ = self.send_signed(&rbuf, from); + } + } + } + MsgType::NatEchoReply => { + if let Ok(reply) = msg::parse_nat_echo_reply(buf) { + let observed_port = reply.port; + let observed_ip: std::net::IpAddr = + if reply.domain == DOMAIN_INET { + std::net::IpAddr::V4(std::net::Ipv4Addr::new( + reply.addr[0], + reply.addr[1], + reply.addr[2], + reply.addr[3], + )) + } else { + std::net::IpAddr::V6(std::net::Ipv6Addr::from( + reply.addr, + )) + }; + let observed = SocketAddr::new(observed_ip, observed_port); + let local = self.net.local_addr().unwrap_or(from); + let action = + self.nat.recv_echo_reply(reply.nonce, observed, local); + log::info!( + "NAT echo reply: observed={observed} action={action:?}" + ); + + // After NAT detection completes, register with DTUN + if let crate::nat::EchoReplyAction::DetectionComplete( + state, + ) = action + { + log::info!("NAT type detected: {state:?}"); + if state != NatState::Global && self.is_dtun { + self.dtun_register(); + } + if state == NatState::SymmetricNat { + // Find a global node to use as proxy + let closest = self.dht_table.closest(&self.id, 1); + if let Some(server) = closest.first() { + self.proxy.set_server(server.clone()); + let nonce = self.alloc_nonce(); + if let Some(n) = + self.proxy.start_register(nonce) + { + // Send ProxyRegister msg + let size = HEADER_SIZE + 8; + let mut buf = vec![0u8; size]; + let hdr = MsgHeader::new( + MsgType::ProxyRegister, + Self::len16(size), + self.id, + server.id, + ); + if hdr.write(&mut buf).is_ok() { + let session = self.dtun.session(); + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice( + &session.to_be_bytes(), + ); + buf[HEADER_SIZE + 4..HEADER_SIZE + 8] + .copy_from_slice(&n.to_be_bytes()); + if let Err(e) = + self.send_signed(&buf, server.addr) + { + log::warn!( + "ProxyRegister send failed: {e}" + ); + } + } + log::info!( + "Sent ProxyRegister to {:?}", + server.id + ); + } + } else { + log::warn!( + "No peers available as proxy server" + ); + } + } + } + } + } + MsgType::NatEchoRedirect | MsgType::NatEchoRedirectReply => { + log::debug!("NAT redirect: {:?}", hdr.msg_type); + } + + // ── DTUN ──────────────────────────────── + MsgType::DtunPing => { + // Reuse DHT ping handler logic + self.handle_dht_ping(buf, &hdr, from); + self.dtun.table_mut().add(PeerInfo::new(hdr.src, from)); + } + MsgType::DtunPingReply => { + self.handle_dht_ping_reply(buf, &hdr, from); + self.dtun.table_mut().add(PeerInfo::new(hdr.src, from)); + } + MsgType::DtunFindNode => { + // Respond with closest from DTUN table + if let Ok(find) = msg::parse_find_node(buf) { + self.dtun.table_mut().add(PeerInfo::new(hdr.src, from)); + let closest = self + .dtun + .table() + .closest(&find.target, self.config.num_find_node); + let domain = if from.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }; + let reply_msg = msg::FindNodeReplyMsg { + nonce: find.nonce, + id: find.target, + domain, + nodes: closest, + }; + let mut rbuf = [0u8; 2048]; + let rhdr = MsgHeader::new( + MsgType::DtunFindNodeReply, + 0, + self.id, + hdr.src, + ); + if rhdr.write(&mut rbuf).is_ok() { + let total = + msg::write_find_node_reply(&mut rbuf, &reply_msg); + rbuf[4..6] + .copy_from_slice(&Self::len16(total).to_be_bytes()); + let _ = self.send_signed(&rbuf[..total], from); + } + } + } + MsgType::DtunFindNodeReply => { + if let Ok(reply) = msg::parse_find_node_reply(buf) { + log::debug!( + "DTUN find_node_reply: {} nodes", + reply.nodes.len() + ); + for node in &reply.nodes { + self.dtun.table_mut().add(node.clone()); + } + } + } + MsgType::DtunFindValue | MsgType::DtunFindValueReply => { + log::debug!("DTUN find_value: {:?}", hdr.msg_type); + } + MsgType::DtunRegister => { + if let Ok(session) = msg::parse_dtun_register(buf) { + log::debug!( + "DTUN register from {:?} session={session}", + hdr.src + ); + self.dtun.register_node(hdr.src, from, session); + } + } + MsgType::DtunRequest => { + if let Ok((nonce, target)) = msg::parse_dtun_request(buf) { + log::debug!("DTUN request for {:?} nonce={nonce}", target); + + // Check if we have this node registered + if let Some(reg) = self.dtun.get_registered(&target) { + log::debug!( + "DTUN: forwarding request to {:?}", + reg.addr + ); + } + } + } + MsgType::DtunRequestBy | MsgType::DtunRequestReply => { + log::debug!("DTUN request_by/reply: {:?}", hdr.msg_type); + } + + // ── Proxy ─────────────────────────────── + MsgType::ProxyRegister => { + if buf.len() >= HEADER_SIZE + 8 { + let session = u32::from_be_bytes([ + buf[HEADER_SIZE], + buf[HEADER_SIZE + 1], + buf[HEADER_SIZE + 2], + buf[HEADER_SIZE + 3], + ]); + let nonce = u32::from_be_bytes([ + buf[HEADER_SIZE + 4], + buf[HEADER_SIZE + 5], + buf[HEADER_SIZE + 6], + buf[HEADER_SIZE + 7], + ]); + log::debug!( + "Proxy register from {:?} session={session}", + hdr.src + ); + self.proxy.register_client(hdr.src, from, session); + + // Send reply + let size = HEADER_SIZE + 4; + let mut rbuf = vec![0u8; size]; + let rhdr = MsgHeader::new( + MsgType::ProxyRegisterReply, + Self::len16(size), + self.id, + hdr.src, + ); + if rhdr.write(&mut rbuf).is_ok() { + rbuf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&nonce.to_be_bytes()); + let _ = self.send_signed(&rbuf, from); + } + } + } + MsgType::ProxyRegisterReply => { + if let Ok(nonce) = msg::parse_ping(buf) { + self.proxy.recv_register_reply(nonce); + } + } + MsgType::ProxyStore => { + // Forward store to DHT on behalf of the client + self.handle_dht_store(buf, &hdr, from); + } + MsgType::ProxyGet | MsgType::ProxyGetReply => { + log::debug!("Proxy get/reply: {:?}", hdr.msg_type); + } + MsgType::ProxyDgram | MsgType::ProxyDgramForwarded => { + // Forward dgram payload + self.handle_dgram(buf, &hdr); + } + MsgType::ProxyRdp | MsgType::ProxyRdpForwarded => { + // Forward RDP payload + self.handle_rdp(buf, &hdr); + } + + // ── Transport ─────────────────────────── + MsgType::Dgram => { + self.handle_dgram(buf, &hdr); + } + MsgType::Rdp => { + self.handle_rdp(buf, &hdr); + } + + // ── Advertise ─────────────────────────── + MsgType::Advertise => { + if buf.len() >= HEADER_SIZE + 8 { + let nonce = u32::from_be_bytes([ + buf[HEADER_SIZE], + buf[HEADER_SIZE + 1], + buf[HEADER_SIZE + 2], + buf[HEADER_SIZE + 3], + ]); + log::debug!("Advertise from {:?} nonce={nonce}", hdr.src); + self.advertise.recv_advertise(hdr.src); + + // Send reply + let size = HEADER_SIZE + 8; + let mut rbuf = vec![0u8; size]; + let rhdr = MsgHeader::new( + MsgType::AdvertiseReply, + Self::len16(size), + self.id, + hdr.src, + ); + if rhdr.write(&mut rbuf).is_ok() { + rbuf[HEADER_SIZE..HEADER_SIZE + 8].copy_from_slice( + &buf[HEADER_SIZE..HEADER_SIZE + 8], + ); + let _ = self.send_signed(&rbuf, from); + } + } + } + MsgType::AdvertiseReply => { + if buf.len() >= HEADER_SIZE + 8 { + let nonce = u32::from_be_bytes([ + buf[HEADER_SIZE], + buf[HEADER_SIZE + 1], + buf[HEADER_SIZE + 2], + buf[HEADER_SIZE + 3], + ]); + self.advertise.recv_reply(nonce); + } + } + } + } + + // ── DHT message handlers ──────────────────────── + + pub(crate) fn handle_dht_ping( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let nonce = match msg::parse_ping(buf) { + Ok(n) => n, + Err(_) => return, + }; + + if hdr.dst != self.id { + return; + } + + log::debug!("DHT ping from {:?} nonce={nonce}", hdr.src); + + // Add to routing table + self.dht_table.add(PeerInfo::new(hdr.src, from)); + + // Send reply + let reply_hdr = MsgHeader::new( + MsgType::DhtPingReply, + Self::len16(msg::PING_MSG_SIZE), + self.id, + hdr.src, + ); + let mut reply = [0u8; msg::PING_MSG_SIZE]; + if reply_hdr.write(&mut reply).is_ok() { + msg::write_ping(&mut reply, nonce); + let _ = self.send_signed(&reply, from); + } + } + + pub(crate) fn handle_dht_ping_reply( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let nonce = match msg::parse_ping(buf) { + Ok(n) => n, + Err(_) => return, + }; + + if hdr.dst != self.id { + return; + } + + // Verify we actually sent this nonce + match self.pending_pings.remove(&nonce) { + Some((expected_id, _sent_at)) => { + if expected_id != hdr.src && !expected_id.is_zero() { + log::debug!( + "Ping reply nonce={nonce}: expected {:?} got {:?}", + expected_id, + hdr.src + ); + return; + } + } + None => { + log::trace!("Ignoring unsolicited ping reply nonce={nonce}"); + return; + } + } + + log::debug!("DHT ping reply from {:?} nonce={nonce}", hdr.src); + + self.dht_table.add(PeerInfo::new(hdr.src, from)); + self.dht_table.mark_seen(&hdr.src); + } + + pub(crate) fn handle_dht_find_node( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let find = match msg::parse_find_node(buf) { + Ok(f) => f, + Err(_) => return, + }; + + log::debug!( + "DHT find_node from {:?} target={:?}", + hdr.src, + find.target + ); + + self.dht_table.add(PeerInfo::new(hdr.src, from)); + + // Respond with our closest nodes + let closest = self + .dht_table + .closest(&find.target, self.config.num_find_node); + + let domain = if from.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }; + + let reply_msg = msg::FindNodeReplyMsg { + nonce: find.nonce, + id: find.target, + domain, + nodes: closest, + }; + + let mut reply_buf = [0u8; 2048]; + let reply_hdr = + MsgHeader::new(MsgType::DhtFindNodeReply, 0, self.id, hdr.src); + if reply_hdr.write(&mut reply_buf).is_ok() { + let total = msg::write_find_node_reply(&mut reply_buf, &reply_msg); + + // Fix the length in header + let len_bytes = Self::len16(total).to_be_bytes(); + reply_buf[4] = len_bytes[0]; + reply_buf[5] = len_bytes[1]; + let _ = self.send_signed(&reply_buf[..total], from); + } + } + + pub(crate) fn handle_dht_find_node_reply( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let reply = match msg::parse_find_node_reply(buf) { + Ok(r) => r, + Err(_) => return, + }; + + log::debug!( + "DHT find_node_reply from {:?}: {} nodes", + hdr.src, + reply.nodes.len() + ); + + // Add sender to routing table + self.dht_table.add(PeerInfo::new(hdr.src, from)); + + // Add returned nodes, filtering invalid addresses + for node in &reply.nodes { + // S1-6: reject unroutable addresses + let ip = node.addr.ip(); + if ip.is_unspecified() || ip.is_multicast() { + continue; + } + // Reject zero NodeId + if node.id.is_zero() { + continue; + } + let result = self.dht_table.add(node.clone()); + self.peers.add(node.clone()); + + // §2.5: replicate data to newly discovered nodes + if matches!(result, crate::routing::InsertResult::Inserted) { + self.proactive_replicate(node); + } + } + + // Feed active queries with the reply + let sender_id = hdr.src; + let nodes_clone = reply.nodes.clone(); + + // Find which query this reply belongs to (match + // by nonce or by pending sender). + let matching_nonce: Option<u32> = self + .queries + .iter() + .find(|(_, q)| q.pending.contains_key(&sender_id)) + .map(|(n, _)| *n); + + if let Some(nonce) = matching_nonce { + if let Some(q) = self.queries.get_mut(&nonce) { + q.process_reply(&sender_id, nodes_clone); + + // Send follow-up find_nodes to newly + // discovered nodes + let next = q.next_to_query(); + let target = q.target; + for peer in next { + if self.send_find_node(peer.addr, target).is_ok() { + if let Some(q) = self.queries.get_mut(&nonce) { + q.pending.insert(peer.id, Instant::now()); + } + } + } + } + } + } + + pub(crate) fn handle_dht_store( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let store = match msg::parse_store(buf) { + Ok(s) => s, + Err(_) => return, + }; + + if hdr.dst != self.id { + return; + } + + // S1-4: enforce max value size + if store.value.len() > self.config.max_value_size { + log::debug!( + "Rejecting oversized store: {} bytes > {} max", + store.value.len(), + self.config.max_value_size + ); + return; + } + + // S2-9: verify sender matches claimed originator + if hdr.src != store.from { + log::debug!( + "Store origin mismatch: sender={:?} from={:?}", + hdr.src, + store.from + ); + return; + } + + log::debug!( + "DHT store from {:?}: key={} bytes, value={} bytes, ttl={}", + hdr.src, + store.key.len(), + store.value.len(), + store.ttl + ); + + self.dht_table.add(PeerInfo::new(hdr.src, from)); + + if store.ttl == 0 { + self.storage.remove(&store.id, &store.key); + } else { + let val = crate::dht::StoredValue { + key: store.key, + value: store.value, + id: store.id, + source: store.from, + ttl: store.ttl, + stored_at: std::time::Instant::now(), + is_unique: store.is_unique, + original: 0, // received, not originated + recvd: { + let mut s = std::collections::HashSet::new(); + s.insert(hdr.src); + s + }, + version: crate::dht::now_version(), + }; + self.storage.store(val); + } + } + + pub(crate) fn handle_dht_find_value( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + from: SocketAddr, + ) { + let fv = match msg::parse_find_value(buf) { + Ok(f) => f, + Err(_) => return, + }; + + log::debug!( + "DHT find_value from {:?}: key={} bytes", + hdr.src, + fv.key.len() + ); + + self.dht_table.add(PeerInfo::new(hdr.src, from)); + + // Check local storage + let values = self.storage.get(&fv.target, &fv.key); + + if !values.is_empty() { + log::debug!("Found {} value(s) locally", values.len()); + + // Send value reply using DHT_FIND_VALUE_REPLY + // with DATA_ARE_VALUES flag + let val = &values[0]; + let fixed = msg::FIND_VALUE_REPLY_FIXED; + let total = HEADER_SIZE + fixed + val.value.len(); + let mut rbuf = vec![0u8; total]; + let rhdr = MsgHeader::new( + MsgType::DhtFindValueReply, + Self::len16(total), + self.id, + hdr.src, + ); + if rhdr.write(&mut rbuf).is_ok() { + let off = HEADER_SIZE; + rbuf[off..off + 4].copy_from_slice(&fv.nonce.to_be_bytes()); + fv.target + .write_to(&mut rbuf[off + 4..off + 4 + crate::id::ID_LEN]); + rbuf[off + 24..off + 26].copy_from_slice(&0u16.to_be_bytes()); // index + rbuf[off + 26..off + 28].copy_from_slice(&1u16.to_be_bytes()); // total + rbuf[off + 28] = crate::wire::DATA_ARE_VALUES; + rbuf[off + 29] = 0; + rbuf[off + 30] = 0; + rbuf[off + 31] = 0; + rbuf[HEADER_SIZE + fixed..].copy_from_slice(&val.value); + let _ = self.send_signed(&rbuf, from); + } + } else { + // Send closest nodes as find_node_reply format + let closest = self + .dht_table + .closest(&fv.target, self.config.num_find_node); + log::debug!("Value not found, returning {} nodes", closest.len()); + let domain = if from.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }; + let reply_msg = msg::FindNodeReplyMsg { + nonce: fv.nonce, + id: fv.target, + domain, + nodes: closest, + }; + let mut rbuf = [0u8; 2048]; + let rhdr = + MsgHeader::new(MsgType::DhtFindValueReply, 0, self.id, hdr.src); + if rhdr.write(&mut rbuf).is_ok() { + // Write header with DATA_ARE_NODES flag + let off = HEADER_SIZE; + rbuf[off..off + 4].copy_from_slice(&fv.nonce.to_be_bytes()); + fv.target + .write_to(&mut rbuf[off + 4..off + 4 + crate::id::ID_LEN]); + rbuf[off + 24..off + 26].fill(0); // index + rbuf[off + 26..off + 28].fill(0); // total + rbuf[off + 28] = crate::wire::DATA_ARE_NODES; + rbuf[off + 29] = 0; + rbuf[off + 30] = 0; + rbuf[off + 31] = 0; + + // Write nodes after the fixed part + let nodes_off = HEADER_SIZE + msg::FIND_VALUE_REPLY_FIXED; + + // Write domain + num + padding before nodes + rbuf[nodes_off..nodes_off + 2] + .copy_from_slice(&domain.to_be_bytes()); + rbuf[nodes_off + 2] = reply_msg.nodes.len() as u8; + rbuf[nodes_off + 3] = 0; + let nw = if domain == DOMAIN_INET { + msg::write_nodes_inet( + &mut rbuf[nodes_off + 4..], + &reply_msg.nodes, + ) + } else { + msg::write_nodes_inet6( + &mut rbuf[nodes_off + 4..], + &reply_msg.nodes, + ) + }; + let total = nodes_off + 4 + nw; + rbuf[4..6].copy_from_slice(&Self::len16(total).to_be_bytes()); + let _ = self.send_signed(&rbuf[..total], from); + } + } + } + + pub(crate) fn handle_dht_find_value_reply( + &mut self, + buf: &[u8], + hdr: &MsgHeader, + ) { + let reply = match msg::parse_find_value_reply(buf) { + Ok(r) => r, + Err(e) => { + log::debug!("Failed to parse find_value_reply: {e}"); + return; + } + }; + + let sender_id = hdr.src; + + // Find the matching active query + let matching_nonce: Option<u32> = self + .queries + .iter() + .filter(|(_, q)| q.is_find_value) + .find(|(_, q)| q.pending.contains_key(&sender_id)) + .map(|(n, _)| *n); + + match reply.data { + msg::FindValueReplyData::Value { data, .. } => { + log::info!( + "Received value from {:?}: {} bytes", + sender_id, + data.len() + ); + + // Cache the value locally and republish + // to the closest queried node without it + // (§2.3: cache on nearest without value) + if let Some(nonce) = matching_nonce { + if let Some(q) = self.queries.get_mut(&nonce) { + // Store in local storage for + // subsequent get() calls + let val = crate::dht::StoredValue { + key: q.key.clone(), + value: data.clone(), + id: q.target, + source: sender_id, + ttl: 300, + stored_at: Instant::now(), + is_unique: false, + original: 0, + recvd: std::collections::HashSet::new(), + version: crate::dht::now_version(), + }; + self.storage.store(val); + + // §2.3: store on nearest queried + // node that didn't have the value + let target = q.target; + let key = q.key.clone(); + let nearest_without: Option<PeerInfo> = q + .closest + .iter() + .find(|p| { + q.queried.contains(&p.id) + && p.id != sender_id + }) + .cloned(); + + q.process_value(data.clone()); + + if let Some(peer) = nearest_without { + let store_msg = crate::msg::StoreMsg { + id: target, + from: self.id, + key, + value: data, + ttl: 300, + is_unique: false, + }; + if let Err(e) = self.send_store(&peer, &store_msg) { + log::debug!( + "Republish-on-access failed: {e}" + ); + } else { + log::debug!( + "Republished value to {:?} (nearest without)", + peer.id, + ); + } + } + } + } + } + msg::FindValueReplyData::Nodes { nodes, .. } => { + log::debug!( + "find_value_reply from {:?}: {} nodes (no value)", + sender_id, + nodes.len() + ); + + // Add nodes to routing table + for node in &nodes { + self.dht_table.add(node.clone()); + self.peers.add(node.clone()); + } + + // Feed the query with the nodes so it + // continues iterating + if let Some(nonce) = matching_nonce { + if let Some(q) = self.queries.get_mut(&nonce) { + q.process_reply(&sender_id, nodes.clone()); + + // Send follow-up find_value to + // newly discovered nodes + let next = q.next_to_query(); + let target = q.target; + let key = q.key.clone(); + for peer in next { + if self + .send_find_value_msg(peer.addr, target, &key) + .is_ok() + { + if let Some(q) = self.queries.get_mut(&nonce) { + q.pending.insert(peer.id, Instant::now()); + } + } + } + } + } + } + msg::FindValueReplyData::Nul => { + log::debug!("find_value_reply NUL from {:?}", sender_id); + if let Some(nonce) = matching_nonce { + if let Some(q) = self.queries.get_mut(&nonce) { + q.pending.remove(&sender_id); + q.queried.insert(sender_id); + } + } + } + } + } + + // ── Data restore / republish ──────────────────── + + /// Restore (republish) stored data to k-closest + /// nodes. Tracks the `original` counter and `recvd` + /// set to avoid unnecessary duplicates. + pub(crate) fn restore_data(&mut self) { + let values = self.storage.all_values(); + if values.is_empty() { + return; + } + + log::debug!("Restoring {} stored values", values.len()); + + for val in &values { + if val.is_expired() { + continue; + } + let closest = + self.dht_table.closest(&val.id, self.config.num_find_node); + let store_msg = msg::StoreMsg { + id: val.id, + from: val.source, + key: val.key.clone(), + value: val.value.clone(), + ttl: val.remaining_ttl(), + is_unique: val.is_unique, + }; + for peer in &closest { + if peer.id == self.id { + continue; + } + + // Skip peers that already have this value + if val.recvd.contains(&peer.id) { + continue; + } + if let Err(e) = self.send_store(peer, &store_msg) { + log::debug!("Restore send failed: {e}"); + continue; + } + self.storage.mark_received(val, peer.id); + } + } + } + + /// Run mask_bit exploration: send find_node queries + /// to probe distant regions of the ID space. + pub(crate) fn run_maintain(&mut self) { + let (t1, t2) = self.explorer.next_targets(); + log::debug!("Maintain: exploring targets {:?} {:?}", t1, t2); + let _ = self.start_find_node(t1); + let _ = self.start_find_node(t2); + } + + // ── Proactive replication (§2.5) ────────────── + + /// Replicate stored data to a newly discovered node + /// if it is closer to a key than the current furthest + /// holder. This ensures data availability without + /// waiting for the periodic restore cycle. + /// + /// Called when a new node is inserted into the routing + /// table (not on updates of existing nodes). + pub(crate) fn proactive_replicate(&mut self, new_node: &PeerInfo) { + let values = self.storage.all_values(); + if values.is_empty() { + return; + } + + let k = self.config.num_find_node; + let mut sent = 0u32; + + for val in &values { + if val.is_expired() { + continue; + } + // Skip if this node already received the value + if val.recvd.contains(&new_node.id) { + continue; + } + + // Check: is new_node closer to this key than + // the furthest current k-closest holder? + let closest = self.dht_table.closest(&val.id, k); + let furthest = match closest.last() { + Some(p) => p, + None => continue, + }; + + let dist_new = val.id.distance(&new_node.id); + let dist_far = val.id.distance(&furthest.id); + + if dist_new >= dist_far { + continue; // new node is not closer + } + + let store_msg = crate::msg::StoreMsg { + id: val.id, + from: val.source, + key: val.key.clone(), + value: val.value.clone(), + ttl: val.remaining_ttl(), + is_unique: val.is_unique, + }; + + if self.send_store(new_node, &store_msg).is_ok() { + self.storage.mark_received(val, new_node.id); + sent += 1; + } + } + + if sent > 0 { + log::debug!( + "Proactive replicate: sent {sent} values to {:?}", + new_node.id, + ); + } + } +} diff --git a/src/id.rs b/src/id.rs new file mode 100644 index 0000000..b2e37eb --- /dev/null +++ b/src/id.rs @@ -0,0 +1,238 @@ +//! 256-bit node identity for Kademlia. +//! +//! NodeId is 32 bytes — the same size as an Ed25519 +//! public key. For node IDs, `NodeId = public_key` +//! directly. For DHT key hashing, SHA-256 maps +//! arbitrary data into the 256-bit ID space. + +use crate::sys; +use std::fmt; + +/// Length of a node ID in bytes (32 = Ed25519 pubkey). +pub const ID_LEN: usize = 32; + +/// Number of bits in a node ID. +pub const ID_BITS: usize = ID_LEN * 8; // 256 + +/// A 256-bit node identifier. +/// +/// Provides XOR distance and lexicographic ordering as +/// required by Kademlia routing. +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct NodeId([u8; ID_LEN]); + +impl NodeId { + /// Generate a random node ID. + pub fn random() -> Self { + let mut buf = [0u8; ID_LEN]; + sys::random_bytes(&mut buf); + Self(buf) + } + + /// Create a node ID from raw bytes. + pub fn from_bytes(b: [u8; ID_LEN]) -> Self { + Self(b) + } + + /// Create a node ID by SHA-256 hashing arbitrary + /// data. + /// + /// Used by `put`/`get` to map keys into the 256-bit + /// Kademlia ID space. + pub fn from_key(data: &[u8]) -> Self { + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(data); + let mut buf = [0u8; ID_LEN]; + buf.copy_from_slice(&hash); + Self(buf) + } + + /// Return the raw bytes. + pub fn as_bytes(&self) -> &[u8; ID_LEN] { + &self.0 + } + + /// XOR distance between two node IDs. + pub fn distance(&self, other: &NodeId) -> NodeId { + let mut out = [0u8; ID_LEN]; + for (i, byte) in out.iter_mut().enumerate() { + *byte = self.0[i] ^ other.0[i]; + } + NodeId(out) + } + + /// Number of leading zero bits in this ID. + /// + /// Used to determine the k-bucket index. + pub fn leading_zeros(&self) -> u32 { + for (i, &byte) in self.0.iter().enumerate() { + if byte != 0 { + return (i as u32) * 8 + byte.leading_zeros(); + } + } + (ID_LEN as u32) * 8 + } + + /// Check if all bytes are zero. + pub fn is_zero(&self) -> bool { + self.0.iter().all(|&b| b == 0) + } + + /// Parse from a hexadecimal string (64 chars). + pub fn from_hex(s: &str) -> Option<Self> { + if s.len() != ID_LEN * 2 { + return None; + } + let mut buf = [0u8; ID_LEN]; + for i in 0..ID_LEN { + buf[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).ok()?; + } + Some(Self(buf)) + } + + /// Format as a hexadecimal string (64 chars). + pub fn to_hex(&self) -> String { + self.0.iter().map(|b| format!("{b:02x}")).collect() + } + + /// Serialize to a byte slice. + /// + /// # Panics + /// Panics if `buf.len() < ID_LEN` (32). + pub fn write_to(&self, buf: &mut [u8]) { + debug_assert!( + buf.len() >= ID_LEN, + "NodeId::write_to: buf too small ({} < {ID_LEN})", + buf.len() + ); + buf[..ID_LEN].copy_from_slice(&self.0); + } + + /// Deserialize from a byte slice. + /// + /// # Panics + /// Panics if `buf.len() < ID_LEN` (32). + pub fn read_from(buf: &[u8]) -> Self { + debug_assert!( + buf.len() >= ID_LEN, + "NodeId::read_from: buf too small ({} < {ID_LEN})", + buf.len() + ); + let mut id = [0u8; ID_LEN]; + id.copy_from_slice(&buf[..ID_LEN]); + Self(id) + } +} + +impl Ord for NodeId { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +impl PartialOrd for NodeId { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl fmt::Debug for NodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "NodeId({})", &self.to_hex()[..8]) + } +} + +impl fmt::Display for NodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_hex()) + } +} + +impl AsRef<[u8]> for NodeId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn zero_distance() { + let id = NodeId::from_bytes([0xAB; ID_LEN]); + let d = id.distance(&id); + assert!(d.is_zero()); + } + + #[test] + fn xor_symmetric() { + let a = NodeId::from_bytes([0x01; ID_LEN]); + let b = NodeId::from_bytes([0xFF; ID_LEN]); + assert_eq!(a.distance(&b), b.distance(&a)); + } + + #[test] + fn leading_zeros_all_zero() { + let z = NodeId::from_bytes([0; ID_LEN]); + assert_eq!(z.leading_zeros(), 256); + } + + #[test] + fn leading_zeros_first_bit() { + let mut buf = [0u8; ID_LEN]; + buf[0] = 0x80; + let id = NodeId::from_bytes(buf); + assert_eq!(id.leading_zeros(), 0); + } + + #[test] + fn leading_zeros_ninth_bit() { + let mut buf = [0u8; ID_LEN]; + buf[1] = 0x80; + let id = NodeId::from_bytes(buf); + assert_eq!(id.leading_zeros(), 8); + } + + #[test] + fn hex_roundtrip() { + let id = NodeId::from_bytes([ + 0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, + 0xEF, 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10, 0xDE, 0xAD, + 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, + ]); + let hex = id.to_hex(); + assert_eq!(hex.len(), 64); + assert_eq!(NodeId::from_hex(&hex), Some(id)); + } + + #[test] + fn from_key_deterministic() { + let a = NodeId::from_key(b"hello"); + let b = NodeId::from_key(b"hello"); + assert_eq!(a, b); + } + + #[test] + fn from_key_different_inputs() { + let a = NodeId::from_key(b"hello"); + let b = NodeId::from_key(b"world"); + assert_ne!(a, b); + } + + #[test] + fn ordering_is_lexicographic() { + let a = NodeId::from_bytes([0x00; ID_LEN]); + let b = NodeId::from_bytes([0xFF; ID_LEN]); + assert!(a < b); + } + + #[test] + fn write_read_roundtrip() { + let id = NodeId::from_bytes([0x42; ID_LEN]); + let mut buf = [0u8; ID_LEN]; + id.write_to(&mut buf); + let id2 = NodeId::read_from(&buf); + assert_eq!(id, id2); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..f956f98 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,128 @@ +// Allow nested if-let until we migrate to let-chains +#![allow(clippy::collapsible_if)] + +//! # tesseras-dht +//! +//! NAT-aware Kademlia DHT library for peer-to-peer +//! networks. +//! +//! `tesseras-dht` provides: +//! +//! - **Distributed key-value storage** via Kademlia +//! (iterative FIND_NODE, FIND_VALUE, STORE) +//! - **NAT traversal** via DTUN (hole-punching) and +//! proxy relay (symmetric NAT) +//! - **Reliable transport** (RDP) over UDP +//! - **Datagram transport** with automatic fragmentation +//! +//! ## Quick start +//! +//! ```rust,no_run +//! use tesseras_dht::{Node, NatState}; +//! +//! let mut node = Node::bind(10000).unwrap(); +//! node.set_nat_state(NatState::Global); +//! node.join("bootstrap.example.com", 10000).unwrap(); +//! +//! // Store and retrieve values +//! node.put(b"key", b"value", 300, false); +//! +//! // Event loop +//! loop { +//! node.poll().unwrap(); +//! # break; +//! } +//! ``` +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────┐ +//! │ Node (facade) │ +//! │ ┌──────┐ ┌────────┐ ┌──────────────┐ │ +//! │ │ DHT │ │Routing │ │ Storage │ │ +//! │ │ │ │ Table │ │ (DhtStorage)│ │ +//! │ └──────┘ └────────┘ └──────────────┘ │ +//! │ ┌──────┐ ┌────────┐ ┌──────────────┐ │ +//! │ │ DTUN │ │ NAT │ │ Proxy │ │ +//! │ │ │ │Detector│ │ │ │ +//! │ └──────┘ └────────┘ └──────────────┘ │ +//! │ ┌──────┐ ┌────────┐ ┌──────────────┐ │ +//! │ │ RDP │ │ Dgram │ │ Timers │ │ +//! │ └──────┘ └────────┘ └──────────────┘ │ +//! │ ┌──────────────────────────────────┐ │ +//! │ │ NetLoop (mio / kqueue) │ │ +//! │ └──────────────────────────────────┘ │ +//! └─────────────────────────────────────────┘ +//! ``` +//! +//! ## Security (OpenBSD) +//! +//! Sandboxing with pledge(2) and unveil(2) is an +//! application-level concern. Call them in your binary +//! after binding the socket. The library only needs +//! `"stdio inet dns"` promises. + +/// Address advertisement protocol. +pub mod advertise; +/// Ban list for misbehaving peers. +pub mod banlist; +/// Node configuration. +pub mod config; +/// Ed25519 identity and packet signing. +pub mod crypto; +/// Datagram transport with fragmentation/reassembly. +pub mod dgram; +/// Kademlia DHT storage and iterative queries. +pub mod dht; +/// Distributed tunnel for NAT traversal. +pub mod dtun; +/// Error types. +pub mod error; +/// Event types for typed callback integration. +pub mod event; +/// Packet dispatch and message handlers. +mod handlers; +/// 256-bit node identity. +pub mod id; +/// Node metrics and observability. +pub mod metrics; +/// Wire message body parsing and writing. +pub mod msg; +/// NAT type detection (STUN-like echo protocol). +pub mod nat; +/// Network send helpers and query management. +mod net; +/// Main facade: the [`Node`]. +pub mod node; +/// Peer node database. +pub mod peers; +/// Persistence traits for data and routing table. +pub mod persist; +/// Proxy relay for symmetric NAT nodes. +pub mod proxy; +/// Per-IP rate limiting. +pub mod ratelimit; +/// Reliable Datagram Protocol (RDP). +pub mod rdp; +/// Kademlia routing table with k-buckets. +pub mod routing; +/// UDP I/O via mio (kqueue on OpenBSD). +pub mod socket; +/// Store acknowledgment tracking. +pub mod store_track; +/// OpenBSD arc4random_buf(3) for secure random bytes. +pub mod sys; +/// Timer wheel for scheduling callbacks. +pub mod timer; +/// On-the-wire binary protocol (header, message types). +pub mod wire; + +// Re-export the main types for convenience. +pub use error::Error; +pub use id::NodeId; +pub use nat::NatState; + +// Re-export sha2 for downstream crates. +pub use sha2; +pub use node::Node; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..67ec30f --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,121 @@ +//! Node metrics and observability. +//! +//! Atomic counters for messages, lookups, storage, +//! and errors. Accessible via `Tessera::metrics()`. + +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Atomic metrics counters. +pub struct Metrics { + pub messages_sent: AtomicU64, + pub messages_received: AtomicU64, + pub lookups_started: AtomicU64, + pub lookups_completed: AtomicU64, + pub rpc_timeouts: AtomicU64, + pub values_stored: AtomicU64, + pub packets_rejected: AtomicU64, + pub bytes_sent: AtomicU64, + pub bytes_received: AtomicU64, +} + +impl Metrics { + pub fn new() -> Self { + Self { + messages_sent: AtomicU64::new(0), + messages_received: AtomicU64::new(0), + lookups_started: AtomicU64::new(0), + lookups_completed: AtomicU64::new(0), + rpc_timeouts: AtomicU64::new(0), + values_stored: AtomicU64::new(0), + packets_rejected: AtomicU64::new(0), + bytes_sent: AtomicU64::new(0), + bytes_received: AtomicU64::new(0), + } + } + + /// Take a snapshot of all counters. + pub fn snapshot(&self) -> MetricsSnapshot { + MetricsSnapshot { + messages_sent: self.messages_sent.load(Ordering::Relaxed), + messages_received: self.messages_received.load(Ordering::Relaxed), + lookups_started: self.lookups_started.load(Ordering::Relaxed), + lookups_completed: self.lookups_completed.load(Ordering::Relaxed), + rpc_timeouts: self.rpc_timeouts.load(Ordering::Relaxed), + values_stored: self.values_stored.load(Ordering::Relaxed), + packets_rejected: self.packets_rejected.load(Ordering::Relaxed), + bytes_sent: self.bytes_sent.load(Ordering::Relaxed), + bytes_received: self.bytes_received.load(Ordering::Relaxed), + } + } +} + +impl Default for Metrics { + fn default() -> Self { + Self::new() + } +} + +/// Snapshot of metrics at a point in time. +#[derive(Debug, Clone)] +pub struct MetricsSnapshot { + pub messages_sent: u64, + pub messages_received: u64, + pub lookups_started: u64, + pub lookups_completed: u64, + pub rpc_timeouts: u64, + pub values_stored: u64, + pub packets_rejected: u64, + pub bytes_sent: u64, + pub bytes_received: u64, +} + +impl std::fmt::Display for MetricsSnapshot { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "sent={} recv={} lookups={}/{} timeouts={} \ + stored={} rejected={} bytes={}/{}", + self.messages_sent, + self.messages_received, + self.lookups_completed, + self.lookups_started, + self.rpc_timeouts, + self.values_stored, + self.packets_rejected, + self.bytes_sent, + self.bytes_received, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn counters_start_zero() { + let m = Metrics::new(); + let s = m.snapshot(); + assert_eq!(s.messages_sent, 0); + assert_eq!(s.bytes_sent, 0); + } + + #[test] + fn increment_and_snapshot() { + let m = Metrics::new(); + m.messages_sent.fetch_add(5, Ordering::Relaxed); + m.bytes_sent.fetch_add(1000, Ordering::Relaxed); + let s = m.snapshot(); + assert_eq!(s.messages_sent, 5); + assert_eq!(s.bytes_sent, 1000); + } + + #[test] + fn display_format() { + let m = Metrics::new(); + m.messages_sent.fetch_add(10, Ordering::Relaxed); + let s = m.snapshot(); + let text = format!("{s}"); + assert!(text.contains("sent=10")); + } +} diff --git a/src/msg.rs b/src/msg.rs new file mode 100644 index 0000000..95c1d4c --- /dev/null +++ b/src/msg.rs @@ -0,0 +1,830 @@ +//! Message body parsing and writing. +//! +//! Each protocol message has a fixed header (parsed by +//! `wire::MsgHeader`) followed by a variable body. This +//! module provides parse/write for all body types. +//! +//! All multi-byte fields are big-endian (network order). + +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +use crate::error::Error; +use crate::id::{ID_LEN, NodeId}; +use crate::peers::PeerInfo; +use crate::wire::{DOMAIN_INET, HEADER_SIZE}; + +// ── Helpers ───────────────────────────────────────── + +// Callers MUST validate `off + 2 <= buf.len()` before calling. +fn read_u16(buf: &[u8], off: usize) -> u16 { + u16::from_be_bytes([buf[off], buf[off + 1]]) +} + +fn read_u32(buf: &[u8], off: usize) -> u32 { + u32::from_be_bytes([buf[off], buf[off + 1], buf[off + 2], buf[off + 3]]) +} + +fn write_u16(buf: &mut [u8], off: usize, v: u16) { + buf[off..off + 2].copy_from_slice(&v.to_be_bytes()); +} + +fn write_u32(buf: &mut [u8], off: usize, v: u32) { + buf[off..off + 4].copy_from_slice(&v.to_be_bytes()); +} + +// ── Ping (DHT + DTUN) ────────────────────────────── + +/// Body: just a nonce (4 bytes after header). +/// Used by: DhtPing, DhtPingReply, DtunPing, +/// DtunPingReply, DtunRequestReply. +pub fn parse_ping(buf: &[u8]) -> Result<u32, Error> { + if buf.len() < HEADER_SIZE + 4 { + return Err(Error::BufferTooSmall); + } + Ok(read_u32(buf, HEADER_SIZE)) +} + +pub fn write_ping(buf: &mut [u8], nonce: u32) { + write_u32(buf, HEADER_SIZE, nonce); +} + +/// Total size of a ping message. +pub const PING_MSG_SIZE: usize = HEADER_SIZE + 4; + +// ── NAT Echo ──────────────────────────────────────── + +/// Parse NatEcho: nonce only. +pub fn parse_nat_echo(buf: &[u8]) -> Result<u32, Error> { + parse_ping(buf) // same layout +} + +/// NatEchoReply body: nonce(4) + domain(2) + port(2) + addr(16). +pub const NAT_ECHO_REPLY_BODY: usize = 4 + 2 + 2 + 16; + +#[derive(Debug, Clone)] +pub struct NatEchoReply { + pub nonce: u32, + pub domain: u16, + pub port: u16, + pub addr: [u8; 16], +} + +pub fn parse_nat_echo_reply(buf: &[u8]) -> Result<NatEchoReply, Error> { + let off = HEADER_SIZE; + if buf.len() < off + NAT_ECHO_REPLY_BODY { + return Err(Error::BufferTooSmall); + } + Ok(NatEchoReply { + nonce: read_u32(buf, off), + domain: read_u16(buf, off + 4), + port: read_u16(buf, off + 6), + addr: { + let mut a = [0u8; 16]; + a.copy_from_slice(&buf[off + 8..off + 24]); + a + }, + }) +} + +pub fn write_nat_echo_reply(buf: &mut [u8], reply: &NatEchoReply) { + let off = HEADER_SIZE; + write_u32(buf, off, reply.nonce); + write_u16(buf, off + 4, reply.domain); + write_u16(buf, off + 6, reply.port); + buf[off + 8..off + 24].copy_from_slice(&reply.addr); +} + +/// NatEchoRedirect body: nonce(4) + port(2) + padding(2). +pub fn parse_nat_echo_redirect(buf: &[u8]) -> Result<(u32, u16), Error> { + let off = HEADER_SIZE; + if buf.len() < off + 8 { + return Err(Error::BufferTooSmall); + } + Ok((read_u32(buf, off), read_u16(buf, off + 4))) +} + +/// Write NatEchoRedirect body. +pub fn write_nat_echo_redirect(buf: &mut [u8], nonce: u32, port: u16) { + let off = HEADER_SIZE; + write_u32(buf, off, nonce); + write_u16(buf, off + 4, port); + write_u16(buf, off + 6, 0); // padding +} + +/// Size of a NatEchoRedirect message. +pub const NAT_ECHO_REDIRECT_SIZE: usize = HEADER_SIZE + 8; + +// ── FindNode (DHT + DTUN) ─────────────────────────── + +/// FindNode body: nonce(4) + id(20) + domain(2) + state_or_pad(2). +pub const FIND_NODE_BODY: usize = 4 + ID_LEN + 2 + 2; + +#[derive(Debug, Clone)] +pub struct FindNodeMsg { + pub nonce: u32, + pub target: NodeId, + pub domain: u16, + pub state: u16, +} + +pub fn parse_find_node(buf: &[u8]) -> Result<FindNodeMsg, Error> { + let off = HEADER_SIZE; + if buf.len() < off + FIND_NODE_BODY { + return Err(Error::BufferTooSmall); + } + Ok(FindNodeMsg { + nonce: read_u32(buf, off), + target: NodeId::read_from(&buf[off + 4..off + 4 + ID_LEN]), + domain: read_u16(buf, off + 4 + ID_LEN), + state: read_u16(buf, off + 6 + ID_LEN), + }) +} + +pub fn write_find_node(buf: &mut [u8], msg: &FindNodeMsg) { + let off = HEADER_SIZE; + write_u32(buf, off, msg.nonce); + msg.target.write_to(&mut buf[off + 4..off + 4 + ID_LEN]); + write_u16(buf, off + 4 + ID_LEN, msg.domain); + write_u16(buf, off + 6 + ID_LEN, msg.state); +} + +pub const FIND_NODE_MSG_SIZE: usize = HEADER_SIZE + FIND_NODE_BODY; + +// ── FindNodeReply (DHT + DTUN) ────────────────────── + +/// FindNodeReply fixed part: nonce(4) + id(20) + +/// domain(2) + num(1) + padding(1). +pub const FIND_NODE_REPLY_FIXED: usize = 4 + ID_LEN + 4; + +#[derive(Debug, Clone)] +pub struct FindNodeReplyMsg { + pub nonce: u32, + pub id: NodeId, + pub domain: u16, + pub nodes: Vec<PeerInfo>, +} + +/// Size of an IPv4 node entry: port(2) + reserved(2) + +/// addr(4) + id(ID_LEN). +pub const INET_NODE_SIZE: usize = 8 + ID_LEN; + +/// Size of an IPv6 node entry: port(2) + reserved(2) + +/// addr(16) + id(ID_LEN). +pub const INET6_NODE_SIZE: usize = 20 + ID_LEN; + +pub fn parse_find_node_reply(buf: &[u8]) -> Result<FindNodeReplyMsg, Error> { + let off = HEADER_SIZE; + if buf.len() < off + FIND_NODE_REPLY_FIXED { + return Err(Error::BufferTooSmall); + } + + let nonce = read_u32(buf, off); + let id = NodeId::read_from(&buf[off + 4..off + 4 + ID_LEN]); + let domain = read_u16(buf, off + 4 + ID_LEN); + let num = buf[off + 4 + ID_LEN + 2] as usize; + + let nodes_off = off + FIND_NODE_REPLY_FIXED; + let nodes = if domain == DOMAIN_INET { + read_nodes_inet(&buf[nodes_off..], num) + } else { + read_nodes_inet6(&buf[nodes_off..], num) + }; + + Ok(FindNodeReplyMsg { + nonce, + id, + domain, + nodes, + }) +} + +pub fn write_find_node_reply(buf: &mut [u8], msg: &FindNodeReplyMsg) -> usize { + let off = HEADER_SIZE; + write_u32(buf, off, msg.nonce); + msg.id.write_to(&mut buf[off + 4..off + 4 + ID_LEN]); + write_u16(buf, off + 4 + ID_LEN, msg.domain); + let num = msg.nodes.len().min(MAX_NODES_PER_REPLY); + buf[off + 4 + ID_LEN + 2] = num as u8; + buf[off + 4 + ID_LEN + 3] = 0; // padding + + let nodes_off = off + FIND_NODE_REPLY_FIXED; + let nodes_len = if msg.domain == DOMAIN_INET { + write_nodes_inet(&mut buf[nodes_off..], &msg.nodes) + } else { + write_nodes_inet6(&mut buf[nodes_off..], &msg.nodes) + }; + + HEADER_SIZE + FIND_NODE_REPLY_FIXED + nodes_len +} + +// ── Store (DHT) ───────────────────────────────────── + +/// Store fixed part: id(20) + from(20) + keylen(2) + +/// valuelen(2) + ttl(2) + flags(1) + reserved(1) = 48. +pub const STORE_FIXED: usize = ID_LEN * 2 + 8; + +#[derive(Debug, Clone)] +pub struct StoreMsg { + pub id: NodeId, + pub from: NodeId, + pub key: Vec<u8>, + pub value: Vec<u8>, + pub ttl: u16, + pub is_unique: bool, +} + +pub fn parse_store(buf: &[u8]) -> Result<StoreMsg, Error> { + let off = HEADER_SIZE; + if buf.len() < off + STORE_FIXED { + return Err(Error::BufferTooSmall); + } + + let id = NodeId::read_from(&buf[off..off + ID_LEN]); + let from = NodeId::read_from(&buf[off + ID_LEN..off + ID_LEN * 2]); + let keylen = read_u16(buf, off + ID_LEN * 2) as usize; + let valuelen = read_u16(buf, off + ID_LEN * 2 + 2) as usize; + let ttl = read_u16(buf, off + ID_LEN * 2 + 4); + let flags = buf[off + ID_LEN * 2 + 6]; + + let data_off = off + STORE_FIXED; + let total = data_off + .checked_add(keylen) + .and_then(|v| v.checked_add(valuelen)) + .ok_or(Error::InvalidMessage)?; + + if buf.len() < total { + return Err(Error::BufferTooSmall); + } + + let key = buf[data_off..data_off + keylen].to_vec(); + let value = buf[data_off + keylen..data_off + keylen + valuelen].to_vec(); + + Ok(StoreMsg { + id, + from, + key, + value, + ttl, + is_unique: flags & crate::wire::DHT_FLAG_UNIQUE != 0, + }) +} + +pub fn write_store(buf: &mut [u8], msg: &StoreMsg) -> Result<usize, Error> { + let off = HEADER_SIZE; + let total = off + STORE_FIXED + msg.key.len() + msg.value.len(); + if buf.len() < total { + return Err(Error::BufferTooSmall); + } + + msg.id.write_to(&mut buf[off..off + ID_LEN]); + msg.from.write_to(&mut buf[off + ID_LEN..off + ID_LEN * 2]); + let keylen = + u16::try_from(msg.key.len()).map_err(|_| Error::BufferTooSmall)?; + let valuelen = + u16::try_from(msg.value.len()).map_err(|_| Error::BufferTooSmall)?; + write_u16(buf, off + ID_LEN * 2, keylen); + write_u16(buf, off + ID_LEN * 2 + 2, valuelen); + write_u16(buf, off + ID_LEN * 2 + 4, msg.ttl); + buf[off + ID_LEN * 2 + 6] = if msg.is_unique { + crate::wire::DHT_FLAG_UNIQUE + } else { + 0 + }; + buf[off + ID_LEN * 2 + 7] = 0; // reserved + + let data_off = off + STORE_FIXED; + buf[data_off..data_off + msg.key.len()].copy_from_slice(&msg.key); + buf[data_off + msg.key.len()..data_off + msg.key.len() + msg.value.len()] + .copy_from_slice(&msg.value); + + Ok(total) +} + +// ── FindValue (DHT) ───────────────────────────────── + +/// FindValue fixed: nonce(4) + id(20) + domain(2) + +/// keylen(2) + flag(1) + padding(3) = 32. +pub const FIND_VALUE_FIXED: usize = 4 + ID_LEN + 8; + +#[derive(Debug, Clone)] +pub struct FindValueMsg { + pub nonce: u32, + pub target: NodeId, + pub domain: u16, + pub key: Vec<u8>, + pub use_rdp: bool, +} + +pub fn parse_find_value(buf: &[u8]) -> Result<FindValueMsg, Error> { + let off = HEADER_SIZE; + if buf.len() < off + FIND_VALUE_FIXED { + return Err(Error::BufferTooSmall); + } + + let nonce = read_u32(buf, off); + let target = NodeId::read_from(&buf[off + 4..off + 4 + ID_LEN]); + let domain = read_u16(buf, off + 4 + ID_LEN); + let keylen = read_u16(buf, off + 6 + ID_LEN) as usize; + let flag = buf[off + 8 + ID_LEN]; + + let key_off = off + FIND_VALUE_FIXED; + if buf.len() < key_off + keylen { + return Err(Error::BufferTooSmall); + } + + Ok(FindValueMsg { + nonce, + target, + domain, + key: buf[key_off..key_off + keylen].to_vec(), + use_rdp: flag == 1, + }) +} + +pub fn write_find_value( + buf: &mut [u8], + msg: &FindValueMsg, +) -> Result<usize, Error> { + let off = HEADER_SIZE; + let total = off + FIND_VALUE_FIXED + msg.key.len(); + if buf.len() < total { + return Err(Error::BufferTooSmall); + } + + write_u32(buf, off, msg.nonce); + msg.target.write_to(&mut buf[off + 4..off + 4 + ID_LEN]); + write_u16(buf, off + 4 + ID_LEN, msg.domain); + let keylen = + u16::try_from(msg.key.len()).map_err(|_| Error::BufferTooSmall)?; + write_u16(buf, off + 6 + ID_LEN, keylen); + buf[off + 8 + ID_LEN] = if msg.use_rdp { 1 } else { 0 }; + buf[off + 9 + ID_LEN] = 0; + buf[off + 10 + ID_LEN] = 0; + buf[off + 11 + ID_LEN] = 0; + + let key_off = off + FIND_VALUE_FIXED; + buf[key_off..key_off + msg.key.len()].copy_from_slice(&msg.key); + + Ok(total) +} + +// ── FindValueReply (DHT) ──────────────────────────── + +/// Fixed: nonce(4) + id(20) + index(2) + total(2) + +/// flag(1) + padding(3) = 32. +pub const FIND_VALUE_REPLY_FIXED: usize = 4 + ID_LEN + 8; + +#[derive(Debug, Clone)] +pub enum FindValueReplyData { + /// flag=0xa0: node list + Nodes { domain: u16, nodes: Vec<PeerInfo> }, + + /// flag=0xa1: a value chunk + Value { + index: u16, + total: u16, + data: Vec<u8>, + }, + + /// flag=0xa2: no data + Nul, +} + +#[derive(Debug, Clone)] +pub struct FindValueReplyMsg { + pub nonce: u32, + pub id: NodeId, + pub data: FindValueReplyData, +} + +pub fn parse_find_value_reply(buf: &[u8]) -> Result<FindValueReplyMsg, Error> { + let off = HEADER_SIZE; + if buf.len() < off + FIND_VALUE_REPLY_FIXED { + return Err(Error::BufferTooSmall); + } + + let nonce = read_u32(buf, off); + let id = NodeId::read_from(&buf[off + 4..off + 4 + ID_LEN]); + let index = read_u16(buf, off + 4 + ID_LEN); + let total = read_u16(buf, off + 6 + ID_LEN); + let flag = buf[off + 8 + ID_LEN]; + + let data_off = off + FIND_VALUE_REPLY_FIXED; + + let data = match flag { + crate::wire::DATA_ARE_NODES => { + if buf.len() < data_off + 4 { + return Err(Error::BufferTooSmall); + } + let domain = read_u16(buf, data_off); + let num = buf[data_off + 2] as usize; + let nodes_off = data_off + 4; + let nodes = if domain == DOMAIN_INET { + read_nodes_inet(&buf[nodes_off..], num) + } else { + read_nodes_inet6(&buf[nodes_off..], num) + }; + FindValueReplyData::Nodes { domain, nodes } + } + crate::wire::DATA_ARE_VALUES => { + let payload = buf[data_off..].to_vec(); + FindValueReplyData::Value { + index, + total, + data: payload, + } + } + crate::wire::DATA_ARE_NUL => FindValueReplyData::Nul, + _ => return Err(Error::InvalidMessage), + }; + + Ok(FindValueReplyMsg { nonce, id, data }) +} + +// ── DtunRegister ──────────────────────────────────── + +pub fn parse_dtun_register(buf: &[u8]) -> Result<u32, Error> { + if buf.len() < HEADER_SIZE + 4 { + return Err(Error::BufferTooSmall); + } + Ok(read_u32(buf, HEADER_SIZE)) // session +} + +// ── DtunRequest ───────────────────────────────────── + +pub fn parse_dtun_request(buf: &[u8]) -> Result<(u32, NodeId), Error> { + let off = HEADER_SIZE; + if buf.len() < off + 4 + ID_LEN { + return Err(Error::BufferTooSmall); + } + let nonce = read_u32(buf, off); + let target = NodeId::read_from(&buf[off + 4..off + 4 + ID_LEN]); + Ok((nonce, target)) +} + +// ── Node list serialization (IPv4 / IPv6) ─────────── + +/// Maximum nodes per reply (prevents OOM from malicious num). +const MAX_NODES_PER_REPLY: usize = 20; + +/// Read `num` IPv4 node entries from `buf`. +pub fn read_nodes_inet(buf: &[u8], num: usize) -> Vec<PeerInfo> { + let num = num.min(MAX_NODES_PER_REPLY); + let mut nodes = Vec::with_capacity(num); + for i in 0..num { + let off = i * INET_NODE_SIZE; + if off + INET_NODE_SIZE > buf.len() { + break; + } + let port = read_u16(buf, off); + let ip = Ipv4Addr::new( + buf[off + 4], + buf[off + 5], + buf[off + 6], + buf[off + 7], + ); + let id = NodeId::read_from(&buf[off + 8..off + 8 + ID_LEN]); + let addr = SocketAddr::V4(SocketAddrV4::new(ip, port)); + nodes.push(PeerInfo::new(id, addr)); + } + nodes +} + +/// Read `num` IPv6 node entries from `buf`. +pub fn read_nodes_inet6(buf: &[u8], num: usize) -> Vec<PeerInfo> { + let num = num.min(MAX_NODES_PER_REPLY); + let mut nodes = Vec::with_capacity(num); + for i in 0..num { + let off = i * INET6_NODE_SIZE; + if off + INET6_NODE_SIZE > buf.len() { + break; + } + let port = read_u16(buf, off); + let mut octets = [0u8; 16]; + octets.copy_from_slice(&buf[off + 4..off + 20]); + let ip = Ipv6Addr::from(octets); + let id = NodeId::read_from(&buf[off + 20..off + 20 + ID_LEN]); + let addr = SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)); + nodes.push(PeerInfo::new(id, addr)); + } + nodes +} + +/// Write IPv4 node entries. Returns bytes written. +pub fn write_nodes_inet(buf: &mut [u8], nodes: &[PeerInfo]) -> usize { + let mut written = 0; + for node in nodes { + if written + INET_NODE_SIZE > buf.len() { + break; + } + let off = written; + write_u16(buf, off, node.addr.port()); + write_u16(buf, off + 2, 0); // reserved + + if let SocketAddr::V4(v4) = node.addr { + let octets = v4.ip().octets(); + buf[off + 4..off + 8].copy_from_slice(&octets); + } else { + buf[off + 4..off + 8].fill(0); + } + + node.id.write_to(&mut buf[off + 8..off + 8 + ID_LEN]); + written += INET_NODE_SIZE; + } + written +} + +/// Write IPv6 node entries. Returns bytes written. +pub fn write_nodes_inet6(buf: &mut [u8], nodes: &[PeerInfo]) -> usize { + let mut written = 0; + for node in nodes { + if written + INET6_NODE_SIZE > buf.len() { + break; + } + let off = written; + write_u16(buf, off, node.addr.port()); + write_u16(buf, off + 2, 0); // reserved + + if let SocketAddr::V6(v6) = node.addr { + let octets = v6.ip().octets(); + buf[off + 4..off + 20].copy_from_slice(&octets); + } else { + buf[off + 4..off + 20].fill(0); + } + + node.id.write_to(&mut buf[off + 20..off + 20 + ID_LEN]); + written += INET6_NODE_SIZE; + } + written +} + +/// Create a PeerInfo from a message header and source +/// address. +pub fn peer_from_header(src_id: NodeId, from: SocketAddr) -> PeerInfo { + PeerInfo::new(src_id, from) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::wire::{MsgHeader, MsgType}; + + fn make_buf(msg_type: MsgType, body_len: usize) -> Vec<u8> { + let total = HEADER_SIZE + body_len; + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + msg_type, + total as u16, + NodeId::from_bytes([0xAA; 32]), + NodeId::from_bytes([0xBB; 32]), + ); + hdr.write(&mut buf).unwrap(); + buf + } + + // ── Ping ──────────────────────────────────────── + + #[test] + fn ping_roundtrip() { + let mut buf = make_buf(MsgType::DhtPing, 4); + write_ping(&mut buf, 0xDEADBEEF); + let nonce = parse_ping(&buf).unwrap(); + assert_eq!(nonce, 0xDEADBEEF); + } + + // ── NatEchoReply ──────────────────────────────── + + #[test] + fn nat_echo_reply_roundtrip() { + let mut buf = make_buf(MsgType::NatEchoReply, NAT_ECHO_REPLY_BODY); + let reply = NatEchoReply { + nonce: 42, + domain: DOMAIN_INET, + port: 3000, + addr: { + let mut a = [0u8; 16]; + a[0..4].copy_from_slice(&[192, 168, 1, 1]); + a + }, + }; + write_nat_echo_reply(&mut buf, &reply); + let parsed = parse_nat_echo_reply(&buf).unwrap(); + assert_eq!(parsed.nonce, 42); + assert_eq!(parsed.domain, DOMAIN_INET); + assert_eq!(parsed.port, 3000); + assert_eq!(parsed.addr[0..4], [192, 168, 1, 1]); + } + + // ── FindNode ──────────────────────────────────── + + #[test] + fn find_node_roundtrip() { + let mut buf = make_buf(MsgType::DhtFindNode, FIND_NODE_BODY); + let msg = FindNodeMsg { + nonce: 99, + target: NodeId::from_bytes([0x42; 32]), + domain: DOMAIN_INET, + state: 1, + }; + write_find_node(&mut buf, &msg); + let parsed = parse_find_node(&buf).unwrap(); + assert_eq!(parsed.nonce, 99); + assert_eq!(parsed.target, msg.target); + assert_eq!(parsed.domain, DOMAIN_INET); + } + + // ── FindNodeReply ─────────────────────────────── + + #[test] + fn find_node_reply_roundtrip() { + let nodes = vec![ + PeerInfo::new( + NodeId::from_bytes([0x01; 32]), + "127.0.0.1:3000".parse().unwrap(), + ), + PeerInfo::new( + NodeId::from_bytes([0x02; 32]), + "127.0.0.1:3001".parse().unwrap(), + ), + ]; + let msg = FindNodeReplyMsg { + nonce: 55, + id: NodeId::from_bytes([0x42; 32]), + domain: DOMAIN_INET, + nodes: nodes.clone(), + }; + + let mut buf = vec![0u8; 1024]; + let hdr = MsgHeader::new( + MsgType::DhtFindNodeReply, + 0, // will fix + NodeId::from_bytes([0xAA; 32]), + NodeId::from_bytes([0xBB; 32]), + ); + hdr.write(&mut buf).unwrap(); + let _total = write_find_node_reply(&mut buf, &msg); + + let parsed = parse_find_node_reply(&buf).unwrap(); + assert_eq!(parsed.nonce, 55); + assert_eq!(parsed.nodes.len(), 2); + assert_eq!(parsed.nodes[0].id, nodes[0].id); + assert_eq!(parsed.nodes[0].addr.port(), 3000); + assert_eq!(parsed.nodes[1].id, nodes[1].id); + } + + // ── Store ─────────────────────────────────────── + + #[test] + fn store_roundtrip() { + let msg = StoreMsg { + id: NodeId::from_bytes([0x10; 32]), + from: NodeId::from_bytes([0x20; 32]), + key: b"mykey".to_vec(), + value: b"myvalue".to_vec(), + ttl: 300, + is_unique: true, + }; + let total = HEADER_SIZE + STORE_FIXED + msg.key.len() + msg.value.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::DhtStore, + total as u16, + NodeId::from_bytes([0xAA; 32]), + NodeId::from_bytes([0xBB; 32]), + ); + hdr.write(&mut buf).unwrap(); + write_store(&mut buf, &msg).unwrap(); + + let parsed = parse_store(&buf).unwrap(); + assert_eq!(parsed.id, msg.id); + assert_eq!(parsed.from, msg.from); + assert_eq!(parsed.key, b"mykey"); + assert_eq!(parsed.value, b"myvalue"); + assert_eq!(parsed.ttl, 300); + assert!(parsed.is_unique); + } + + #[test] + fn store_not_unique() { + let msg = StoreMsg { + id: NodeId::from_bytes([0x10; 32]), + from: NodeId::from_bytes([0x20; 32]), + key: b"k".to_vec(), + value: b"v".to_vec(), + ttl: 60, + is_unique: false, + }; + let total = HEADER_SIZE + STORE_FIXED + msg.key.len() + msg.value.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::DhtStore, + total as u16, + NodeId::from_bytes([0xAA; 32]), + NodeId::from_bytes([0xBB; 32]), + ); + hdr.write(&mut buf).unwrap(); + write_store(&mut buf, &msg).unwrap(); + + let parsed = parse_store(&buf).unwrap(); + assert!(!parsed.is_unique); + } + + // ── FindValue ─────────────────────────────────── + + #[test] + fn find_value_roundtrip() { + let msg = FindValueMsg { + nonce: 77, + target: NodeId::from_bytes([0x33; 32]), + domain: DOMAIN_INET, + key: b"lookup-key".to_vec(), + use_rdp: false, + }; + let total = HEADER_SIZE + FIND_VALUE_FIXED + msg.key.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::DhtFindValue, + total as u16, + NodeId::from_bytes([0xAA; 32]), + NodeId::from_bytes([0xBB; 32]), + ); + hdr.write(&mut buf).unwrap(); + write_find_value(&mut buf, &msg).unwrap(); + + let parsed = parse_find_value(&buf).unwrap(); + assert_eq!(parsed.nonce, 77); + assert_eq!(parsed.target, msg.target); + assert_eq!(parsed.key, b"lookup-key"); + assert!(!parsed.use_rdp); + } + + // ── DtunRegister ──────────────────────────────── + + #[test] + fn dtun_register_roundtrip() { + let mut buf = make_buf(MsgType::DtunRegister, 4); + write_u32(&mut buf, HEADER_SIZE, 12345); + let session = parse_dtun_register(&buf).unwrap(); + assert_eq!(session, 12345); + } + + // ── DtunRequest ───────────────────────────────── + + #[test] + fn dtun_request_roundtrip() { + let mut buf = make_buf(MsgType::DtunRequest, 4 + ID_LEN); + let nonce = 88u32; + let target = NodeId::from_bytes([0x55; 32]); + write_u32(&mut buf, HEADER_SIZE, nonce); + target.write_to(&mut buf[HEADER_SIZE + 4..HEADER_SIZE + 4 + ID_LEN]); + + let (n, t) = parse_dtun_request(&buf).unwrap(); + assert_eq!(n, 88); + assert_eq!(t, target); + } + + // ── Node list ─────────────────────────────────── + + #[test] + fn inet_nodes_roundtrip() { + let nodes = vec![ + PeerInfo::new( + NodeId::from_bytes([0x01; 32]), + "10.0.0.1:8000".parse().unwrap(), + ), + PeerInfo::new( + NodeId::from_bytes([0x02; 32]), + "10.0.0.2:9000".parse().unwrap(), + ), + ]; + let mut buf = vec![0u8; INET_NODE_SIZE * 2]; + let written = write_nodes_inet(&mut buf, &nodes); + assert_eq!(written, INET_NODE_SIZE * 2); + + let parsed = read_nodes_inet(&buf, 2); + assert_eq!(parsed.len(), 2); + assert_eq!(parsed[0].id, nodes[0].id); + assert_eq!(parsed[0].addr.port(), 8000); + assert_eq!(parsed[1].addr.port(), 9000); + } + + // ── Truncated inputs ──────────────────────────── + + #[test] + fn parse_store_truncated() { + let buf = make_buf(MsgType::DhtStore, 2); // too small + assert!(matches!(parse_store(&buf), Err(Error::BufferTooSmall))); + } + + #[test] + fn parse_find_node_truncated() { + let buf = make_buf(MsgType::DhtFindNode, 2); + assert!(matches!(parse_find_node(&buf), Err(Error::BufferTooSmall))); + } + + #[test] + fn parse_find_value_truncated() { + let buf = make_buf(MsgType::DhtFindValue, 2); + assert!(matches!(parse_find_value(&buf), Err(Error::BufferTooSmall))); + } +} diff --git a/src/nat.rs b/src/nat.rs new file mode 100644 index 0000000..cfd5729 --- /dev/null +++ b/src/nat.rs @@ -0,0 +1,384 @@ +//! NAT type detection (STUN-like echo protocol). +//! +//! Detects whether +//! this node has a public IP, is behind a cone NAT, or is +//! behind a symmetric NAT. The result determines how the +//! node communicates: +//! +//! - **Global**: direct communication. +//! - **ConeNat**: hole-punching via DTUN. +//! - **SymmetricNat**: relay via proxy. +//! +//! ## Detection protocol +//! +//! 1. Send NatEcho to a known node. +//! 2. Receive NatEchoReply with our observed external +//! address and port. +//! 3. If observed == local → Global. +//! 4. If different → we're behind NAT. Send +//! NatEchoRedirect asking the peer to have a *third* +//! node echo us from a different port. +//! 5. Receive NatEchoRedirectReply with observed port +//! from the third node. +//! 6. If ports match → ConeNat (same external port for +//! different destinations). +//! 7. If ports differ → SymmetricNat (external port +//! changes per destination). + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +/// NAT detection echo timeout. +pub const ECHO_TIMEOUT: Duration = Duration::from_secs(3); + +/// Periodic re-detection interval. +pub const NAT_TIMER_INTERVAL: Duration = Duration::from_secs(30); + +/// Node's NAT state (visible to the application). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NatState { + /// Not yet determined. + Unknown, + + /// Public IP — direct communication. + Global, + + /// Behind NAT, type not yet classified. + Nat, + + /// Cone NAT — hole-punching works. + ConeNat, + + /// Symmetric NAT — must use a proxy relay. + SymmetricNat, +} + +/// Internal state machine for the detection protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DetectorState { + /// Initial / idle. + Undefined, + + /// Sent first echo, waiting for reply. + EchoWait1, + + /// Behind NAT, sent redirect request, waiting. + EchoRedirectWait, + + /// Confirmed global (public IP). + Global, + + /// Behind NAT, type unknown yet. + Nat, + + /// Sent second echo (via redirected node), waiting. + EchoWait2, + + /// Confirmed cone NAT. + ConeNat, + + /// Confirmed symmetric NAT. + SymmetricNat, +} + +/// Pending echo request tracked by nonce. +struct EchoPending { + sent_at: Instant, +} + +/// NAT type detector. +/// +/// Drives the echo protocol state machine. The caller +/// feeds received messages via `recv_*` methods and +/// reads the result via `state()`. +pub struct NatDetector { + state: DetectorState, + + /// Our detected external address (if behind NAT). + global_addr: Option<SocketAddr>, + + /// Port observed in the first echo reply. + echo1_port: Option<u16>, + + /// Pending echo requests keyed by nonce. + pending: HashMap<u32, EchoPending>, +} + +impl NatDetector { + pub fn new(_local_id: NodeId) -> Self { + Self { + state: DetectorState::Undefined, + global_addr: None, + echo1_port: None, + pending: HashMap::new(), + } + } + + /// Current NAT state as seen by the application. + pub fn state(&self) -> NatState { + match self.state { + DetectorState::Undefined + | DetectorState::EchoWait1 + | DetectorState::EchoRedirectWait + | DetectorState::EchoWait2 => NatState::Unknown, + DetectorState::Global => NatState::Global, + DetectorState::Nat => NatState::Nat, + DetectorState::ConeNat => NatState::ConeNat, + DetectorState::SymmetricNat => NatState::SymmetricNat, + } + } + + /// Our detected global address, if known. + pub fn global_addr(&self) -> Option<SocketAddr> { + self.global_addr + } + + /// Force the NAT state (e.g. from configuration). + pub fn set_state(&mut self, s: NatState) { + self.state = match s { + NatState::Unknown => DetectorState::Undefined, + NatState::Global => DetectorState::Global, + NatState::Nat => DetectorState::Nat, + NatState::ConeNat => DetectorState::ConeNat, + NatState::SymmetricNat => DetectorState::SymmetricNat, + }; + } + + /// Start detection: prepare a NatEcho to send to a + /// known peer. + /// + /// Returns the nonce to include in the echo message. + pub fn start_detect(&mut self, nonce: u32) -> u32 { + self.state = DetectorState::EchoWait1; + self.pending.insert( + nonce, + EchoPending { + sent_at: Instant::now(), + }, + ); + nonce + } + + /// Handle a NatEchoReply: the peer tells us our + /// observed external address and port. + /// Replies in unexpected states are silently ignored + /// (catch-all returns `EchoReplyAction::Ignore`). + pub fn recv_echo_reply( + &mut self, + nonce: u32, + observed_addr: SocketAddr, + local_addr: SocketAddr, + ) -> EchoReplyAction { + if self.pending.remove(&nonce).is_none() { + return EchoReplyAction::Ignore; + } + + match self.state { + DetectorState::EchoWait1 => { + if observed_addr.port() == local_addr.port() + && observed_addr.ip() == local_addr.ip() + { + // Our address matches → we have a public IP + self.state = DetectorState::Global; + self.global_addr = Some(observed_addr); + EchoReplyAction::DetectionComplete(NatState::Global) + } else { + // Behind NAT — need redirect test + self.state = DetectorState::Nat; + self.global_addr = Some(observed_addr); + self.echo1_port = Some(observed_addr.port()); + EchoReplyAction::NeedRedirect + } + } + DetectorState::EchoWait2 => { + // Reply from redirected third node + let port2 = observed_addr.port(); + if Some(port2) == self.echo1_port { + // Same external port → Cone NAT + self.state = DetectorState::ConeNat; + EchoReplyAction::DetectionComplete(NatState::ConeNat) + } else { + // Different port → Symmetric NAT + self.state = DetectorState::SymmetricNat; + EchoReplyAction::DetectionComplete(NatState::SymmetricNat) + } + } + _ => EchoReplyAction::Ignore, + } + } + + /// After receiving NeedRedirect, start the redirect + /// phase with a new nonce. + pub fn start_redirect(&mut self, nonce: u32) { + self.state = DetectorState::EchoRedirectWait; + self.pending.insert( + nonce, + EchoPending { + sent_at: Instant::now(), + }, + ); + } + + /// The redirect was forwarded and a third node will + /// send us an echo. Transition to EchoWait2. + pub fn redirect_sent(&mut self, nonce: u32) { + self.state = DetectorState::EchoWait2; + self.pending.insert( + nonce, + EchoPending { + sent_at: Instant::now(), + }, + ); + } + + /// Expire timed-out echo requests. + /// + /// If all pending requests timed out and we haven't + /// completed detection, reset to Undefined. + pub fn expire_pending(&mut self) { + self.pending + .retain(|_, p| p.sent_at.elapsed() < ECHO_TIMEOUT); + + if self.pending.is_empty() { + match self.state { + DetectorState::EchoWait1 + | DetectorState::EchoRedirectWait + | DetectorState::EchoWait2 => { + log::debug!("NAT detection timed out, resetting"); + self.state = DetectorState::Undefined; + } + _ => {} + } + } + } + + /// Check if detection is still in progress. + pub fn is_detecting(&self) -> bool { + matches!( + self.state, + DetectorState::EchoWait1 + | DetectorState::EchoRedirectWait + | DetectorState::EchoWait2 + ) + } + + /// Check if detection is complete (any terminal state). + pub fn is_complete(&self) -> bool { + matches!( + self.state, + DetectorState::Global + | DetectorState::ConeNat + | DetectorState::SymmetricNat + ) + } +} + +/// Action to take after processing an echo reply. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EchoReplyAction { + /// Ignore (unknown nonce or wrong state). + Ignore, + + /// Detection complete with the given NAT state. + DetectionComplete(NatState), + + /// Need to send a NatEchoRedirect to determine + /// NAT type (cone vs symmetric). + NeedRedirect, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn local() -> SocketAddr { + "192.168.1.100:5000".parse().unwrap() + } + + #[test] + fn detect_global() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + let nonce = nd.start_detect(1); + assert!(nd.is_detecting()); + + let action = nd.recv_echo_reply(nonce, local(), local()); + assert_eq!( + action, + EchoReplyAction::DetectionComplete(NatState::Global) + ); + assert_eq!(nd.state(), NatState::Global); + assert!(nd.is_complete()); + } + + #[test] + fn detect_cone_nat() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + + // Phase 1: echo shows different address + let n1 = nd.start_detect(1); + let observed: SocketAddr = "1.2.3.4:5000".parse().unwrap(); + let action = nd.recv_echo_reply(n1, observed, local()); + assert_eq!(action, EchoReplyAction::NeedRedirect); + assert_eq!(nd.state(), NatState::Nat); + + // Phase 2: redirect → third node echoes with + // same port + nd.redirect_sent(2); + let observed2: SocketAddr = "1.2.3.4:5000".parse().unwrap(); + let action = nd.recv_echo_reply(2, observed2, local()); + assert_eq!( + action, + EchoReplyAction::DetectionComplete(NatState::ConeNat) + ); + assert_eq!(nd.state(), NatState::ConeNat); + } + + #[test] + fn detect_symmetric_nat() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + + // Phase 1: behind NAT + let n1 = nd.start_detect(1); + let observed: SocketAddr = "1.2.3.4:5000".parse().unwrap(); + nd.recv_echo_reply(n1, observed, local()); + + // Phase 2: different port from third node + nd.redirect_sent(2); + let observed2: SocketAddr = "1.2.3.4:6000".parse().unwrap(); + let action = nd.recv_echo_reply(2, observed2, local()); + assert_eq!( + action, + EchoReplyAction::DetectionComplete(NatState::SymmetricNat) + ); + } + + #[test] + fn unknown_nonce_ignored() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + nd.start_detect(1); + let action = nd.recv_echo_reply(999, local(), local()); + assert_eq!(action, EchoReplyAction::Ignore); + } + + #[test] + fn force_state() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + nd.set_state(NatState::Global); + assert_eq!(nd.state(), NatState::Global); + assert!(nd.is_complete()); + } + + #[test] + fn timeout_resets() { + let mut nd = NatDetector::new(NodeId::from_bytes([0x01; 32])); + nd.start_detect(1); + + // Clear pending manually to simulate timeout + nd.pending.clear(); + nd.expire_pending(); + assert_eq!(nd.state(), NatState::Unknown); + } +} diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..aa6d2a7 --- /dev/null +++ b/src/net.rs @@ -0,0 +1,744 @@ +//! Network send helpers and query management. +//! +//! Extension impl block for Node. Contains send_signed, +//! verify_incoming, send_find_node, send_store, query +//! batch sending, and liveness probing. + +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::dgram; +use crate::dht::IterativeQuery; +use crate::error::Error; +use crate::id::NodeId; +use crate::msg; +use crate::nat::NatState; +use crate::node::Node; +use crate::peers::PeerInfo; +use crate::wire::{DOMAIN_INET, DOMAIN_INET6, HEADER_SIZE, MsgHeader, MsgType}; + +impl Node { + // ── Network ───────────────────────────────────── + + /// Local socket address. + pub fn local_addr(&self) -> Result<SocketAddr, Error> { + self.net.local_addr() + } + + /// Join the DHT network via a bootstrap node. + /// + /// Starts an iterative FIND_NODE for our own ID. + /// The query is driven by subsequent `poll()` calls. + /// DNS resolution is synchronous. + /// + /// # Example + /// + /// ```rust,no_run + /// # let mut node = tesseras_dht::Node::bind(0).unwrap(); + /// node.join("bootstrap.example.com", 10000).unwrap(); + /// loop { node.poll().unwrap(); } + /// ``` + pub fn join(&mut self, host: &str, port: u16) -> Result<(), Error> { + use std::net::ToSocketAddrs; + + let addr_str = format!("{host}:{port}"); + let addr = addr_str + .to_socket_addrs() + .map_err(Error::Io)? + .next() + .ok_or_else(|| { + Error::Io(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("no addresses for {addr_str}"), + )) + })?; + + log::info!("Joining via {addr}"); + + // Send DHT FIND_NODE to bootstrap (populates + // DHT routing table via handle_dht_find_node_reply) + self.send_find_node(addr, self.id)?; + + if self.is_dtun { + // Send DTUN FIND_NODE to bootstrap too + // (populates DTUN routing table) + let nonce = self.alloc_nonce(); + let domain = if addr.is_ipv4() { + crate::wire::DOMAIN_INET + } else { + crate::wire::DOMAIN_INET6 + }; + let find = crate::msg::FindNodeMsg { + nonce, + target: self.id, + domain, + state: 0, + }; + let mut buf = [0u8; crate::msg::FIND_NODE_MSG_SIZE]; + let hdr = MsgHeader::new( + MsgType::DtunFindNode, + Self::len16(crate::msg::FIND_NODE_MSG_SIZE), + self.id, + NodeId::from_bytes([0; 32]), + ); + if hdr.write(&mut buf).is_ok() { + crate::msg::write_find_node(&mut buf, &find); + if let Err(e) = self.send_signed(&buf, addr) { + log::debug!("DTUN find_node send failed: {e}"); + } + } + } + + // Start NAT detection if DTUN is enabled + if self.is_dtun && !self.nat.is_complete() { + let nonce = self.alloc_nonce(); + self.nat.start_detect(nonce); + + // Send NatEcho to bootstrap + let size = HEADER_SIZE + 4; + let mut buf = vec![0u8; size]; + let hdr = MsgHeader::new( + MsgType::NatEcho, + Self::len16(size), + self.id, + NodeId::from_bytes([0; crate::id::ID_LEN]), + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&nonce.to_be_bytes()); + if let Err(e) = self.send_signed(&buf, addr) { + log::warn!("Failed to send NatEcho: {e}"); + } + log::debug!("Sent NatEcho to {addr} nonce={nonce}"); + } + } + + Ok(()) + } + + /// Start an iterative FIND_NODE query for a target. + /// + /// Returns the query nonce. The query is driven by + /// `poll()` and completes when converged. + pub fn start_find_node(&mut self, target: NodeId) -> Result<u32, Error> { + let nonce = self.alloc_nonce(); + let mut query = IterativeQuery::find_node(target, nonce); + + // Seed with our closest known nodes + let closest = + self.dht_table.closest(&target, self.config.num_find_node); + query.closest = closest; + + // Limit concurrent queries to prevent memory + // exhaustion (max 100) + const MAX_CONCURRENT_QUERIES: usize = 100; + if self.queries.len() >= MAX_CONCURRENT_QUERIES { + log::warn!("Too many concurrent queries, dropping"); + return Err(Error::Timeout); + } + + // Send initial batch + self.send_query_batch(nonce)?; + + self.queries.insert(nonce, query); + Ok(nonce) + } + + /// Send a FIND_NODE message to a specific address. + pub(crate) fn send_find_node( + &mut self, + to: SocketAddr, + target: NodeId, + ) -> Result<u32, Error> { + let nonce = self.alloc_nonce(); + let domain = if to.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }; + + let find = msg::FindNodeMsg { + nonce, + target, + domain, + state: 0, + }; + + let mut buf = [0u8; msg::FIND_NODE_MSG_SIZE]; + let hdr = MsgHeader::new( + MsgType::DhtFindNode, + Self::len16(msg::FIND_NODE_MSG_SIZE), + self.id, + NodeId::from_bytes([0; crate::id::ID_LEN]), // unknown dst + ); + hdr.write(&mut buf)?; + msg::write_find_node(&mut buf, &find); + self.send_signed(&buf, to)?; + + log::debug!("Sent find_node to {to} target={target:?} nonce={nonce}"); + Ok(nonce) + } + + /// Send a STORE message to a specific peer. + pub(crate) fn send_store( + &mut self, + peer: &PeerInfo, + store: &msg::StoreMsg, + ) -> Result<(), Error> { + let total = HEADER_SIZE + + msg::STORE_FIXED + + store.key.len() + + store.value.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::DhtStore, + Self::len16(total), + self.id, + peer.id, + ); + hdr.write(&mut buf)?; + msg::write_store(&mut buf, store)?; + self.send_signed(&buf, peer.addr)?; + log::debug!( + "Sent store to {:?} key={} bytes", + peer.id, + store.key.len() + ); + Ok(()) + } + + /// Send the next batch of queries for an active + /// iterative query. Uses FIND_VALUE for find_value + /// queries, FIND_NODE otherwise. + pub(crate) fn send_query_batch(&mut self, nonce: u32) -> Result<(), Error> { + let query = match self.queries.get(&nonce) { + Some(q) => q, + None => return Ok(()), + }; + + let to_query = query.next_to_query(); + let target = query.target; + let is_find_value = query.is_find_value; + let key = query.key.clone(); + + for peer in to_query { + let result = if is_find_value { + self.send_find_value_msg(peer.addr, target, &key) + } else { + self.send_find_node(peer.addr, target) + }; + + if let Err(e) = result { + log::debug!("Failed to send query to {:?}: {e}", peer.id); + continue; + } + if let Some(q) = self.queries.get_mut(&nonce) { + q.pending.insert(peer.id, Instant::now()); + } + } + Ok(()) + } + + /// Drive all active iterative queries: expire + /// timeouts, send next batches, clean up finished. + pub(crate) fn drive_queries(&mut self) { + // Expire timed-out pending requests + let nonces: Vec<u32> = self.queries.keys().copied().collect(); + + for nonce in &nonces { + if let Some(q) = self.queries.get_mut(nonce) { + q.expire_pending(); + } + } + + // Send next batch for active queries + for nonce in &nonces { + let is_active = self + .queries + .get(nonce) + .map(|q| !q.is_done()) + .unwrap_or(false); + + if is_active { + if let Err(e) = self.send_query_batch(*nonce) { + log::debug!("Query batch send failed: {e}"); + } + } + } + + // Remove completed queries + self.queries.retain(|nonce, q| { + if q.is_done() { + log::debug!( + "Query nonce={nonce} complete: {} closest, {} hops, {}ms", + q.closest.len(), + q.hops, + q.started_at.elapsed().as_millis(), + ); + self.metrics + .lookups_completed + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + false + } else { + true + } + }); + } + + /// Refresh stale routing table buckets by starting + /// FIND_NODE queries for random IDs in each stale + /// bucket's range. + pub(crate) fn refresh_buckets(&mut self) { + let threshold = self.config.refresh_interval; + + if self.last_refresh.elapsed() < threshold { + return; + } + self.last_refresh = Instant::now(); + + let targets = self.dht_table.stale_bucket_targets(threshold); + + if targets.is_empty() { + return; + } + + // Limit to 3 refresh queries per cycle to avoid + // flooding the query queue (256 empty buckets would + // otherwise spawn 256 queries at once). + let batch = targets.into_iter().take(3); + log::debug!("Refreshing stale buckets"); + + for target in batch { + if let Err(e) = self.start_find_node(target) { + log::debug!("Refresh find_node failed: {e}"); + break; // query queue full, stop + } + } + } + + /// Ping LRU peers in each bucket to verify they're + /// alive. Dead peers are removed from the routing + /// table. Called alongside refresh_buckets. + pub(crate) fn probe_liveness(&mut self) { + let lru_peers = self.dht_table.lru_peers(); + if lru_peers.is_empty() { + return; + } + + for peer in &lru_peers { + // Skip if we've seen them recently + if peer.last_seen.elapsed() < Duration::from_secs(30) { + continue; + } + let nonce = self.alloc_nonce(); + let size = msg::PING_MSG_SIZE; + let mut buf = [0u8; msg::PING_MSG_SIZE]; + let hdr = MsgHeader::new( + MsgType::DhtPing, + Self::len16(size), + self.id, + peer.id, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&nonce.to_be_bytes()); + let _ = self.send_signed(&buf, peer.addr); + self.pending_pings.insert(nonce, (peer.id, Instant::now())); + } + } + } + + /// Try to send queued datagrams whose destinations + /// are now in the peer store. + pub(crate) fn drain_send_queue(&mut self) { + let known_ids: Vec<crate::id::NodeId> = self.peers.ids(); + + for dst_id in known_ids { + if !self.send_queue.has_pending(&dst_id) { + continue; + } + let peer = match self.peers.get(&dst_id).cloned() { + Some(p) => p, + None => continue, + }; + let queued = self.send_queue.drain(&dst_id); + for item in queued { + self.send_dgram_raw(&item.data, &peer); + } + } + } + + /// Handle an incoming Dgram message: reassemble + /// fragments and invoke the dgram callback. + pub(crate) fn handle_dgram(&mut self, buf: &[u8], hdr: &MsgHeader) { + let payload = &buf[HEADER_SIZE..]; + + let (total, index, frag_data) = match dgram::parse_fragment(payload) { + Some(v) => v, + None => return, + }; + + let complete = + self.reassembler + .feed(hdr.src, total, index, frag_data.to_vec()); + + if let Some(data) = complete { + log::debug!( + "Dgram reassembled: {} bytes from {:?}", + data.len(), + hdr.src + ); + if let Some(ref cb) = self.dgram_callback { + cb(&data, &hdr.src); + } + } + } + + /// Flush pending RDP output for a connection, + /// sending packets via UDP. + pub(crate) fn flush_rdp_output(&mut self, desc: i32) { + let output = match self.rdp.pending_output(desc) { + Some(o) => o, + None => return, + }; + + // Determine target address and message type + let (send_addr, msg_type) = + if self.nat.state() == NatState::SymmetricNat { + // Route through proxy + if let Some(server) = self.proxy.server() { + (server.addr, MsgType::ProxyRdp) + } else if let Some(peer) = self.peers.get(&output.dst) { + (peer.addr, MsgType::Rdp) + } else { + log::debug!("RDP: no route for {:?}", output.dst); + return; + } + } else if let Some(peer) = self.peers.get(&output.dst) { + (peer.addr, MsgType::Rdp) + } else { + log::debug!("RDP: no address for {:?}", output.dst); + return; + }; + + for pkt in &output.packets { + let rdp_wire = crate::rdp::build_rdp_wire( + pkt.flags, + output.sport, + output.dport, + pkt.seqnum, + pkt.acknum, + &pkt.data, + ); + + let total = HEADER_SIZE + rdp_wire.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + msg_type, + Self::len16(total), + self.id, + output.dst, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..].copy_from_slice(&rdp_wire); + let _ = self.send_signed(&buf, send_addr); + } + } + } + + /// Flush pending RDP output for all connections. + pub(crate) fn flush_all_rdp(&mut self) { + let descs = self.rdp.descriptors(); + for desc in descs { + self.flush_rdp_output(desc); + } + } + + /// Handle an incoming RDP packet. + pub(crate) fn handle_rdp(&mut self, buf: &[u8], hdr: &MsgHeader) { + let payload = &buf[HEADER_SIZE..]; + let wire = match crate::rdp::parse_rdp_wire(payload) { + Some(w) => w, + None => return, + }; + + log::debug!( + "RDP from {:?}: flags=0x{:02x} \ + sport={} dport={} \ + seq={} ack={} \ + data={} bytes", + hdr.src, + wire.flags, + wire.sport, + wire.dport, + wire.seqnum, + wire.acknum, + wire.data.len() + ); + + let input = crate::rdp::RdpInput { + src: hdr.src, + sport: wire.sport, + dport: wire.dport, + flags: wire.flags, + seqnum: wire.seqnum, + acknum: wire.acknum, + data: wire.data, + }; + let actions = self.rdp.input(&input); + + for action in actions { + match action { + crate::rdp::RdpAction::Event { + desc, + ref addr, + event, + } => { + log::info!("RDP event: desc={desc} {:?}", event); + + // Invoke app callback + if let Some(ref cb) = self.rdp_callback { + cb(desc, addr, event); + } + + // After accept/connect, flush SYN-ACK/ACK + self.flush_rdp_output(desc); + } + crate::rdp::RdpAction::Close(desc) => { + self.rdp.close(desc); + } + } + } + } + + /// Register this node with DTUN (for NAT traversal). + /// + /// Sends DTUN_REGISTER to the k-closest global nodes + /// so other peers can find us via hole-punching. + pub(crate) fn dtun_register(&mut self) { + let (session, closest) = self.dtun.prepare_register(); + log::info!( + "DTUN register: session={session}, {} targets", + closest.len() + ); + + for peer in &closest { + let size = HEADER_SIZE + 4; + let mut buf = vec![0u8; size]; + let hdr = MsgHeader::new( + MsgType::DtunRegister, + Self::len16(size), + self.id, + peer.id, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&session.to_be_bytes()); + let _ = self.send_signed(&buf, peer.addr); + } + } + + self.dtun.registration_done(); + } + + /// Send a packet with Ed25519 signature appended. + /// + /// Appends 64-byte signature after the packet body. + /// All outgoing packets go through this method. + pub(crate) fn send_signed( + &self, + buf: &[u8], + to: SocketAddr, + ) -> Result<usize, Error> { + let mut signed = buf.to_vec(); + crate::wire::sign_packet(&mut signed, &self.identity); + self.metrics + .messages_sent + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.metrics.bytes_sent.fetch_add( + signed.len() as u64, + std::sync::atomic::Ordering::Relaxed, + ); + // Note: reply sends use `let _ =` (best-effort). + // Critical sends (join, put, DTUN register) use + // `if let Err(e) = ... { log::warn! }`. + self.net.send_to(&signed, to) + } + + /// Proactively check node activity by pinging peers + /// in the routing table. Peers that fail to respond + /// accumulate failures in the ban list. This keeps + /// the routing table healthy by detecting dead nodes + /// before queries need them. + pub(crate) fn check_node_activity(&mut self) { + if self.last_activity_check.elapsed() + < self.config.activity_check_interval + { + return; + } + self.last_activity_check = Instant::now(); + + let lru_peers = self.dht_table.lru_peers(); + if lru_peers.is_empty() { + return; + } + + let mut pinged = 0u32; + for peer in &lru_peers { + // Only ping peers not seen for >60s + if peer.last_seen.elapsed() < Duration::from_secs(60) { + continue; + } + // Skip banned peers + if self.ban_list.is_banned(&peer.addr) { + continue; + } + + let nonce = self.alloc_nonce(); + let size = crate::msg::PING_MSG_SIZE; + let mut buf = [0u8; crate::msg::PING_MSG_SIZE]; + let hdr = MsgHeader::new( + MsgType::DhtPing, + Self::len16(size), + self.id, + peer.id, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&nonce.to_be_bytes()); + let _ = self.send_signed(&buf, peer.addr); + self.pending_pings.insert(nonce, (peer.id, Instant::now())); + pinged += 1; + } + } + + if pinged > 0 { + log::debug!("Activity check: pinged {pinged} peers"); + } + } + + /// Sweep timed-out stores and retry with alternative + /// peers. Failed peers accumulate ban list failures. + pub(crate) fn retry_failed_stores(&mut self) { + if self.last_store_retry.elapsed() < self.config.store_retry_interval { + return; + } + self.last_store_retry = Instant::now(); + + let retries = self.store_tracker.collect_timeouts(); + if retries.is_empty() { + return; + } + + log::debug!("Store retry: {} timed-out stores", retries.len()); + + for retry in &retries { + // Record failure for the peer that didn't ack + if let Some(peer_info) = self.peers.get(&retry.failed_peer) { + self.ban_list.record_failure(peer_info.addr); + } + + // Find alternative peers (excluding the failed one) + let closest = self + .dht_table + .closest(&retry.target, self.config.num_find_node); + + let store_msg = crate::msg::StoreMsg { + id: retry.target, + from: self.id, + key: retry.key.clone(), + value: retry.value.clone(), + ttl: retry.ttl, + is_unique: retry.is_unique, + }; + + let mut sent = false; + for peer in &closest { + if peer.id == retry.failed_peer { + continue; + } + if self.ban_list.is_banned(&peer.addr) { + continue; + } + if let Err(e) = self.send_store(peer, &store_msg) { + log::debug!("Store retry send failed: {e}"); + continue; + } + self.store_tracker.track( + retry.target, + retry.key.clone(), + retry.value.clone(), + retry.ttl, + retry.is_unique, + peer.clone(), + ); + sent = true; + break; // one alternative peer is enough per retry + } + + if !sent { + log::debug!( + "Store retry: no alternative peer for key ({} bytes)", + retry.key.len() + ); + } + } + } + + /// Verify Ed25519 signature on an incoming packet. + /// + /// Since NodeId = Ed25519 public key, the src field + /// in the header IS the public key. The signature + /// proves the sender holds the private key for that + /// NodeId. + /// + /// Additionally, if we already know this peer, verify + /// the source address matches to prevent IP spoofing. + pub(crate) fn verify_incoming<'a>( + &self, + buf: &'a [u8], + from: std::net::SocketAddr, + ) -> Option<&'a [u8]> { + if buf.len() < HEADER_SIZE + crate::crypto::SIGNATURE_SIZE { + log::trace!("Rejecting unsigned packet ({} bytes)", buf.len()); + return None; + } + + // NodeId IS the public key — verify signature + let src = NodeId::read_from(&buf[8..8 + crate::id::ID_LEN]); + + // Reject zero NodeId + if src.is_zero() { + log::trace!("Rejecting packet from zero NodeId"); + return None; + } + + let pubkey = src.as_bytes(); + + if !crate::wire::verify_packet(buf, pubkey) { + log::trace!("Signature verification failed"); + self.metrics + .packets_rejected + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + return None; + } + + // If we know this peer, verify source address + // (prevents IP spoofing with valid signatures) + if let Some(known) = self.peers.get(&src) { + if known.addr != from { + log::debug!( + "Peer {:?} address mismatch: known={} got={}", + src, + known.addr, + from + ); + // Allow — peer may have changed IP (NAT rebind) + // but log for monitoring + } + } + + let payload_end = buf.len() - crate::crypto::SIGNATURE_SIZE; + Some(&buf[..payload_end]) + } +} diff --git a/src/node.rs b/src/node.rs new file mode 100644 index 0000000..fef917e --- /dev/null +++ b/src/node.rs @@ -0,0 +1,1395 @@ +//! Main facade: the `Node` node. +//! +//! Owns all subsystems and provides the public API for +//! joining the network, +//! storing/retrieving values, sending datagrams, and +//! using reliable transport (RDP). + +use std::collections::HashMap; +use std::fmt; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::advertise::Advertise; +use crate::dgram::{self, Reassembler, SendQueue}; +use crate::dht::{DhtStorage, IterativeQuery, MaskBitExplorer}; +use crate::dtun::Dtun; +use crate::error::Error; +use crate::id::NodeId; +use crate::msg; +use crate::nat::{NatDetector, NatState}; +use crate::peers::{PeerInfo, PeerStore}; +use crate::proxy::Proxy; +use crate::rdp::{Rdp, RdpError, RdpState, RdpStatus}; +use crate::routing::RoutingTable; +use crate::socket::{NetLoop, UDP_TOKEN}; +use crate::timer::TimerWheel; +use crate::wire::{DOMAIN_INET, DOMAIN_INET6, HEADER_SIZE, MsgHeader, MsgType}; + +/// Default poll timeout when no timers are scheduled. +const DEFAULT_POLL_TIMEOUT: Duration = Duration::from_millis(100); + +type DgramCallback = Box<dyn Fn(&[u8], &NodeId) + Send>; +type RdpCallback = + Box<dyn Fn(i32, &crate::rdp::RdpAddr, crate::rdp::RdpEvent) + Send>; + +/// The tesseras-dht node. +/// +/// This is the main entry point. It owns all subsystems +/// (DHT, DTUN, NAT detector, proxy, RDP, datagrams, +/// peers, timers, network I/O) and exposes a clean API +/// for the tesseras-dht node. +pub struct Node { + pub(crate) identity: crate::crypto::Identity, + pub(crate) id: NodeId, + pub(crate) net: NetLoop, + pub(crate) dht_table: RoutingTable, + pub(crate) dtun: Dtun, + pub(crate) nat: NatDetector, + pub(crate) proxy: Proxy, + pub(crate) rdp: Rdp, + pub(crate) storage: DhtStorage, + pub(crate) peers: PeerStore, + pub(crate) timers: TimerWheel, + pub(crate) advertise: Advertise, + pub(crate) reassembler: Reassembler, + pub(crate) send_queue: SendQueue, + pub(crate) explorer: MaskBitExplorer, + pub(crate) is_dtun: bool, + pub(crate) dgram_callback: Option<DgramCallback>, + pub(crate) rdp_callback: Option<RdpCallback>, + + /// Active iterative queries keyed by nonce. + pub(crate) queries: HashMap<u32, IterativeQuery>, + + /// Last bucket refresh time. + pub(crate) last_refresh: Instant, + + /// Last data restore time. + pub(crate) last_restore: Instant, + + /// Last maintain (mask_bit exploration) time. + pub(crate) last_maintain: Instant, + + /// Routing table persistence backend. + pub(crate) routing_persistence: Box<dyn crate::persist::RoutingPersistence>, + + /// Data persistence backend. + pub(crate) data_persistence: Box<dyn crate::persist::DataPersistence>, + + /// Metrics counters. + pub(crate) metrics: crate::metrics::Metrics, + /// Pending pings: nonce → (target NodeId, sent_at). + pub(crate) pending_pings: HashMap<u32, (NodeId, Instant)>, + /// Inbound rate limiter. + pub(crate) rate_limiter: crate::ratelimit::RateLimiter, + /// Node configuration. + pub(crate) config: crate::config::Config, + /// Ban list for misbehaving peers. + pub(crate) ban_list: crate::banlist::BanList, + /// Store acknowledgment tracker. + pub(crate) store_tracker: crate::store_track::StoreTracker, + /// Last node activity check time. + pub(crate) last_activity_check: Instant, + /// Last store retry sweep time. + pub(crate) last_store_retry: Instant, +} + +/// Builder for configuring a Node node. +/// +/// ```rust,no_run +/// use tesseras_dht::node::NodeBuilder; +/// use tesseras_dht::nat::NatState; +/// +/// let node = NodeBuilder::new() +/// .port(10000) +/// .nat(NatState::Global) +/// .seed(b"my-identity-seed") +/// .build() +/// .unwrap(); +/// ``` +pub struct NodeBuilder { + port: u16, + addr: Option<SocketAddr>, + pub(crate) nat: Option<NatState>, + seed: Option<Vec<u8>>, + enable_dtun: bool, + config: Option<crate::config::Config>, +} + +impl NodeBuilder { + pub fn new() -> Self { + Self { + port: 0, + addr: None, + nat: None, + seed: None, + enable_dtun: true, + config: None, + } + } + + /// Set the UDP port to bind. + pub fn port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Set a specific bind address. + pub fn addr(mut self, addr: SocketAddr) -> Self { + self.addr = Some(addr); + self + } + + /// Set the NAT state. + pub fn nat(mut self, state: NatState) -> Self { + self.nat = Some(state); + self + } + + /// Set identity seed (deterministic keypair). + pub fn seed(mut self, data: &[u8]) -> Self { + self.seed = Some(data.to_vec()); + self + } + + /// Enable or disable DTUN. + pub fn dtun(mut self, enabled: bool) -> Self { + self.enable_dtun = enabled; + self + } + + /// Set the node configuration. + pub fn config(mut self, config: crate::config::Config) -> Self { + self.config = Some(config); + self + } + + /// Build the Node node. + pub fn build(self) -> Result<Node, Error> { + let addr = self + .addr + .unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], self.port))); + + let mut node = Node::bind_addr(addr)?; + + if let Some(seed) = &self.seed { + node.set_id(seed); + } + + if let Some(nat) = self.nat { + node.set_nat_state(nat); + } + + if !self.enable_dtun { + node.is_dtun = false; + } + + if let Some(config) = self.config { + node.config = config; + } + + Ok(node) + } +} + +impl Default for NodeBuilder { + fn default() -> Self { + Self::new() + } +} + +impl Node { + /// Create a new node and bind to `port` (IPv4). + /// + /// Generates a random node ID. Use `set_id` to + /// derive an ID from application data instead. + pub fn bind(port: u16) -> Result<Self, Error> { + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + Self::bind_addr(addr) + } + + /// Create a new node bound to `port` on IPv6. + /// + /// When using IPv6, DTUN is disabled and NAT state + /// is set to Global (IPv6 does not need NAT + /// traversal). + pub fn bind_v6(port: u16) -> Result<Self, Error> { + let addr = SocketAddr::from((std::net::Ipv6Addr::UNSPECIFIED, port)); + let mut node = Self::bind_addr(addr)?; + node.is_dtun = false; + node.dtun.set_enabled(false); + node.nat.set_state(NatState::Global); + Ok(node) + } + + /// Create a new node bound to a specific address. + pub fn bind_addr(addr: SocketAddr) -> Result<Self, Error> { + let identity = crate::crypto::Identity::generate(); + let id = *identity.node_id(); + let net = NetLoop::bind(addr)?; + + log::info!("Node node {} bound to {}", id, net.local_addr()?); + + Ok(Self { + dht_table: RoutingTable::new(id), + dtun: Dtun::new(id), + nat: NatDetector::new(id), + proxy: Proxy::new(id), + rdp: Rdp::new(), + storage: DhtStorage::new(), + peers: PeerStore::new(), + timers: TimerWheel::new(), + advertise: Advertise::new(id), + reassembler: Reassembler::new(), + send_queue: SendQueue::new(), + explorer: MaskBitExplorer::new(id), + is_dtun: true, + dgram_callback: None, + rdp_callback: None, + queries: HashMap::new(), + last_refresh: Instant::now(), + last_restore: Instant::now(), + last_maintain: Instant::now(), + routing_persistence: Box::new(crate::persist::NoPersistence), + data_persistence: Box::new(crate::persist::NoPersistence), + metrics: crate::metrics::Metrics::new(), + pending_pings: HashMap::new(), + rate_limiter: crate::ratelimit::RateLimiter::default(), + config: crate::config::Config::default(), + ban_list: crate::banlist::BanList::new(), + store_tracker: crate::store_track::StoreTracker::new(), + last_activity_check: Instant::now(), + last_store_retry: Instant::now(), + identity, + id, + net, + }) + } + + // ── Identity ──────────────────────────────────── + + /// The local node ID. + pub fn id(&self) -> &NodeId { + &self.id + } + + /// The local node ID as a hex string. + pub fn id_hex(&self) -> String { + self.id.to_hex() + } + + /// Set the node identity from a 32-byte seed. + /// + /// Derives an Ed25519 keypair from the seed. + /// NodeId = public key. Deterministic: same seed + /// produces the same identity. + pub fn set_id(&mut self, data: &[u8]) { + // Hash to 32 bytes if input is not already 32 + let seed = if data.len() == 32 { + let mut s = [0u8; 32]; + s.copy_from_slice(data); + s + } else { + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(data); + let mut s = [0u8; 32]; + s.copy_from_slice(&hash); + s + }; + self.identity = crate::crypto::Identity::from_seed(seed); + self.id = *self.identity.node_id(); + self.dht_table = RoutingTable::new(self.id); + self.dtun = Dtun::new(self.id); + self.explorer = MaskBitExplorer::new(self.id); + log::info!("Node ID set to {}", self.id); + } + + /// The node's Ed25519 public key (32 bytes). + pub fn public_key(&self) -> &[u8; 32] { + self.identity.public_key() + } + + // ── NAT state ─────────────────────────────────── + + /// Current NAT detection state. + pub fn nat_state(&self) -> NatState { + self.nat.state() + } + + /// Force the NAT state. + pub fn set_nat_state(&mut self, state: NatState) { + self.nat.set_state(state); + if state == NatState::Global { + // IPv6 or explicitly global: disable DTUN + if !self.is_dtun { + self.dtun.set_enabled(false); + } + } + } + + pub(crate) fn alloc_nonce(&mut self) -> u32 { + let mut buf = [0u8; 4]; + crate::sys::random_bytes(&mut buf); + u32::from_ne_bytes(buf) + } + + /// Safe cast of packet size to u16. + #[inline] + pub(crate) fn len16(n: usize) -> u16 { + u16::try_from(n).expect("packet size exceeds u16") + } + + // ── DHT operations ────────────────────────────── + + /// Store a key-value pair in the DHT. + /// + /// The key is hashed with SHA-256 to map it to the + /// 256-bit ID space. Stored locally and sent to + /// the k-closest known nodes. + /// + /// # Example + /// + /// ```rust,no_run + /// # let mut node = tesseras_dht::Node::bind(0).unwrap(); + /// node.put(b"paste-id", b"paste content", 3600, false); + /// ``` + pub fn put(&mut self, key: &[u8], value: &[u8], ttl: u16, is_unique: bool) { + let target_id = NodeId::from_key(key); + log::debug!( + "put: key={} target={}", + String::from_utf8_lossy(key), + target_id + ); + + // Store locally + let val = crate::dht::StoredValue { + key: key.to_vec(), + value: value.to_vec(), + id: target_id, + source: self.id, + ttl, + stored_at: std::time::Instant::now(), + is_unique, + original: 3, // ORIGINAL_PUT_NUM + recvd: std::collections::HashSet::new(), + version: crate::dht::now_version(), + }; + self.storage.store(val); + + // If behind symmetric NAT, route through proxy + if self.nat.state() == NatState::SymmetricNat { + if let Some(server) = self.proxy.server().cloned() { + let store_msg = msg::StoreMsg { + id: target_id, + from: self.id, + key: key.to_vec(), + value: value.to_vec(), + ttl, + is_unique, + }; + let total = HEADER_SIZE + + msg::STORE_FIXED + + store_msg.key.len() + + store_msg.value.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::ProxyStore, + Self::len16(total), + self.id, + server.id, + ); + if hdr.write(&mut buf).is_ok() { + let _ = msg::write_store(&mut buf, &store_msg); + let _ = self.send_signed(&buf, server.addr); + } + log::info!("put: via proxy to {:?}", server.id); + return; + } + } + + // Direct: send STORE to k-closest known nodes + let closest = self + .dht_table + .closest(&target_id, self.config.num_find_node); + let store_msg = msg::StoreMsg { + id: target_id, + from: self.id, + key: key.to_vec(), + value: value.to_vec(), + ttl, + is_unique, + }; + for peer in &closest { + if let Err(e) = self.send_store(peer, &store_msg) { + log::warn!("Failed to send store to {:?}: {e}", peer.id); + } else { + self.store_tracker.track( + target_id, + key.to_vec(), + value.to_vec(), + ttl, + is_unique, + peer.clone(), + ); + } + } + log::info!("put: stored locally + sent to {} peers", closest.len()); + } + + /// Store multiple key-value pairs in the DHT. + /// + /// More efficient than calling `put()` in a loop: + /// groups stores by target peer to reduce redundant + /// lookups and sends. + pub fn put_batch(&mut self, entries: &[(&[u8], &[u8], u16, bool)]) { + // Group by target peer set to batch sends + struct BatchEntry { + target_id: NodeId, + key: Vec<u8>, + value: Vec<u8>, + ttl: u16, + is_unique: bool, + } + + let mut batch: Vec<BatchEntry> = Vec::with_capacity(entries.len()); + + for &(key, value, ttl, is_unique) in entries { + let target_id = NodeId::from_key(key); + + // Store locally + let val = crate::dht::StoredValue { + key: key.to_vec(), + value: value.to_vec(), + id: target_id, + source: self.id, + ttl, + stored_at: std::time::Instant::now(), + is_unique, + original: 3, + recvd: std::collections::HashSet::new(), + version: crate::dht::now_version(), + }; + self.storage.store(val); + + batch.push(BatchEntry { + target_id, + key: key.to_vec(), + value: value.to_vec(), + ttl, + is_unique, + }); + } + + // Collect unique peers across all targets to + // minimize redundant sends + let mut peer_stores: HashMap<NodeId, Vec<msg::StoreMsg>> = + HashMap::new(); + + for entry in &batch { + let closest = self + .dht_table + .closest(&entry.target_id, self.config.num_find_node); + let store_msg = msg::StoreMsg { + id: entry.target_id, + from: self.id, + key: entry.key.clone(), + value: entry.value.clone(), + ttl: entry.ttl, + is_unique: entry.is_unique, + }; + for peer in &closest { + peer_stores + .entry(peer.id) + .or_default() + .push(store_msg.clone()); + } + } + + // Send all stores grouped by peer + let mut total_sent = 0u32; + for (peer_id, stores) in &peer_stores { + if let Some(peer) = self.peers.get(peer_id).cloned() { + if self.ban_list.is_banned(&peer.addr) { + continue; + } + for store in stores { + if self.send_store(&peer, store).is_ok() { + total_sent += 1; + } + } + } + } + + log::info!( + "put_batch: {} entries stored locally, {total_sent} sends to {} peers", + batch.len(), + peer_stores.len(), + ); + } + + /// Retrieve multiple keys from the DHT. + /// + /// Returns a vec of (key, values) pairs. Local values + /// are returned immediately; missing keys trigger + /// iterative FIND_VALUE queries resolved via `poll()`. + pub fn get_batch( + &mut self, + keys: &[&[u8]], + ) -> Vec<(Vec<u8>, Vec<Vec<u8>>)> { + let mut results = Vec::with_capacity(keys.len()); + + for &key in keys { + let target_id = NodeId::from_key(key); + let local = self.storage.get(&target_id, key); + + if !local.is_empty() { + let vals: Vec<Vec<u8>> = + local.into_iter().map(|v| v.value).collect(); + results.push((key.to_vec(), vals)); + } else { + // Start iterative FIND_VALUE for missing keys + if let Err(e) = self.start_find_value(key) { + log::debug!("Batch find_value failed for key: {e}"); + } + results.push((key.to_vec(), Vec::new())); + } + } + + results + } + + /// Delete a key from the DHT. + /// + /// Sends a STORE with TTL=0 to the k-closest nodes, + /// which causes them to remove the value. + pub fn delete(&mut self, key: &[u8]) { + let target_id = NodeId::from_key(key); + + // Remove locally + self.storage.remove(&target_id, key); + + // Send TTL=0 store to closest nodes + let closest = self + .dht_table + .closest(&target_id, self.config.num_find_node); + let store_msg = msg::StoreMsg { + id: target_id, + from: self.id, + key: key.to_vec(), + value: Vec::new(), + ttl: 0, + is_unique: false, + }; + for peer in &closest { + let _ = self.send_store(peer, &store_msg); + } + log::info!("delete: removed locally + sent to {} peers", closest.len()); + } + + /// Retrieve values for a key from the DHT. + /// + /// First checks local storage. If not found, starts + /// an iterative FIND_VALUE query across the network. + /// Returns local values immediately; remote results + /// arrive via `poll()` and can be retrieved with a + /// subsequent `get()` call (they'll be cached + /// locally by `handle_dht_find_value_reply`). + pub fn get(&mut self, key: &[u8]) -> Vec<Vec<u8>> { + let target_id = NodeId::from_key(key); + + // Check local storage first + let local = self.storage.get(&target_id, key); + if !local.is_empty() { + return local.into_iter().map(|v| v.value).collect(); + } + + // Not found locally — start iterative FIND_VALUE + if let Err(e) = self.start_find_value(key) { + log::debug!("Failed to start find_value: {e}"); + } + + Vec::new() + } + + /// Retrieve values with blocking network lookup. + /// + /// Polls internally until the value is found or + /// `timeout` expires. Returns empty if not found. + pub fn get_blocking( + &mut self, + key: &[u8], + timeout: Duration, + ) -> Vec<Vec<u8>> { + let target_id = NodeId::from_key(key); + + // Check local first + let local = self.storage.get(&target_id, key); + if !local.is_empty() { + return local.into_iter().map(|v| v.value).collect(); + } + + // Start FIND_VALUE + if self.start_find_value(key).is_err() { + return Vec::new(); + } + + // Poll until found or timeout + let deadline = Instant::now() + timeout; + while Instant::now() < deadline { + let _ = self.poll(); + + let vals = self.storage.get(&target_id, key); + if !vals.is_empty() { + return vals.into_iter().map(|v| v.value).collect(); + } + + std::thread::sleep(Duration::from_millis(10)); + } + + Vec::new() + } + + /// Start an iterative FIND_VALUE query for a key. + /// + /// Returns the query nonce. Results arrive via + /// `handle_dht_find_value_reply` during `poll()`. + pub fn start_find_value(&mut self, key: &[u8]) -> Result<u32, Error> { + let target_id = NodeId::from_key(key); + let nonce = self.alloc_nonce(); + let mut query = + IterativeQuery::find_value(target_id, key.to_vec(), nonce); + + // Seed with our closest known nodes + let closest = self + .dht_table + .closest(&target_id, self.config.num_find_node); + query.closest = closest; + + self.queries.insert(nonce, query); + + // Send initial batch + self.send_query_batch(nonce)?; + + Ok(nonce) + } + + /// Send a FIND_VALUE message to a specific address. + pub(crate) fn send_find_value_msg( + &mut self, + to: SocketAddr, + target: NodeId, + key: &[u8], + ) -> Result<u32, Error> { + let nonce = self.alloc_nonce(); + let domain = if to.is_ipv4() { + DOMAIN_INET + } else { + DOMAIN_INET6 + }; + + let fv = msg::FindValueMsg { + nonce, + target, + domain, + key: key.to_vec(), + use_rdp: false, + }; + + let total = HEADER_SIZE + msg::FIND_VALUE_FIXED + key.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::DhtFindValue, + Self::len16(total), + self.id, + NodeId::from_bytes([0; crate::id::ID_LEN]), + ); + hdr.write(&mut buf)?; + msg::write_find_value(&mut buf, &fv)?; + self.send_signed(&buf, to)?; + + log::debug!( + "Sent find_value to {to} target={target:?} key={} bytes", + key.len() + ); + Ok(nonce) + } + + // ── Datagram ──────────────────────────────────── + + /// Send a datagram to a destination node. + /// + /// If the destination's address is known, fragments + /// are sent immediately. Otherwise they are queued + /// for delivery once the address is resolved. + pub fn send_dgram(&mut self, data: &[u8], dst: &NodeId) { + let fragments = dgram::fragment(data); + log::debug!( + "send_dgram: {} bytes, {} fragment(s) to {:?}", + data.len(), + fragments.len(), + dst + ); + + // If behind symmetric NAT, route through proxy + if self.nat.state() == NatState::SymmetricNat { + if let Some(server) = self.proxy.server().cloned() { + for frag in &fragments { + let total = HEADER_SIZE + frag.len(); + let mut buf = vec![0u8; total]; + let hdr = MsgHeader::new( + MsgType::ProxyDgram, + Self::len16(total), + self.id, + *dst, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..].copy_from_slice(frag); + let _ = self.send_signed(&buf, server.addr); + } + } + return; + } + } + + // Direct: send if we know the address + if let Some(peer) = self.peers.get(dst).cloned() { + for frag in &fragments { + self.send_dgram_raw(frag, &peer); + } + } else { + // Queue for later delivery + for frag in fragments { + self.send_queue.push(*dst, frag, self.id); + } + } + } + + /// Send a single dgram fragment wrapped in a + /// protocol message. + pub(crate) fn send_dgram_raw(&self, payload: &[u8], dst: &PeerInfo) { + let total = HEADER_SIZE + payload.len(); + let mut buf = vec![0u8; total]; + let hdr = + MsgHeader::new(MsgType::Dgram, Self::len16(total), self.id, dst.id); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..].copy_from_slice(payload); + let _ = self.send_signed(&buf, dst.addr); + } + } + + /// Set the callback for received datagrams. + pub fn set_dgram_callback<F>(&mut self, f: F) + where + F: Fn(&[u8], &NodeId) + Send + 'static, + { + self.dgram_callback = Some(Box::new(f)); + } + + /// Remove the datagram callback. + pub fn unset_dgram_callback(&mut self) { + self.dgram_callback = None; + } + + /// Set a callback for RDP events (ACCEPTED, + /// CONNECTED, READY2READ, RESET, FAILED, etc). + pub fn set_rdp_callback<F>(&mut self, f: F) + where + F: Fn(i32, &crate::rdp::RdpAddr, crate::rdp::RdpEvent) + Send + 'static, + { + self.rdp_callback = Some(Box::new(f)); + } + + /// Remove the RDP event callback. + pub fn unset_rdp_callback(&mut self) { + self.rdp_callback = None; + } + + // ── RDP (reliable transport) ──────────────────── + + /// Listen for RDP connections on `port`. + pub fn rdp_listen(&mut self, port: u16) -> Result<i32, RdpError> { + self.rdp.listen(port) + } + + /// Connect to a remote node via RDP. + /// + /// Sends a SYN packet immediately if the peer's + /// address is known. + pub fn rdp_connect( + &mut self, + sport: u16, + dst: &NodeId, + dport: u16, + ) -> Result<i32, RdpError> { + let desc = self.rdp.connect(sport, *dst, dport)?; + self.flush_rdp_output(desc); + Ok(desc) + } + + /// Close an RDP connection or listener. + pub fn rdp_close(&mut self, desc: i32) { + self.rdp.close(desc); + } + + /// Send data on an RDP connection. + /// + /// Enqueues data and flushes pending packets to the + /// network. + pub fn rdp_send( + &mut self, + desc: i32, + data: &[u8], + ) -> Result<usize, RdpError> { + let n = self.rdp.send(desc, data)?; + self.flush_rdp_output(desc); + Ok(n) + } + + /// Receive data from an RDP connection. + pub fn rdp_recv( + &mut self, + desc: i32, + buf: &mut [u8], + ) -> Result<usize, RdpError> { + self.rdp.recv(desc, buf) + } + + /// Get the state of an RDP descriptor. + pub fn rdp_state(&self, desc: i32) -> Result<RdpState, RdpError> { + self.rdp.get_state(desc) + } + + /// Get status of all RDP connections. + pub fn rdp_status(&self) -> Vec<RdpStatus> { + self.rdp.get_status() + } + + /// Set RDP maximum retransmission timeout. + pub fn rdp_set_max_retrans(&mut self, secs: u64) { + self.rdp.set_max_retrans(Duration::from_secs(secs)); + } + + /// Get RDP maximum retransmission timeout. + pub fn rdp_max_retrans(&self) -> u64 { + self.rdp.max_retrans().as_secs() + } + + // ── Event loop ────────────────────────────────── + + /// Process one iteration of the event loop. + /// + /// Polls for I/O events, processes incoming packets, + /// fires expired timers, and runs maintenance tasks. + pub fn poll(&mut self) -> Result<(), Error> { + self.poll_timeout(DEFAULT_POLL_TIMEOUT) + } + + /// Poll with a custom maximum timeout. + /// + /// Use a short timeout (e.g. 1ms) in tests with + /// many nodes to avoid blocking. + pub fn poll_timeout(&mut self, max_timeout: Duration) -> Result<(), Error> { + let timeout = self + .timers + .next_deadline() + .map(|d| d.min(max_timeout)) + .unwrap_or(max_timeout); + + self.net.poll_events(timeout)?; + + // Check if we got a UDP event. We must not hold + // a borrow on self.net while calling handle_packet, + // so we just check and then drain separately. + let has_udp = self.net.drain_events().any(|ev| ev.token() == UDP_TOKEN); + + if has_udp { + let mut buf = [0u8; 4096]; + while let Ok((len, from)) = self.net.recv_from(&mut buf) { + self.handle_packet(&buf[..len], from); + } + } + + // Fire timers + let _fired = self.timers.tick(); + + // Drive iterative queries + self.drive_queries(); + + // Drain send queue for destinations we now know + self.drain_send_queue(); + + // Drive RDP: tick timeouts and flush output + let rdp_actions = self.rdp.tick(); + for action in rdp_actions { + if let crate::rdp::RdpAction::Event { desc, event, .. } = action { + log::info!("RDP tick event: desc={desc} {event:?}"); + } + } + self.flush_all_rdp(); + + // Periodic maintenance + self.peers.refresh(); + self.storage.expire(); + self.advertise.refresh(); + self.reassembler.expire(); + self.send_queue.expire(); + self.refresh_buckets(); + self.probe_liveness(); + + self.rate_limiter.cleanup(); + + // Node activity monitor (proactive ping) + self.check_node_activity(); + + // Retry failed stores + self.retry_failed_stores(); + + // Ban list cleanup + self.ban_list.cleanup(); + self.store_tracker.cleanup(); + + // Expire stale pending pings (>10s) — record + // failure for unresponsive peers + let expired_pings: Vec<(u32, NodeId)> = self + .pending_pings + .iter() + .filter(|(_, (_, sent))| sent.elapsed().as_secs() >= 10) + .map(|(nonce, (id, _))| (*nonce, *id)) + .collect(); + for (nonce, peer_id) in &expired_pings { + if let Some(peer) = self.peers.get(peer_id) { + self.ban_list.record_failure(peer.addr); + } + // Also record failure in routing table for + // stale count / replacement cache logic + if let Some(evicted) = self.dht_table.record_failure(peer_id) { + log::debug!( + "Replaced stale peer {:?} from routing table", + evicted + ); + } + self.pending_pings.remove(nonce); + } + + // Data restore (every 120s) + if self.last_restore.elapsed() >= self.config.restore_interval { + self.last_restore = Instant::now(); + self.restore_data(); + } + + // Maintain: mask_bit exploration (every 120s) + if self.last_maintain.elapsed() >= self.config.maintain_interval { + self.last_maintain = Instant::now(); + self.run_maintain(); + } + + // DTUN maintenance + let dtun_targets = self.dtun.maintain(); + for target in dtun_targets { + if let Err(e) = self.start_find_node(target) { + log::debug!("DTUN maintain find_node failed: {e}"); + } + } + + // NAT re-detection + self.nat.expire_pending(); + + Ok(()) + } + + /// Run the event loop forever. + pub fn run(&mut self) -> ! { + loop { + if let Err(e) = self.poll() { + log::error!("Event loop error: {e}"); + } + } + } + + /// Set the node configuration. Call before `join()`. + pub fn set_config(&mut self, config: crate::config::Config) { + self.config = config; + } + + /// Get the current configuration. + pub fn config(&self) -> &crate::config::Config { + &self.config + } + + /// Set the routing table persistence backend. + pub fn set_routing_persistence( + &mut self, + p: Box<dyn crate::persist::RoutingPersistence>, + ) { + self.routing_persistence = p; + } + + /// Set the data persistence backend. + pub fn set_data_persistence( + &mut self, + p: Box<dyn crate::persist::DataPersistence>, + ) { + self.data_persistence = p; + } + + /// Load saved contacts and data from persistence + /// backends. Call after bind, before join. + pub fn load_persisted(&mut self) { + // Load routing table contacts + if let Ok(contacts) = self.routing_persistence.load_contacts() { + let mut loaded = 0usize; + for c in contacts { + // Validate: skip zero IDs and unspecified addrs + if c.id.is_zero() { + log::debug!("Skipping persisted contact: zero ID"); + continue; + } + if c.addr.ip().is_unspecified() { + log::debug!("Skipping persisted contact: unspecified addr"); + continue; + } + let peer = PeerInfo::new(c.id, c.addr); + self.dht_table.add(peer.clone()); + self.peers.add(peer); + loaded += 1; + } + log::info!("Loaded {loaded} persisted contacts"); + } + + // Load stored values + if let Ok(records) = self.data_persistence.load() { + for r in &records { + let val = crate::dht::StoredValue { + key: r.key.clone(), + value: r.value.clone(), + id: r.target_id, + source: r.source, + ttl: r.ttl, + stored_at: Instant::now(), + is_unique: r.is_unique, + original: 0, + recvd: std::collections::HashSet::new(), + version: 0, // persisted data has no version + }; + self.storage.store(val); + } + log::info!("Loaded {} persisted values", records.len()); + } + } + + /// Save current state to persistence backends. + /// Called during shutdown or periodically. + pub fn save_state(&self) { + // Save routing table contacts + let contacts: Vec<crate::persist::ContactRecord> = self + .dht_table + .closest(&self.id, 1000) // save all + .iter() + .map(|p| crate::persist::ContactRecord { + id: p.id, + addr: p.addr, + }) + .collect(); + + if let Err(e) = self.routing_persistence.save_contacts(&contacts) { + log::warn!("Failed to save contacts: {e}"); + } + + // Save stored values + let values = self.storage.all_values(); + let records: Vec<crate::persist::StoredRecord> = values + .iter() + .map(|v| crate::persist::StoredRecord { + key: v.key.clone(), + value: v.value.clone(), + target_id: v.id, + source: v.source, + ttl: v.remaining_ttl(), + is_unique: v.is_unique, + }) + .collect(); + + if let Err(e) = self.data_persistence.save(&records) { + log::warn!("Failed to save values: {e}"); + } + } + + /// Graceful shutdown: notify closest peers that we + /// are leaving, so they can remove us from their + /// routing tables immediately. + pub fn shutdown(&mut self) { + log::info!("Shutting down node {}", self.id); + + // Send a "leaving" ping to our closest peers + // so they know to remove us + let closest = + self.dht_table.closest(&self.id, self.config.num_find_node); + + for peer in &closest { + // Send FIN-like notification via advertise + let nonce = self.alloc_nonce(); + let size = HEADER_SIZE + 8; + let mut buf = vec![0u8; size]; + let hdr = MsgHeader::new( + MsgType::Advertise, + Self::len16(size), + self.id, + peer.id, + ); + if hdr.write(&mut buf).is_ok() { + buf[HEADER_SIZE..HEADER_SIZE + 4] + .copy_from_slice(&nonce.to_be_bytes()); + + // Session 0 = shutdown signal + buf[HEADER_SIZE + 4..HEADER_SIZE + 8].fill(0); + let _ = self.send_signed(&buf, peer.addr); + } + } + + log::info!("Shutdown: notified {} peers", closest.len()); + + // Persist state + self.save_state(); + } + + // ── Routing table access ──────────────────────── + + /// Number of peers in the DHT routing table. + pub fn routing_table_size(&self) -> usize { + self.dht_table.size() + } + + /// Number of known peers. + pub fn peer_count(&self) -> usize { + self.peers.len() + } + + /// Snapshot of metrics counters. + pub fn metrics(&self) -> crate::metrics::MetricsSnapshot { + self.metrics.snapshot() + } + + /// Number of stored DHT values. + pub fn storage_count(&self) -> usize { + self.storage.len() + } + + /// All stored DHT values (key, value bytes). + /// Used by applications to sync DHT-replicated data + /// to their own persistence layer. + pub fn dht_values(&self) -> Vec<(Vec<u8>, Vec<u8>)> { + self.storage + .all_values() + .into_iter() + .map(|v| (v.key, v.value)) + .collect() + } + + /// Number of currently banned peers. + pub fn ban_count(&self) -> usize { + self.ban_list.ban_count() + } + + /// Number of pending store operations awaiting ack. + pub fn pending_stores(&self) -> usize { + self.store_tracker.pending_count() + } + + /// Store tracker statistics: (acks, failures). + pub fn store_stats(&self) -> (u64, u64) { + (self.store_tracker.acks, self.store_tracker.failures) + } + + /// Print node state (debug). + pub fn print_state(&self) { + println!("MyID = {}", self.id); + println!(); + println!("Node State:"); + println!(" {:?}", self.nat_state()); + println!(); + println!("Routing Table: {} nodes", self.dht_table.size()); + self.dht_table.print_table(); + println!(); + println!("DTUN: {} registrations", self.dtun.registration_count()); + println!("Peers: {} known", self.peers.len()); + println!("Storage: {} values", self.storage.len()); + println!("Bans: {} active", self.ban_list.ban_count()); + println!( + "Stores: {} pending, {} acked, {} failed", + self.store_tracker.pending_count(), + self.store_tracker.acks, + self.store_tracker.failures, + ); + println!( + "RDP: {} connections, {} listeners", + self.rdp.connection_count(), + self.rdp.listener_count() + ); + } +} + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Node({}, {:?}, {} peers, {} stored)", + self.id, + self.nat_state(), + self.dht_table.size(), + self.storage.len() + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bind_creates_node() { + let node = Node::bind(0).unwrap(); + assert!(!node.id().is_zero()); + assert_eq!(node.routing_table_size(), 0); + assert_eq!(node.nat_state(), NatState::Unknown); + } + + #[test] + fn set_id_from_data() { + let mut node1 = Node::bind(0).unwrap(); + let mut node2 = Node::bind(0).unwrap(); + + let old_id = *node1.id(); + node1.set_id(b"my-application-id"); + assert_ne!(*node1.id(), old_id); + + // Deterministic: same seed → same ID + node2.set_id(b"my-application-id"); + assert_eq!(*node1.id(), *node2.id()); + + // Public key matches NodeId + assert_eq!(node1.public_key(), node1.id().as_bytes()); + } + + #[test] + fn id_hex_format() { + let node = Node::bind(0).unwrap(); + let hex = node.id_hex(); + assert_eq!(hex.len(), 64); + assert!(hex.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn nat_state_set_get() { + let mut node = Node::bind(0).unwrap(); + assert_eq!(node.nat_state(), NatState::Unknown); + + node.set_nat_state(NatState::Global); + assert_eq!(node.nat_state(), NatState::Global); + + node.set_nat_state(NatState::ConeNat); + assert_eq!(node.nat_state(), NatState::ConeNat); + + node.set_nat_state(NatState::SymmetricNat); + assert_eq!(node.nat_state(), NatState::SymmetricNat); + } + + #[test] + fn put_get_local() { + let mut node = Node::bind(0).unwrap(); + node.put(b"hello", b"world", 300, false); + + let vals = node.get(b"hello"); + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], b"world"); + } + + #[test] + fn put_unique() { + let mut node = Node::bind(0).unwrap(); + node.put(b"key", b"val1", 300, true); + node.put(b"key", b"val2", 300, true); + + let vals = node.get(b"key"); + + // Unique from same source → replaced + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], b"val2"); + } + + #[test] + fn get_nonexistent() { + let mut node = Node::bind(0).unwrap(); + let vals = node.get(b"nope"); + assert!(vals.is_empty()); + } + + #[test] + fn rdp_listen_and_close() { + let mut node = Node::bind(0).unwrap(); + let _desc = node.rdp_listen(5000).unwrap(); + + // Listener desc is not a connection, so + // rdp_state won't find it. Just verify close + // doesn't panic. + node.rdp_close(_desc); + } + + #[test] + fn rdp_connect_creates_syn() { + let mut node = Node::bind(0).unwrap(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = node.rdp_connect(0, &dst, 5000).unwrap(); + assert_eq!(node.rdp_state(desc).unwrap(), RdpState::SynSent); + } + + #[test] + fn rdp_status() { + let mut node = Node::bind(0).unwrap(); + let dst = NodeId::from_bytes([0x01; 32]); + node.rdp_connect(0, &dst, 5000).unwrap(); + let status = node.rdp_status(); + assert_eq!(status.len(), 1); + } + + #[test] + fn rdp_max_retrans() { + let mut node = Node::bind(0).unwrap(); + node.rdp_set_max_retrans(60); + assert_eq!(node.rdp_max_retrans(), 60); + } + + #[test] + fn display() { + let node = Node::bind(0).unwrap(); + let s = format!("{node}"); + assert!(s.starts_with("Node(")); + assert!(s.contains("Unknown")); + } + + #[test] + fn dgram_callback() { + let mut node = Node::bind(0).unwrap(); + assert!(node.dgram_callback.is_none()); + node.set_dgram_callback(|_data, _from| {}); + assert!(node.dgram_callback.is_some()); + node.unset_dgram_callback(); + assert!(node.dgram_callback.is_none()); + } + + #[test] + fn join_with_invalid_host() { + let mut node = Node::bind(0).unwrap(); + let result = node.join("this-host-does-not-exist.invalid", 3000); + assert!(result.is_err()); + } + + #[test] + fn poll_once() { + let mut node = Node::bind(0).unwrap(); + + // Should not block forever (default 1s timeout, + // but returns quickly with no events) + node.poll().unwrap(); + } +} diff --git a/src/peers.rs b/src/peers.rs new file mode 100644 index 0000000..e575323 --- /dev/null +++ b/src/peers.rs @@ -0,0 +1,337 @@ +//! Peer node database. +//! +//! Bidirectional peer map with TTL-based expiry and +//! timeout tracking. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +/// TTL for peer entries (5 min). +pub const PEERS_MAP_TTL: Duration = Duration::from_secs(300); + +/// TTL for timeout entries (30s). +pub const PEERS_TIMEOUT_TTL: Duration = Duration::from_secs(30); + +/// Cleanup timer interval (30s). +pub const PEERS_TIMER_INTERVAL: Duration = Duration::from_secs(30); + +/// NAT state of a peer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NatState { + Unknown, + Global, + Nat, + ConeNat, + SymmetricNat, +} + +/// Information about a known peer. +#[derive(Debug, Clone)] +pub struct PeerInfo { + pub id: NodeId, + pub addr: SocketAddr, + pub domain: u16, + pub nat_state: NatState, + pub last_seen: Instant, + pub session: u32, + + /// Ed25519 public key (if known). Since NodeId = + /// pubkey, this is always available when we know + /// the peer's ID. + pub public_key: Option<[u8; 32]>, +} + +impl PeerInfo { + pub fn new(id: NodeId, addr: SocketAddr) -> Self { + // NodeId IS the public key (32 bytes) + let public_key = Some(*id.as_bytes()); + Self { + id, + addr, + domain: if addr.is_ipv4() { 1 } else { 2 }, + nat_state: NatState::Unknown, + last_seen: Instant::now(), + session: 0, + public_key, + } + } +} + +/// Database of known peers with TTL-based expiry. +/// +/// Provides forward lookup (id -> info) and reverse +/// lookup (addr -> ids). +type PeerCallback = Box<dyn Fn(&PeerInfo)>; + +/// Maximum number of tracked peers (prevents OOM). +const MAX_PEERS: usize = 10_000; + +pub struct PeerStore { + by_id: HashMap<NodeId, PeerInfo>, + by_addr: HashMap<SocketAddr, Vec<NodeId>>, + timeouts: HashMap<NodeId, Instant>, + on_add: Option<PeerCallback>, +} + +impl PeerStore { + pub fn new() -> Self { + Self { + by_id: HashMap::new(), + by_addr: HashMap::new(), + timeouts: HashMap::new(), + on_add: None, + } + } + + /// Get peer info by ID. + pub fn get(&self, id: &NodeId) -> Option<&PeerInfo> { + self.by_id.get(id) + } + + /// Get all peer IDs associated with an address. + pub fn ids_for_addr(&self, addr: &SocketAddr) -> Vec<NodeId> { + self.by_addr.get(addr).cloned().unwrap_or_default() + } + + /// Add a peer, checking for duplicates. + /// + /// If the peer already exists, updates `last_seen`. + /// Returns `true` if newly added. + pub fn add(&mut self, peer: PeerInfo) -> bool { + let id = peer.id; + let addr = peer.addr; + + // Limit check (updates are always allowed) + if self.by_id.len() >= MAX_PEERS && !self.by_id.contains_key(&id) { + return false; + } + + if let Some(existing) = self.by_id.get_mut(&id) { + existing.last_seen = Instant::now(); + existing.addr = addr; + return false; + } + + self.by_addr.entry(addr).or_default().push(id); + if let Some(ref cb) = self.on_add { + cb(&peer); + } + self.by_id.insert(id, peer); + true + } + + /// Add a peer with a session ID (for DTUN register). + /// + /// Returns `true` if the session matches or is new. + pub fn add_with_session(&mut self, peer: PeerInfo, session: u32) -> bool { + if let Some(existing) = self.by_id.get(&peer.id) { + if existing.session != session { + return false; + } + } + let mut peer = peer; + peer.session = session; + self.add(peer); + true + } + + /// Add a peer, overwriting any existing entry. + pub fn add_force(&mut self, peer: PeerInfo) { + self.remove(&peer.id); + self.add(peer); + } + + /// Remove a peer by ID. + pub fn remove(&mut self, id: &NodeId) -> Option<PeerInfo> { + if let Some(peer) = self.by_id.remove(id) { + if let Some(ids) = self.by_addr.get_mut(&peer.addr) { + ids.retain(|i| i != id); + if ids.is_empty() { + self.by_addr.remove(&peer.addr); + } + } + self.timeouts.remove(id); + Some(peer) + } else { + None + } + } + + /// Remove all peers associated with an address. + pub fn remove_addr(&mut self, addr: &SocketAddr) { + if let Some(ids) = self.by_addr.remove(addr) { + for id in ids { + self.by_id.remove(&id); + self.timeouts.remove(&id); + } + } + } + + /// Mark a peer as having timed out. + pub fn mark_timeout(&mut self, id: &NodeId) { + self.timeouts.insert(*id, Instant::now()); + } + + /// Check if a peer is in timeout state. + pub fn is_timeout(&self, id: &NodeId) -> bool { + if let Some(t) = self.timeouts.get(id) { + t.elapsed() < PEERS_TIMEOUT_TTL + } else { + false + } + } + + /// Remove expired peers (older than MAP_TTL) and + /// stale timeout entries. + pub fn refresh(&mut self) { + let now = Instant::now(); + + // Remove expired peers + let expired: Vec<NodeId> = self + .by_id + .iter() + .filter(|(_, p)| now.duration_since(p.last_seen) >= PEERS_MAP_TTL) + .map(|(id, _)| *id) + .collect(); + + for id in expired { + self.remove(&id); + } + + // Remove stale timeout entries + self.timeouts + .retain(|_, t| now.duration_since(*t) < PEERS_TIMEOUT_TTL); + } + + /// Set a callback for when a peer is added. + pub fn on_add(&mut self, f: impl Fn(&PeerInfo) + 'static) { + self.on_add = Some(Box::new(f)); + } + + /// Number of known peers. + pub fn len(&self) -> usize { + self.by_id.len() + } + + /// Check if the store is empty. + pub fn is_empty(&self) -> bool { + self.by_id.is_empty() + } + + /// Iterate over all peers. + pub fn iter(&self) -> impl Iterator<Item = &PeerInfo> { + self.by_id.values() + } + + /// Get all peer IDs. + pub fn ids(&self) -> Vec<NodeId> { + self.by_id.keys().copied().collect() + } +} + +impl Default for PeerStore { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn peer(byte: u8, port: u16) -> PeerInfo { + PeerInfo::new( + NodeId::from_bytes([byte; 32]), + SocketAddr::from(([127, 0, 0, 1], port)), + ) + } + + #[test] + fn add_and_get() { + let mut store = PeerStore::new(); + let p = peer(1, 3000); + assert!(store.add(p.clone())); + assert_eq!(store.len(), 1); + assert_eq!(store.get(&p.id).unwrap().addr, p.addr); + } + + #[test] + fn add_duplicate_updates() { + let mut store = PeerStore::new(); + let p = peer(1, 3000); + assert!(store.add(p.clone())); + assert!(!store.add(p)); // duplicate + assert_eq!(store.len(), 1); + } + + #[test] + fn remove_by_id() { + let mut store = PeerStore::new(); + let p = peer(1, 3000); + store.add(p.clone()); + store.remove(&p.id); + assert!(store.is_empty()); + } + + #[test] + fn reverse_lookup() { + let mut store = PeerStore::new(); + let addr: SocketAddr = "127.0.0.1:3000".parse().unwrap(); + let p1 = PeerInfo::new(NodeId::from_bytes([1; 32]), addr); + let p2 = PeerInfo::new(NodeId::from_bytes([2; 32]), addr); + store.add(p1.clone()); + store.add(p2.clone()); + + let ids = store.ids_for_addr(&addr); + assert_eq!(ids.len(), 2); + assert!(ids.contains(&p1.id)); + assert!(ids.contains(&p2.id)); + } + + #[test] + fn remove_addr_removes_all() { + let mut store = PeerStore::new(); + let addr: SocketAddr = "127.0.0.1:3000".parse().unwrap(); + store.add(PeerInfo::new(NodeId::from_bytes([1; 32]), addr)); + store.add(PeerInfo::new(NodeId::from_bytes([2; 32]), addr)); + store.remove_addr(&addr); + assert!(store.is_empty()); + } + + #[test] + fn timeout_tracking() { + let mut store = PeerStore::new(); + let id = NodeId::from_bytes([1; 32]); + assert!(!store.is_timeout(&id)); + store.mark_timeout(&id); + assert!(store.is_timeout(&id)); + } + + #[test] + fn add_with_session() { + let mut store = PeerStore::new(); + let p = peer(1, 3000); + assert!(store.add_with_session(p.clone(), 42)); + + // Same session: ok + assert!(store.add_with_session(peer(1, 3001), 42)); + + // Different session: rejected + assert!(!store.add_with_session(peer(1, 3002), 99)); + } + + #[test] + fn add_force_overwrites() { + let mut store = PeerStore::new(); + store.add(peer(1, 3000)); + store.add_force(peer(1, 4000)); + assert_eq!(store.len(), 1); + assert_eq!( + store.get(&NodeId::from_bytes([1; 32])).unwrap().addr.port(), + 4000 + ); + } +} diff --git a/src/persist.rs b/src/persist.rs new file mode 100644 index 0000000..8e733b0 --- /dev/null +++ b/src/persist.rs @@ -0,0 +1,84 @@ +//! Persistence traits for data and routing table. +//! +//! The library defines the traits; applications +//! implement backends (SQLite, file, etc). + +use crate::error::Error; +use crate::id::NodeId; +use std::net::SocketAddr; + +/// Stored value record for persistence. +#[derive(Debug, Clone)] +pub struct StoredRecord { + pub key: Vec<u8>, + pub value: Vec<u8>, + pub target_id: NodeId, + pub source: NodeId, + pub ttl: u16, + pub is_unique: bool, +} + +/// Contact record for routing table persistence. +#[derive(Debug, Clone)] +pub struct ContactRecord { + pub id: NodeId, + pub addr: SocketAddr, +} + +/// Trait for persisting DHT stored values. +/// +/// Implement this to survive restarts. The library +/// calls `save` periodically and `load` on startup. +pub trait DataPersistence { + /// Save all stored values. + fn save(&self, records: &[StoredRecord]) -> Result<(), Error>; + + /// Load previously saved values. + fn load(&self) -> Result<Vec<StoredRecord>, Error>; +} + +/// Trait for persisting the routing table. +/// +/// Implement this for fast re-bootstrap after restart. +pub trait RoutingPersistence { + /// Save known contacts. + fn save_contacts(&self, contacts: &[ContactRecord]) -> Result<(), Error>; + + /// Load previously saved contacts. + fn load_contacts(&self) -> Result<Vec<ContactRecord>, Error>; +} + +/// No-op persistence (default — no persistence). +pub struct NoPersistence; + +impl DataPersistence for NoPersistence { + fn save(&self, _records: &[StoredRecord]) -> Result<(), Error> { + Ok(()) + } + fn load(&self) -> Result<Vec<StoredRecord>, Error> { + Ok(Vec::new()) + } +} + +impl RoutingPersistence for NoPersistence { + fn save_contacts(&self, _contacts: &[ContactRecord]) -> Result<(), Error> { + Ok(()) + } + fn load_contacts(&self) -> Result<Vec<ContactRecord>, Error> { + Ok(Vec::new()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_persistence_save_load() { + let p = NoPersistence; + assert!(p.save(&[]).is_ok()); + assert!(p.load().unwrap().is_empty()); + assert!(p.save_contacts(&[]).is_ok()); + assert!(p.load_contacts().unwrap().is_empty()); + } +} diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..aaa3827 --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,370 @@ +//! Proxy relay for symmetric NAT nodes. +//! +//! When a node is +//! behind a symmetric NAT (where hole-punching fails), +//! it registers with a global node that acts as a relay: +//! +//! - **Store**: the NAT'd node sends store requests to +//! its proxy, which forwards them to the DHT. +//! - **Get**: the NAT'd node sends get requests to its +//! proxy, which performs the lookup and returns results. +//! - **Dgram**: datagrams to/from NAT'd nodes are +//! forwarded through the proxy. +//! - **RDP**: reliable transport is also tunnelled +//! through the proxy (type_proxy_rdp). + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; +use crate::peers::PeerInfo; + +// ── Constants ──────────────────────────────────────── + +/// Registration timeout. +pub const PROXY_REGISTER_TIMEOUT: Duration = Duration::from_secs(2); + +/// Registration TTL. +pub const PROXY_REGISTER_TTL: Duration = Duration::from_secs(300); + +/// Get request timeout. +pub const PROXY_GET_TIMEOUT: Duration = Duration::from_secs(10); + +/// Maintenance timer interval. +pub const PROXY_TIMER_INTERVAL: Duration = Duration::from_secs(30); + +/// RDP timeout for proxy connections. +pub const PROXY_RDP_TIMEOUT: Duration = Duration::from_secs(30); + +/// RDP port for proxy store. +pub const PROXY_STORE_PORT: u16 = 200; + +/// RDP port for proxy get. +pub const PROXY_GET_PORT: u16 = 201; + +/// RDP port for proxy get reply. +pub const PROXY_GET_REPLY_PORT: u16 = 202; + +// ── Client state ──────────────────────────────────── + +/// A node registered as a client of this proxy. +#[derive(Debug, Clone)] +pub struct ProxyClient { + /// Client's address (behind NAT). + pub addr: SocketAddr, + + /// DTUN session for validation. + pub session: u32, + + /// When the client last registered/refreshed. + pub registered_at: Instant, +} + +impl ProxyClient { + pub fn is_expired(&self) -> bool { + self.registered_at.elapsed() >= PROXY_REGISTER_TTL + } +} + +// ── Pending get state ─────────────────────────────── + +/// State for a pending proxy get request. +#[derive(Debug)] +pub struct PendingGet { + /// Key being looked up. + pub key: Vec<u8>, + + /// Nonce for correlation. + pub nonce: u32, + + /// When the request was sent. + pub sent_at: Instant, + + /// Whether we've received a result. + pub completed: bool, + + /// Collected values. + pub values: Vec<Vec<u8>>, +} + +// ── Proxy ─────────────────────────────────────────── + +/// Proxy relay for symmetric NAT traversal. +/// +/// Acts in two roles: +/// - **Client**: when we're behind symmetric NAT, we +/// register with a proxy server and route operations +/// through it. +/// - **Server**: when we have a public IP, we accept +/// registrations from NAT'd nodes and relay their +/// traffic. +pub struct Proxy { + /// Our proxy server (when we're a client). + server: Option<PeerInfo>, + + /// Whether we've successfully registered with our + /// proxy server. + is_registered: bool, + + /// Whether registration is in progress. + is_registering: bool, + + /// Nonce for registration correlation. + register_nonce: u32, + + /// Clients registered with us. Capped at 500. + clients: HashMap<NodeId, ProxyClient>, + + /// Pending get requests. Capped at 100. + pending_gets: HashMap<u32, PendingGet>, +} + +impl Proxy { + pub fn new(_id: NodeId) -> Self { + Self { + server: None, + is_registered: false, + is_registering: false, + register_nonce: 0, + clients: HashMap::new(), + pending_gets: HashMap::new(), + } + } + + // ── Client role ───────────────────────────────── + + /// Set the proxy server to use. + pub fn set_server(&mut self, server: PeerInfo) { + self.server = Some(server); + self.is_registered = false; + } + + /// Get the current proxy server. + pub fn server(&self) -> Option<&PeerInfo> { + self.server.as_ref() + } + + /// Whether we're registered with a proxy server. + pub fn is_registered(&self) -> bool { + self.is_registered + } + + /// Start proxy registration. Returns the nonce to + /// include in the register message. + pub fn start_register(&mut self, nonce: u32) -> Option<u32> { + self.server.as_ref()?; + self.is_registering = true; + self.register_nonce = nonce; + Some(nonce) + } + + /// Handle registration reply. + pub fn recv_register_reply(&mut self, nonce: u32) -> bool { + if nonce != self.register_nonce { + return false; + } + self.is_registering = false; + self.is_registered = true; + log::info!("Registered with proxy server"); + true + } + + // ── Server role ───────────────────────────────── + + /// Register a client node (we act as their proxy). + pub fn register_client( + &mut self, + id: NodeId, + addr: SocketAddr, + session: u32, + ) -> bool { + const MAX_CLIENTS: usize = 500; + if self.clients.len() >= MAX_CLIENTS && !self.clients.contains_key(&id) + { + return false; + } + if let Some(existing) = self.clients.get(&id) { + if existing.session != session && !existing.is_expired() { + return false; + } + } + self.clients.insert( + id, + ProxyClient { + addr, + session, + registered_at: Instant::now(), + }, + ); + log::debug!("Proxy: registered client {id:?}"); + true + } + + /// Check if a node is registered as our client. + pub fn is_client_registered(&self, id: &NodeId) -> bool { + self.clients + .get(id) + .map(|c| !c.is_expired()) + .unwrap_or(false) + } + + /// Get a registered client's info. + pub fn get_client(&self, id: &NodeId) -> Option<&ProxyClient> { + self.clients.get(id).filter(|c| !c.is_expired()) + } + + /// Number of active proxy clients. + pub fn client_count(&self) -> usize { + self.clients.values().filter(|c| !c.is_expired()).count() + } + + // ── Pending get management ────────────────────── + + /// Start a proxied get request. + pub fn start_get(&mut self, nonce: u32, key: Vec<u8>) { + const MAX_GETS: usize = 100; + if self.pending_gets.len() >= MAX_GETS { + log::debug!("Proxy: too many pending gets"); + return; + } + self.pending_gets.insert( + nonce, + PendingGet { + key, + nonce, + sent_at: Instant::now(), + completed: false, + values: Vec::new(), + }, + ); + } + + /// Add a value to a pending get. + pub fn add_get_value(&mut self, nonce: u32, value: Vec<u8>) -> bool { + if let Some(pg) = self.pending_gets.get_mut(&nonce) { + pg.values.push(value); + true + } else { + false + } + } + + /// Complete a pending get request. Returns the + /// collected values. + pub fn complete_get(&mut self, nonce: u32) -> Option<Vec<Vec<u8>>> { + self.pending_gets.remove(&nonce).map(|pg| pg.values) + } + + // ── Maintenance ───────────────────────────────── + + /// Remove expired clients and timed-out gets. + pub fn refresh(&mut self) { + self.clients.retain(|_, c| !c.is_expired()); + + self.pending_gets + .retain(|_, pg| pg.sent_at.elapsed() < PROXY_GET_TIMEOUT); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn addr(port: u16) -> SocketAddr { + SocketAddr::from(([127, 0, 0, 1], port)) + } + + fn peer(byte: u8, port: u16) -> PeerInfo { + PeerInfo::new(NodeId::from_bytes([byte; 32]), addr(port)) + } + + // ── Client tests ──────────────────────────────── + + #[test] + fn client_register_flow() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + + assert!(!proxy.is_registered()); + + proxy.set_server(peer(0x02, 5000)); + let nonce = proxy.start_register(42).unwrap(); + assert_eq!(nonce, 42); + + // Wrong nonce + assert!(!proxy.recv_register_reply(99)); + assert!(!proxy.is_registered()); + + // Correct nonce + assert!(proxy.recv_register_reply(42)); + assert!(proxy.is_registered()); + } + + #[test] + fn client_no_server() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + assert!(proxy.start_register(1).is_none()); + } + + // ── Server tests ──────────────────────────────── + + #[test] + fn server_register_client() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + let client_id = NodeId::from_bytes([0x02; 32]); + + assert!(proxy.register_client(client_id, addr(3000), 1)); + assert!(proxy.is_client_registered(&client_id)); + assert_eq!(proxy.client_count(), 1); + } + + #[test] + fn server_rejects_different_session() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + let client_id = NodeId::from_bytes([0x02; 32]); + + proxy.register_client(client_id, addr(3000), 1); + assert!(!proxy.register_client(client_id, addr(3001), 2)); + } + + #[test] + fn server_expire_clients() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + let client_id = NodeId::from_bytes([0x02; 32]); + + proxy.clients.insert( + client_id, + ProxyClient { + addr: addr(3000), + session: 1, + registered_at: Instant::now() - Duration::from_secs(600), + }, + ); + + proxy.refresh(); + assert_eq!(proxy.client_count(), 0); + } + + // ── Pending get tests ─────────────────────────── + + #[test] + fn pending_get_flow() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + + proxy.start_get(42, b"mykey".to_vec()); + assert!(proxy.add_get_value(42, b"val1".to_vec())); + assert!(proxy.add_get_value(42, b"val2".to_vec())); + + let vals = proxy.complete_get(42).unwrap(); + assert_eq!(vals.len(), 2); + assert_eq!(vals[0], b"val1"); + assert_eq!(vals[1], b"val2"); + } + + #[test] + fn pending_get_unknown_nonce() { + let mut proxy = Proxy::new(NodeId::from_bytes([0x01; 32])); + assert!(!proxy.add_get_value(999, b"v".to_vec())); + assert!(proxy.complete_get(999).is_none()); + } +} diff --git a/src/ratelimit.rs b/src/ratelimit.rs new file mode 100644 index 0000000..691b30f --- /dev/null +++ b/src/ratelimit.rs @@ -0,0 +1,136 @@ +//! Per-IP rate limiting with token bucket algorithm. +//! +//! Limits the number of inbound messages processed per +//! source IP address per second. Prevents a single +//! source from overwhelming the node. + +use std::collections::HashMap; +use std::net::IpAddr; +use std::time::Instant; + +/// Default rate: 50 messages per second per IP. +pub const DEFAULT_RATE: f64 = 50.0; + +/// Default burst: 100 messages. +pub const DEFAULT_BURST: u32 = 100; + +/// Stale bucket cleanup threshold. +const STALE_SECS: u64 = 60; + +struct Bucket { + tokens: f64, + last_refill: Instant, +} + +/// Per-IP token bucket rate limiter. +pub struct RateLimiter { + buckets: HashMap<IpAddr, Bucket>, + rate: f64, + burst: u32, + last_cleanup: Instant, +} + +impl RateLimiter { + /// Create a new rate limiter. + /// + /// - `rate`: tokens added per second per IP. + /// - `burst`: maximum tokens (burst capacity). + pub fn new(rate: f64, burst: u32) -> Self { + Self { + buckets: HashMap::new(), + rate, + burst, + last_cleanup: Instant::now(), + } + } + + /// Check if a message from `ip` should be allowed. + /// + /// Returns `true` if allowed (token consumed), + /// `false` if rate-limited (drop the message). + pub fn allow(&mut self, ip: IpAddr) -> bool { + let now = Instant::now(); + let burst = self.burst as f64; + let rate = self.rate; + + let bucket = self.buckets.entry(ip).or_insert(Bucket { + tokens: burst, + last_refill: now, + }); + + // Refill tokens based on elapsed time + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + bucket.tokens = (bucket.tokens + elapsed * rate).min(burst); + bucket.last_refill = now; + + if bucket.tokens >= 1.0 { + bucket.tokens -= 1.0; + true + } else { + false + } + } + + /// Remove stale buckets (no activity for 60s). + pub fn cleanup(&mut self) { + if self.last_cleanup.elapsed().as_secs() < STALE_SECS { + return; + } + self.last_cleanup = Instant::now(); + let cutoff = Instant::now(); + self.buckets.retain(|_, b| { + cutoff.duration_since(b.last_refill).as_secs() < STALE_SECS + }); + } + + /// Number of tracked IPs. + pub fn tracked_count(&self) -> usize { + self.buckets.len() + } +} + +impl Default for RateLimiter { + fn default() -> Self { + Self::new(DEFAULT_RATE, DEFAULT_BURST) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allow_within_burst() { + let mut rl = RateLimiter::new(10.0, 5); + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + + // First 5 should be allowed (burst) + for _ in 0..5 { + assert!(rl.allow(ip)); + } + + // 6th should be denied + assert!(!rl.allow(ip)); + } + + #[test] + fn different_ips_independent() { + let mut rl = RateLimiter::new(1.0, 1); + let ip1: IpAddr = "1.2.3.4".parse().unwrap(); + let ip2: IpAddr = "5.6.7.8".parse().unwrap(); + assert!(rl.allow(ip1)); + assert!(rl.allow(ip2)); + + // Both exhausted + assert!(!rl.allow(ip1)); + assert!(!rl.allow(ip2)); + } + + #[test] + fn tracked_count() { + let mut rl = RateLimiter::default(); + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + rl.allow(ip); + assert_eq!(rl.tracked_count(), 1); + } +} diff --git a/src/rdp.rs b/src/rdp.rs new file mode 100644 index 0000000..96de51d --- /dev/null +++ b/src/rdp.rs @@ -0,0 +1,1343 @@ +//! Reliable Datagram Protocol (RDP). +//! +//! Provides TCP-like +//! reliable, ordered delivery over UDP with: +//! +//! - 7-state connection machine +//! - 3-way handshake (SYN / SYN-ACK / ACK) +//! - Sliding send and receive windows +//! - Cumulative ACK + Extended ACK (EACK/SACK) +//! - Delayed ACK (300ms) +//! - Retransmission (300ms timer) +//! - FIN-based graceful close +//! - RST for abrupt termination +//! +//! **No congestion control**. + +use std::collections::{HashMap, VecDeque}; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; + +// ── Constants ──────────────────────────────────────── + +pub const RDP_FLAG_SYN: u8 = 0x80; +pub const RDP_FLAG_ACK: u8 = 0x40; +pub const RDP_FLAG_EAK: u8 = 0x20; +pub const RDP_FLAG_RST: u8 = 0x10; +pub const RDP_FLAG_NUL: u8 = 0x08; +pub const RDP_FLAG_FIN: u8 = 0x04; + +pub const RDP_RBUF_MAX_DEFAULT: u32 = 884; +pub const RDP_RCV_MAX_DEFAULT: u32 = 1024; +pub const RDP_WELL_KNOWN_PORT_MAX: u16 = 1024; +pub const RDP_SBUF_LIMIT: u16 = 884; +pub const RDP_TIMER_INTERVAL: Duration = Duration::from_millis(300); +pub const RDP_ACK_INTERVAL: Duration = Duration::from_millis(300); +pub const RDP_DEFAULT_MAX_RETRANS: Duration = Duration::from_secs(30); + +/// Generate a random initial sequence number to prevent +/// sequence prediction attacks. +fn random_isn() -> u32 { + let mut buf = [0u8; 4]; + crate::sys::random_bytes(&mut buf); + u32::from_ne_bytes(buf) +} + +/// RDP packet header (20 bytes on the wire). +pub const RDP_HEADER_SIZE: usize = 20; + +// ── Connection state ──────────────────────────────── + +/// RDP connection state (7 states). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RdpState { + Closed, + Listen, + SynSent, + SynRcvd, + Open, + CloseWaitPassive, + CloseWaitActive, +} + +/// Events emitted to the application. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RdpEvent { + /// Server accepted a new connection. + Accepted, + + /// Client connected successfully. + Connected, + + /// Connection refused by peer. + Refused, + + /// Connection reset by peer. + Reset, + + /// Connection failed (timeout). + Failed, + + /// Data available to read. + Ready2Read, + + /// Pipe broken (peer vanished). + Broken, +} + +/// RDP connection address. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RdpAddr { + pub did: NodeId, + pub dport: u16, + pub sport: u16, +} + +/// Status of a single RDP connection. +#[derive(Debug, Clone)] +pub struct RdpStatus { + pub state: RdpState, + pub did: NodeId, + pub dport: u16, + pub sport: u16, +} + +// ── Segment ───────────────────────────────────────── + +/// A segment in the send window. +#[derive(Debug, Clone)] +struct SendSegment { + data: Vec<u8>, + seqnum: u32, + sent_time: Option<Instant>, + is_sent: bool, + is_acked: bool, + + /// Retransmission timeout (doubles on each retry). + rt_secs: u64, +} + +/// A segment in the receive window. +#[derive(Debug, Clone)] +struct RecvSegment { + data: Vec<u8>, + seqnum: u32, + is_used: bool, + is_eacked: bool, +} + +// ── Connection ────────────────────────────────────── + +/// A single RDP connection. +struct RdpConnection { + addr: RdpAddr, + desc: i32, + state: RdpState, + + /// Server (passive) or client (active) side. + is_passive: bool, + + /// Connection has been closed locally. + is_closed: bool, + + /// Max segment size we can send (from peer's SYN). + sbuf_max: u32, + + /// Max segment size we can receive (our buffer). + rbuf_max: u32, + + // Send sequence variables + snd_nxt: u32, + snd_una: u32, + snd_max: u32, + snd_iss: u32, + + // Receive sequence variables + rcv_cur: u32, + + /// Max segments we can buffer. + rcv_max: u32, + rcv_irs: u32, + + /// Last sequence number we ACK'd. + rcv_ack: u32, + + // Windows + send_window: VecDeque<SendSegment>, + recv_window: Vec<Option<RecvSegment>>, + read_queue: VecDeque<Vec<u8>>, + + // Timing + last_ack_time: Instant, + syn_time: Option<Instant>, + close_time: Option<Instant>, + + /// SYN retry timeout (doubles on retry). + syn_rt_secs: u64, + + // ── RTT estimation (Jacobson/Karels) ──────── + /// Smoothed RTT estimate (microseconds). + srtt_us: u64, + + /// RTT variation (microseconds). + rttvar_us: u64, + + /// Retransmission timeout (microseconds). + rto_us: u64, + + // ── Congestion control (AIMD) ─────────────── + /// Congestion window (segments allowed in flight). + cwnd: u32, + + /// Slow-start threshold. + ssthresh: u32, + + /// RST retry state. + rst_time: Option<Instant>, + rst_rt_secs: u64, + is_retry_rst: bool, + + // Out-of-order tracking for EACK + rcvd_seqno: Vec<u32>, +} + +impl RdpConnection { + fn new(desc: i32, addr: RdpAddr, is_passive: bool) -> Self { + Self { + addr, + desc, + state: RdpState::Closed, + is_passive, + is_closed: false, + sbuf_max: RDP_SBUF_LIMIT as u32, + rbuf_max: RDP_RBUF_MAX_DEFAULT, + snd_nxt: 0, + snd_una: 0, + snd_max: RDP_RCV_MAX_DEFAULT, + snd_iss: 0, + rcv_cur: 0, + rcv_max: RDP_RCV_MAX_DEFAULT, + rcv_irs: 0, + rcv_ack: 0, + send_window: VecDeque::new(), + recv_window: Vec::new(), + read_queue: VecDeque::new(), + last_ack_time: Instant::now(), + syn_time: None, + close_time: None, + syn_rt_secs: 1, + + // Jacobson/Karels: initial RTO = 1s + srtt_us: 0, + rttvar_us: 500_000, // 500ms initial variance + rto_us: 1_000_000, // 1s initial RTO + + // AIMD congestion control + cwnd: 1, // start with 1 segment + ssthresh: RDP_RCV_MAX_DEFAULT, + rst_time: None, + rst_rt_secs: 1, + is_retry_rst: false, + rcvd_seqno: Vec::new(), + } + } + + /// Enqueue data for sending. + fn enqueue_send(&mut self, data: &[u8]) -> bool { + if self.is_closed { + return false; + } + if data.len() > self.sbuf_max as usize { + return false; + } + if self.send_window.len() >= self.snd_max as usize { + return false; + } + let seg = SendSegment { + data: data.to_vec(), + seqnum: self.snd_nxt, + sent_time: None, + is_sent: false, + is_acked: false, + rt_secs: 1, + }; + self.snd_nxt = self.snd_nxt.wrapping_add(1); + self.send_window.push_back(seg); + true + } + + /// Process a cumulative ACK. + fn recv_ack(&mut self, acknum: u32) { + while let Some(front) = self.send_window.front() { + if front.seqnum == acknum { + break; + } + + // Sequence numbers before acknum are acked + if is_before(front.seqnum, acknum) { + let seg = self.send_window.pop_front().unwrap(); + self.snd_una = self.snd_una.wrapping_add(1); + self.on_ack_received(); + + // Measure RTT from first-sent (non-retransmitted) + if let Some(sent) = seg.sent_time { + if seg.rt_secs <= 1 { + // Only use first transmission for RTT + let rtt_us = sent.elapsed().as_micros() as u64; + self.update_rtt(rtt_us); + } + } + } else { + break; + } + } + } + + /// Process an Extended ACK (EACK) for out-of-order + /// segments. + fn recv_eack(&mut self, eack_seqnum: u32) { + for seg in self.send_window.iter_mut() { + if seg.seqnum == eack_seqnum { + seg.is_acked = true; + break; + } + } + } + + /// Deliver in-order data from the receive window to + /// the read queue. + fn deliver_to_read_queue(&mut self) { + // Count contiguous in-order segments + let mut count = 0; + for slot in &self.recv_window { + match slot { + Some(seg) if seg.is_used => count += 1, + _ => break, + } + } + // Drain them all at once (O(n) instead of O(n²)) + if count > 0 { + for seg in self.recv_window.drain(..count).flatten() { + self.rcv_cur = seg.seqnum; + self.read_queue.push_back(seg.data); + } + } + } + + /// Maximum out-of-order gap before dropping. + /// Prevents memory exhaustion from malicious + /// high-seqnum packets. + const MAX_OOO_GAP: usize = 256; + + /// Place a received segment into the receive window. + fn recv_data(&mut self, seqnum: u32, data: Vec<u8>) { + let expected = self.rcv_cur.wrapping_add(1); + + if seqnum == expected { + // In-order: deliver directly + self.read_queue.push_back(data); + self.rcv_cur = seqnum; + self.deliver_to_read_queue(); + // Clean up rcvd_seqno for delivered segments + self.rcvd_seqno.retain(|&s| is_before(seqnum, s)); + } else if is_before(expected, seqnum) { + // Out-of-order: check gap before allocating + let offset = seqnum.wrapping_sub(expected) as usize; + + // Reject if gap too large (DoS protection) + if offset > Self::MAX_OOO_GAP { + log::debug!("RDP: dropping packet with gap {offset}"); + return; + } + + // Check total recv window capacity + if self.recv_window.len() >= self.rcv_max as usize { + return; + } + + while self.recv_window.len() <= offset { + self.recv_window.push(None); + } + self.recv_window[offset] = Some(RecvSegment { + data, + seqnum, + is_used: true, + is_eacked: false, + }); + // Use bounded set (cap at MAX_OOO_GAP) + if self.rcvd_seqno.len() < Self::MAX_OOO_GAP + && !self.rcvd_seqno.contains(&seqnum) + { + self.rcvd_seqno.push(seqnum); + } + } + + // else: duplicate, ignore + } + + /// Whether a delayed ACK should be sent. + /// Update RTT estimate using Jacobson/Karels algorithm. + /// + /// Called when we receive an ACK for a segment + /// whose `sent_time` we know. + fn update_rtt(&mut self, sample_us: u64) { + if self.srtt_us == 0 { + // First measurement + self.srtt_us = sample_us; + self.rttvar_us = sample_us / 2; + } else { + // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R| + // SRTT = (1 - alpha) * SRTT + alpha * R + // alpha = 1/8, beta = 1/4 (RFC 6298) + let diff = sample_us.abs_diff(self.srtt_us); + self.rttvar_us = (3 * self.rttvar_us + diff) / 4; + self.srtt_us = (7 * self.srtt_us + sample_us) / 8; + } + + // RTO = SRTT + max(G, 4 * RTTVAR) + // G (clock granularity) = 1ms = 1000us + let k_rttvar = 4 * self.rttvar_us; + self.rto_us = self.srtt_us + k_rttvar.max(1000); + + // Clamp: min 200ms, max 60s + self.rto_us = self.rto_us.clamp(200_000, 60_000_000); + } + + /// Handle successful ACK: update congestion window. + fn on_ack_received(&mut self) { + if self.cwnd < self.ssthresh { + // Slow start: increase by 1 per ACK + self.cwnd += 1; + } else { + // Congestion avoidance: increase by 1/cwnd + // (approx 1 segment per RTT) + self.cwnd += 1u32.max(1 / self.cwnd.max(1)); + } + } + + /// Handle packet loss: halve congestion window. + fn on_loss_detected(&mut self) { + self.ssthresh = (self.cwnd / 2).max(2); + self.cwnd = self.ssthresh; + } + + fn needs_ack(&self) -> bool { + self.rcv_cur != self.rcv_ack + && self.last_ack_time.elapsed() >= RDP_ACK_INTERVAL + } + + /// Check for segments needing retransmission. + /// + /// Returns `false` if a segment has exceeded + /// max_retrans (broken pipe). + fn retransmit( + &mut self, + max_retrans: Duration, + ) -> (bool, Vec<SendSegment>) { + let mut to_send = Vec::new(); + let now = Instant::now(); + let rto_secs = (self.rto_us / 1_000_000).max(1); + + for seg in self.send_window.iter_mut() { + if !seg.is_sent { + break; + } + if seg.is_acked { + continue; + } + + // Check if we've exceeded max retransmission + if seg.rt_secs > max_retrans.as_secs() { + self.state = RdpState::Closed; + return (false, Vec::new()); // broken pipe + } + + if let Some(sent) = seg.sent_time { + // Use adaptive RTO for first retransmit, + // then exponential backoff + let timeout = if seg.rt_secs <= 1 { + rto_secs + } else { + seg.rt_secs + }; + let elapsed = now.duration_since(sent).as_secs(); + if elapsed > timeout { + seg.sent_time = Some(now); + seg.rt_secs = timeout * 2; // backoff + to_send.push(seg.clone()); + } + } + } + + // Loss detected → halve congestion window + if !to_send.is_empty() { + self.on_loss_detected(); + } + + (true, to_send) + } +} + +/// Check if sequence `a` comes before `b` (wrapping). +fn is_before(a: u32, b: u32) -> bool { + let diff = b.wrapping_sub(a); + diff > 0 && diff < 0x80000000 +} + +// ── Deferred action (invoke protection) ───────────── + +/// Actions deferred during event processing to avoid +/// reentrance issues. +#[derive(Debug)] +pub enum RdpAction { + /// Emit an event to the application. + Event { + desc: i32, + addr: RdpAddr, + event: RdpEvent, + }, + + /// Close a connection after processing. + Close(i32), +} + +// ── RDP manager ───────────────────────────────────── + +/// RDP protocol manager. +/// Incoming RDP packet for `Rdp::input()`. +pub struct RdpInput<'a> { + pub src: NodeId, + pub sport: u16, + pub dport: u16, + pub flags: u8, + pub seqnum: u32, + pub acknum: u32, + pub data: &'a [u8], +} + +/// RDP protocol manager. +/// +/// Manages multiple connections with descriptor-based +/// API (similar to file descriptors). +pub struct Rdp { + connections: HashMap<i32, RdpConnection>, + listeners: HashMap<u16, i32>, + addr_to_desc: HashMap<RdpAddr, i32>, + next_desc: i32, + max_retrans: Duration, +} + +impl Rdp { + pub fn new() -> Self { + Self { + connections: HashMap::new(), + listeners: HashMap::new(), + addr_to_desc: HashMap::new(), + next_desc: 1, + max_retrans: RDP_DEFAULT_MAX_RETRANS, + } + } + + /// Create a listening socket on `port`. + /// + /// Returns a descriptor for the listener. + pub fn listen(&mut self, port: u16) -> Result<i32, RdpError> { + if self.listeners.contains_key(&port) { + return Err(RdpError::PortInUse(port)); + } + let desc = self.alloc_desc(); + self.listeners.insert(port, desc); + Ok(desc) + } + + /// Initiate a connection to `dst:dport` from `sport`. + /// + /// Returns a descriptor for the connection. + pub fn connect( + &mut self, + sport: u16, + dst: NodeId, + dport: u16, + ) -> Result<i32, RdpError> { + let desc = self.alloc_desc(); + let addr = RdpAddr { + did: dst, + dport, + sport, + }; + + let mut conn = RdpConnection::new(desc, addr.clone(), false); + conn.state = RdpState::SynSent; + conn.snd_iss = random_isn(); + conn.snd_nxt = conn.snd_iss.wrapping_add(1); + conn.snd_una = conn.snd_iss; + conn.syn_time = Some(Instant::now()); + + self.addr_to_desc.insert(addr, desc); + self.connections.insert(desc, conn); + Ok(desc) + } + + /// Close a connection or listener. + pub fn close(&mut self, desc: i32) { + if let Some(mut conn) = self.connections.remove(&desc) { + conn.is_closed = true; + self.addr_to_desc.remove(&conn.addr); + log::debug!("RDP: closed desc {desc}"); + } + + // Also check listeners + self.listeners.retain(|_, d| *d != desc); + } + + /// Enqueue data for sending on a connection. + pub fn send(&mut self, desc: i32, data: &[u8]) -> Result<usize, RdpError> { + let conn = self + .connections + .get_mut(&desc) + .ok_or(RdpError::BadDescriptor(desc))?; + + if conn.state != RdpState::Open { + return Err(RdpError::NotOpen(desc)); + } + + if !conn.enqueue_send(data) { + return Err(RdpError::SendBufferFull); + } + + Ok(data.len()) + } + + /// Read available data from a connection. + /// + /// Returns the number of bytes read, or 0 if no + /// data available. + pub fn recv( + &mut self, + desc: i32, + buf: &mut [u8], + ) -> Result<usize, RdpError> { + let conn = self + .connections + .get_mut(&desc) + .ok_or(RdpError::BadDescriptor(desc))?; + + if let Some(data) = conn.read_queue.pop_front() { + let len = data.len().min(buf.len()); + buf[..len].copy_from_slice(&data[..len]); + Ok(len) + } else { + Ok(0) + } + } + + /// Get the state of a descriptor. + pub fn get_state(&self, desc: i32) -> Result<RdpState, RdpError> { + self.connections + .get(&desc) + .map(|c| c.state) + .ok_or(RdpError::BadDescriptor(desc)) + } + + /// Get status of all connections. + pub fn get_status(&self) -> Vec<RdpStatus> { + self.connections + .values() + .map(|c| RdpStatus { + state: c.state, + did: c.addr.did, + dport: c.addr.dport, + sport: c.addr.sport, + }) + .collect() + } + + /// Set the maximum retransmission timeout. + pub fn set_max_retrans(&mut self, dur: Duration) { + self.max_retrans = dur; + } + + /// Get the maximum retransmission timeout. + pub fn max_retrans(&self) -> Duration { + self.max_retrans + } + + /// Process incoming RDP data from a peer. + /// + /// Returns deferred actions (events, closes) to + /// avoid reentrance during processing. + pub fn input(&mut self, pkt: &RdpInput<'_>) -> Vec<RdpAction> { + let src = pkt.src; + let sport = pkt.sport; + let dport = pkt.dport; + let flags = pkt.flags; + let seqnum = pkt.seqnum; + let acknum = pkt.acknum; + let data = pkt.data; + let addr = RdpAddr { + did: src, + dport: sport, // their sport is our dport + sport: dport, // our dport is our sport + }; + let mut actions = Vec::new(); + + if let Some(&desc) = self.addr_to_desc.get(&addr) { + // Existing connection + if let Some(conn) = self.connections.get_mut(&desc) { + Self::process_connected( + conn, + flags, + seqnum, + acknum, + data, + &mut actions, + ); + } + } else if flags & RDP_FLAG_SYN != 0 { + // New inbound SYN → check listener + if let Some(&_listen_desc) = self.listeners.get(&dport) { + let desc = self.alloc_desc(); + let mut conn = RdpConnection::new(desc, addr.clone(), true); + conn.state = RdpState::SynRcvd; + conn.rcv_irs = seqnum; + conn.rcv_cur = seqnum; + conn.snd_iss = random_isn(); + conn.snd_nxt = conn.snd_iss.wrapping_add(1); + conn.snd_una = conn.snd_iss; + + self.addr_to_desc.insert(addr.clone(), desc); + self.connections.insert(desc, conn); + + actions.push(RdpAction::Event { + desc, + addr, + event: RdpEvent::Accepted, + }); + } + + // else: no listener → RST (ignored for now) + } + + actions + } + + /// Process a packet on an existing connection. + fn process_connected( + conn: &mut RdpConnection, + flags: u8, + seqnum: u32, + acknum: u32, + data: &[u8], + actions: &mut Vec<RdpAction>, + ) { + match conn.state { + RdpState::SynSent => { + if flags & RDP_FLAG_SYN != 0 && flags & RDP_FLAG_ACK != 0 { + conn.rcv_irs = seqnum; + conn.rcv_cur = seqnum; + conn.recv_ack(acknum); + conn.state = RdpState::Open; + actions.push(RdpAction::Event { + desc: conn.desc, + addr: conn.addr.clone(), + event: RdpEvent::Connected, + }); + } else if flags & RDP_FLAG_RST != 0 { + conn.state = RdpState::Closed; + actions.push(RdpAction::Event { + desc: conn.desc, + addr: conn.addr.clone(), + event: RdpEvent::Refused, + }); + } + } + RdpState::SynRcvd => { + if flags & RDP_FLAG_ACK != 0 { + conn.recv_ack(acknum); + conn.state = RdpState::Open; + } + } + RdpState::Open => { + if flags & RDP_FLAG_RST != 0 { + conn.state = RdpState::Closed; + actions.push(RdpAction::Event { + desc: conn.desc, + addr: conn.addr.clone(), + event: RdpEvent::Reset, + }); + return; + } + if flags & RDP_FLAG_FIN != 0 { + conn.state = if conn.is_passive { + RdpState::CloseWaitPassive + } else { + RdpState::CloseWaitActive + }; + conn.close_time = Some(Instant::now()); + return; + } + if flags & RDP_FLAG_ACK != 0 { + conn.recv_ack(acknum); + } + if flags & RDP_FLAG_EAK != 0 { + conn.recv_eack(seqnum); + } + if !data.is_empty() { + conn.recv_data(seqnum, data.to_vec()); + actions.push(RdpAction::Event { + desc: conn.desc, + addr: conn.addr.clone(), + event: RdpEvent::Ready2Read, + }); + } + } + RdpState::CloseWaitActive => { + if flags & RDP_FLAG_FIN != 0 { + conn.state = RdpState::Closed; + } + } + _ => {} + } + } + + /// Periodic tick: retransmit, delayed ACK, timeouts. + /// + /// Returns actions for timed-out connections. + pub fn tick(&mut self) -> Vec<RdpAction> { + let mut actions = Vec::new(); + let mut to_close = Vec::new(); + + for (desc, conn) in self.connections.iter_mut() { + // Check SYN timeout with exponential backoff + if conn.state == RdpState::SynSent { + if let Some(t) = conn.syn_time { + let elapsed = t.elapsed().as_secs(); + if elapsed > conn.syn_rt_secs { + if conn.syn_rt_secs > self.max_retrans.as_secs() { + actions.push(RdpAction::Event { + desc: *desc, + addr: conn.addr.clone(), + event: RdpEvent::Failed, + }); + to_close.push(*desc); + } else { + // Retry SYN with backoff + conn.syn_time = Some(Instant::now()); + conn.syn_rt_secs *= 2; + } + } + } + } + + // RST retry + if conn.is_retry_rst { + if let Some(t) = conn.rst_time { + if t.elapsed().as_secs() > conn.rst_rt_secs { + conn.rst_rt_secs *= 2; + conn.rst_time = Some(Instant::now()); + if conn.rst_rt_secs > self.max_retrans.as_secs() { + conn.is_retry_rst = false; + to_close.push(*desc); + } + } + } + } + + // Check close-wait timeout + if matches!( + conn.state, + RdpState::CloseWaitPassive | RdpState::CloseWaitActive + ) { + if let Some(t) = conn.close_time { + if t.elapsed() >= self.max_retrans { + to_close.push(*desc); + } + } + } + + // Retransmit unacked segments + if conn.state == RdpState::Open { + let (alive, _retransmits) = conn.retransmit(self.max_retrans); + if !alive { + // Broken pipe — exceeded max retrans + actions.push(RdpAction::Event { + desc: *desc, + addr: conn.addr.clone(), + event: RdpEvent::Broken, + }); + to_close.push(*desc); + } + + // Note: retransmitted segments are picked + // up by pending_output() in the next flush + } + } + + for d in to_close { + self.close(d); + } + + actions + } + + /// Number of active connections. + pub fn connection_count(&self) -> usize { + self.connections.len() + } + + /// Number of active listeners. + pub fn listener_count(&self) -> usize { + self.listeners.len() + } + + /// Build outgoing packets for a connection. + /// + /// Returns `(dst_id, sport, dport, packets)` where + /// each packet is `(flags, seqnum, acknum, data)`. + /// The caller wraps these in protocol messages and + /// sends via UDP. + pub fn pending_output(&mut self, desc: i32) -> Option<PendingOutput> { + let conn = self.connections.get_mut(&desc)?; + + let mut packets = Vec::new(); + + match conn.state { + RdpState::SynSent => { + // Send SYN with receive buffer params + // (rdp_syn: out_segs_max + seg_size_max) + let mut syn_data = Vec::with_capacity(4); + syn_data + .extend_from_slice(&(conn.rcv_max as u16).to_be_bytes()); + syn_data + .extend_from_slice(&(conn.rbuf_max as u16).to_be_bytes()); + packets.push(RdpPacket { + flags: RDP_FLAG_SYN, + seqnum: conn.snd_iss, + acknum: 0, + data: syn_data, + }); + } + RdpState::SynRcvd => { + // Send SYN+ACK + packets.push(RdpPacket { + flags: RDP_FLAG_SYN | RDP_FLAG_ACK, + seqnum: conn.snd_iss, + acknum: conn.rcv_cur, + data: Vec::new(), + }); + } + RdpState::Open => { + // Send ACK if needed + if conn.needs_ack() { + packets.push(RdpPacket { + flags: RDP_FLAG_ACK, + seqnum: conn.snd_nxt, + acknum: conn.rcv_cur, + data: Vec::new(), + }); + conn.last_ack_time = Instant::now(); + conn.rcv_ack = conn.rcv_cur; + } + + // Send EACKs for out-of-order recv segments + for seg in conn.recv_window.iter_mut().flatten() { + if seg.is_used && !seg.is_eacked { + packets.push(RdpPacket { + flags: RDP_FLAG_EAK | RDP_FLAG_ACK, + seqnum: seg.seqnum, + acknum: conn.rcv_cur, + data: Vec::new(), + }); + seg.is_eacked = true; + } + } + + // Send pending data segments, limited by + // congestion window (AIMD) + let in_flight = conn + .send_window + .iter() + .filter(|s| s.is_sent && !s.is_acked) + .count() as u32; + let can_send = conn.cwnd.saturating_sub(in_flight); + + let mut sent_count = 0u32; + for seg in &conn.send_window { + if sent_count >= can_send { + break; + } + if !seg.is_sent && !seg.is_acked { + packets.push(RdpPacket { + flags: RDP_FLAG_ACK, + seqnum: seg.seqnum, + acknum: conn.rcv_cur, + data: seg.data.clone(), + }); + sent_count += 1; + } + } + + // Mark as sent + let mut marked = 0u32; + for seg in conn.send_window.iter_mut() { + if marked >= can_send { + break; + } + if !seg.is_sent { + seg.is_sent = true; + seg.sent_time = Some(Instant::now()); + marked += 1; + } + } + } + _ => {} + } + + if packets.is_empty() { + return None; + } + + Some(PendingOutput { + dst: conn.addr.did, + sport: conn.addr.sport, + dport: conn.addr.dport, + packets, + }) + } + + /// Get all connection descriptors. + pub fn descriptors(&self) -> Vec<i32> { + self.connections.keys().copied().collect() + } + + fn alloc_desc(&mut self) -> i32 { + let d = self.next_desc; + // Wrap at i32::MAX to avoid overflow; skip 0 + // and negative values + self.next_desc = if d >= i32::MAX - 1 { 1 } else { d + 1 }; + d + } +} + +/// A pending outgoing RDP packet. +#[derive(Debug, Clone)] +pub struct RdpPacket { + pub flags: u8, + pub seqnum: u32, + pub acknum: u32, + pub data: Vec<u8>, +} + +/// Pending output for a connection. +#[derive(Debug)] +pub struct PendingOutput { + pub dst: NodeId, + pub sport: u16, + pub dport: u16, + pub packets: Vec<RdpPacket>, +} + +/// Build an RDP wire packet: rdp_head(20) + data. +pub fn build_rdp_wire( + flags: u8, + sport: u16, + dport: u16, + seqnum: u32, + acknum: u32, + data: &[u8], +) -> Vec<u8> { + let dlen = data.len() as u16; + let mut buf = vec![0u8; RDP_HEADER_SIZE + data.len()]; + buf[0] = flags; + buf[1] = (RDP_HEADER_SIZE / 2) as u8; // hlen in 16-bit words + buf[2..4].copy_from_slice(&sport.to_be_bytes()); + buf[4..6].copy_from_slice(&dport.to_be_bytes()); + buf[6..8].copy_from_slice(&dlen.to_be_bytes()); + buf[8..12].copy_from_slice(&seqnum.to_be_bytes()); + buf[12..16].copy_from_slice(&acknum.to_be_bytes()); + buf[16..20].fill(0); // reserved + buf[20..].copy_from_slice(data); + buf +} + +/// Parsed RDP wire header fields. +pub struct RdpWireHeader<'a> { + pub flags: u8, + pub sport: u16, + pub dport: u16, + pub seqnum: u32, + pub acknum: u32, + pub data: &'a [u8], +} + +/// Parse an RDP wire packet header. +pub fn parse_rdp_wire(buf: &[u8]) -> Option<RdpWireHeader<'_>> { + if buf.len() < RDP_HEADER_SIZE { + return None; + } + Some(RdpWireHeader { + flags: buf[0], + sport: u16::from_be_bytes([buf[2], buf[3]]), + dport: u16::from_be_bytes([buf[4], buf[5]]), + seqnum: u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]), + acknum: u32::from_be_bytes([buf[12], buf[13], buf[14], buf[15]]), + data: &buf[RDP_HEADER_SIZE..], + }) +} + +impl Default for Rdp { + fn default() -> Self { + Self::new() + } +} + +// ── Errors ────────────────────────────────────────── + +#[derive(Debug)] +pub enum RdpError { + PortInUse(u16), + BadDescriptor(i32), + NotOpen(i32), + SendBufferFull, +} + +impl std::fmt::Display for RdpError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RdpError::PortInUse(p) => write!(f, "port {p} in use"), + RdpError::BadDescriptor(d) => write!(f, "bad descriptor {d}"), + RdpError::NotOpen(d) => write!(f, "descriptor {d} not open"), + RdpError::SendBufferFull => write!(f, "send buffer full"), + } + } +} + +impl std::error::Error for RdpError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn listen_and_close() { + let mut rdp = Rdp::new(); + let desc = rdp.listen(5000).unwrap(); + assert_eq!(rdp.listener_count(), 1); + rdp.close(desc); + assert_eq!(rdp.listener_count(), 0); + } + + #[test] + fn listen_duplicate_port() { + let mut rdp = Rdp::new(); + rdp.listen(5000).unwrap(); + assert!(matches!(rdp.listen(5000), Err(RdpError::PortInUse(5000)))); + } + + #[test] + fn connect_creates_syn_sent() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(0, dst, 5000).unwrap(); + assert_eq!(rdp.get_state(desc).unwrap(), RdpState::SynSent); + } + + #[test] + fn send_before_open_fails() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(0, dst, 5000).unwrap(); + assert!(matches!( + rdp.send(desc, b"hello"), + Err(RdpError::NotOpen(_)) + )); + } + + #[test] + fn recv_empty() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(0, dst, 5000).unwrap(); + let mut buf = [0u8; 64]; + + // SynSent state → bad descriptor or no data + assert!(rdp.recv(desc, &mut buf).is_ok()); + } + + #[test] + fn syn_ack_opens_connection() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(1000, dst, 5000).unwrap(); + + // Simulate receiving SYN+ACK + let actions = rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_SYN | RDP_FLAG_ACK, + seqnum: 100, + acknum: 1, + data: &[], + }); + + assert_eq!(rdp.get_state(desc).unwrap(), RdpState::Open); + assert!(actions.iter().any(|a| matches!( + a, + RdpAction::Event { + event: RdpEvent::Connected, + .. + } + ))); + } + + #[test] + fn inbound_syn_accepted() { + let mut rdp = Rdp::new(); + rdp.listen(5000).unwrap(); + + let peer = NodeId::from_bytes([0x02; 32]); + let actions = rdp.input(&RdpInput { + src: peer, + sport: 3000, + dport: 5000, + flags: RDP_FLAG_SYN, + seqnum: 200, + acknum: 0, + data: &[], + }); + + assert_eq!(rdp.connection_count(), 1); + assert!(actions.iter().any(|a| matches!( + a, + RdpAction::Event { + event: RdpEvent::Accepted, + .. + } + ))); + } + + #[test] + fn rst_resets_connection() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(1000, dst, 5000).unwrap(); + + // Open first + rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_SYN | RDP_FLAG_ACK, + seqnum: 100, + acknum: 1, + data: &[], + }); + assert_eq!(rdp.get_state(desc).unwrap(), RdpState::Open); + + // RST + let actions = rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_RST, + seqnum: 0, + acknum: 0, + data: &[], + }); + assert_eq!(rdp.get_state(desc).unwrap(), RdpState::Closed); + assert!(actions.iter().any(|a| matches!( + a, + RdpAction::Event { + event: RdpEvent::Reset, + .. + } + ))); + } + + #[test] + fn data_delivery() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(1000, dst, 5000).unwrap(); + + // Open + rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_SYN | RDP_FLAG_ACK, + seqnum: 100, + acknum: 1, + data: &[], + }); + + // Receive data + let actions = rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_ACK, + seqnum: 101, + acknum: 1, + data: b"hello", + }); + + assert!(actions.iter().any(|a| matches!( + a, + RdpAction::Event { + event: RdpEvent::Ready2Read, + .. + } + ))); + + let mut buf = [0u8; 64]; + let n = rdp.recv(desc, &mut buf).unwrap(); + assert_eq!(&buf[..n], b"hello"); + } + + #[test] + fn send_data_on_open() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + let desc = rdp.connect(1000, dst, 5000).unwrap(); + + // Open + rdp.input(&RdpInput { + src: dst, + sport: 5000, + dport: 1000, + flags: RDP_FLAG_SYN | RDP_FLAG_ACK, + seqnum: 100, + acknum: 1, + data: &[], + }); + + let n = rdp.send(desc, b"world").unwrap(); + assert_eq!(n, 5); + } + + #[test] + fn get_status() { + let mut rdp = Rdp::new(); + let dst = NodeId::from_bytes([0x01; 32]); + rdp.connect(1000, dst, 5000).unwrap(); + + let status = rdp.get_status(); + assert_eq!(status.len(), 1); + assert_eq!(status[0].state, RdpState::SynSent); + assert_eq!(status[0].dport, 5000); + } + + #[test] + fn is_before_wrapping() { + assert!(is_before(1, 2)); + assert!(is_before(0, 1)); + assert!(!is_before(2, 1)); + assert!(!is_before(5, 5)); + + // Wrapping: u32::MAX is before 0 + assert!(is_before(u32::MAX, 0)); + } +} diff --git a/src/routing.rs b/src/routing.rs new file mode 100644 index 0000000..a9b618d --- /dev/null +++ b/src/routing.rs @@ -0,0 +1,843 @@ +//! Kademlia routing table with k-buckets. +//! +//! Each bucket holds +//! up to `MAX_BUCKET_ENTRY` (20) peers ordered by last +//! seen time (LRU). When a bucket is full, the least +//! recently seen peer is pinged; if it doesn't respond, +//! it's replaced by the new peer. + +use std::time::Instant; + +use crate::id::NodeId; +use crate::peers::PeerInfo; + +/// Maximum entries per k-bucket. +pub const MAX_BUCKET_ENTRY: usize = 20; + +/// Number of bits in a node ID. +pub const ID_BITS: usize = crate::id::ID_BITS; + +/// Number of closest nodes to return in lookups +/// Kademlia default: 10. +pub const NUM_FIND_NODE: usize = 10; + +/// Maximum entries in the replacement cache per bucket. +/// When a bucket is full, new contacts go here instead +/// of being discarded (Kademlia paper section 2.4). +const MAX_REPLACEMENT_CACHE: usize = 5; + +/// Number of consecutive failures before a contact is +/// considered stale and eligible for replacement. +const STALE_THRESHOLD: u32 = 3; + +/// Result of inserting a peer into the routing table. +#[derive(Debug)] +pub enum InsertResult { + /// Peer was inserted into the bucket. + Inserted, + + /// Peer already existed and was moved to tail (LRU). + Updated, + + /// Bucket is full. Contains the LRU peer that should + /// be pinged to decide eviction. + BucketFull { lru: PeerInfo }, + + /// Peer is our own ID, ignored. + IsSelf, +} + +/// A single k-bucket holding up to K peers. +struct KBucket { + nodes: Vec<PeerInfo>, + last_updated: Instant, + /// Replacement cache: contacts seen when bucket is + /// full. Used to replace stale contacts without + /// losing discovered nodes (Kademlia paper §2.4). + replacements: Vec<PeerInfo>, + /// Consecutive failure count per node ID. Peers with + /// count >= STALE_THRESHOLD are replaced by cached + /// contacts. + stale_counts: std::collections::HashMap<NodeId, u32>, +} + +impl KBucket { + fn new() -> Self { + Self { + nodes: Vec::new(), + last_updated: Instant::now(), + replacements: Vec::new(), + stale_counts: std::collections::HashMap::new(), + } + } + + fn len(&self) -> usize { + self.nodes.len() + } + + fn is_full(&self) -> bool { + self.nodes.len() >= MAX_BUCKET_ENTRY + } + + fn contains(&self, id: &NodeId) -> bool { + self.nodes.iter().any(|n| n.id == *id) + } + + fn find_pos(&self, id: &NodeId) -> Option<usize> { + self.nodes.iter().position(|n| n.id == *id) + } + + /// Insert or update a peer. Returns InsertResult. + fn insert(&mut self, peer: PeerInfo) -> InsertResult { + if let Some(pos) = self.find_pos(&peer.id) { + // Move to tail (most recently seen) + self.nodes.remove(pos); + self.nodes.push(peer); + self.last_updated = Instant::now(); + // Clear stale count on successful contact + self.stale_counts.remove(&self.nodes.last().unwrap().id); + return InsertResult::Updated; + } + + if self.is_full() { + // Check if any existing contact is stale + // enough to replace immediately + if let Some(stale_pos) = self.find_stale() { + let stale_id = self.nodes[stale_pos].id; + self.stale_counts.remove(&stale_id); + self.nodes.remove(stale_pos); + self.nodes.push(peer); + self.last_updated = Instant::now(); + return InsertResult::Inserted; + } + + // No stale contact: add to replacement cache + self.add_to_cache(peer.clone()); + + // Return LRU (front) for ping check + let lru = self.nodes[0].clone(); + return InsertResult::BucketFull { lru }; + } + + self.nodes.push(peer); + self.last_updated = Instant::now(); + InsertResult::Inserted + } + + /// Find a contact whose stale count exceeds the + /// threshold. Returns its position in the nodes vec. + fn find_stale(&self) -> Option<usize> { + for (i, node) in self.nodes.iter().enumerate() { + if let Some(&count) = self.stale_counts.get(&node.id) { + if count >= STALE_THRESHOLD { + return Some(i); + } + } + } + None + } + + /// Add a contact to the replacement cache. + fn add_to_cache(&mut self, peer: PeerInfo) { + // Update if already in cache + if let Some(pos) = self + .replacements + .iter() + .position(|r| r.id == peer.id) + { + self.replacements.remove(pos); + self.replacements.push(peer); + return; + } + if self.replacements.len() >= MAX_REPLACEMENT_CACHE { + self.replacements.remove(0); // drop oldest + } + self.replacements.push(peer); + } + + /// Record a failure for a contact. Returns true if + /// the contact became stale (crossed threshold). + fn record_failure(&mut self, id: &NodeId) -> bool { + let count = self.stale_counts.entry(*id).or_insert(0); + *count += 1; + *count >= STALE_THRESHOLD + } + + /// Try to replace a stale contact with the best + /// replacement from cache. Returns the evicted ID + /// if successful. + fn try_replace_stale(&mut self, stale_id: &NodeId) -> Option<NodeId> { + let pos = self.find_pos(stale_id)?; + let replacement = self.replacements.pop()?; + let evicted = self.nodes[pos].id; + self.stale_counts.remove(&evicted); + self.nodes.remove(pos); + self.nodes.push(replacement); + self.last_updated = Instant::now(); + Some(evicted) + } + + /// Number of contacts in the replacement cache. + fn cache_len(&self) -> usize { + self.replacements.len() + } + + /// Replace the LRU node (front) with a new peer. + /// Only succeeds if `old_id` matches the current LRU. + fn evict_lru(&mut self, old_id: &NodeId, new: PeerInfo) -> bool { + if let Some(front) = self.nodes.first() { + if front.id == *old_id { + self.nodes.remove(0); + self.nodes.push(new); + self.last_updated = Instant::now(); + return true; + } + } + false + } + + fn remove(&mut self, id: &NodeId) -> bool { + if let Some(pos) = self.find_pos(id) { + self.nodes.remove(pos); + true + } else { + false + } + } + + /// Mark a peer as recently seen (move to tail). + fn mark_seen(&mut self, id: &NodeId) { + if let Some(pos) = self.find_pos(id) { + let peer = self.nodes.remove(pos); + self.nodes.push(PeerInfo { + last_seen: Instant::now(), + ..peer + }); + self.last_updated = Instant::now(); + self.stale_counts.remove(id); + } + } +} + +/// Kademlia routing table. +/// +/// Maintains 256 k-buckets indexed by XOR distance from +/// the local node. Each bucket holds up to +/// `MAX_BUCKET_ENTRY` peers. +/// Maximum nodes per /24 subnet in the routing table. +/// Limits Sybil attack impact. +pub const MAX_PER_SUBNET: usize = 2; + +pub struct RoutingTable { + local_id: NodeId, + buckets: Vec<KBucket>, + + /// Count of nodes per /24 subnet for Sybil + /// resistance. + subnet_counts: std::collections::HashMap<[u8; 3], usize>, + + /// Pinned bootstrap nodes — never evicted. + pinned: std::collections::HashSet<NodeId>, +} + +impl RoutingTable { + /// Create a new routing table for the given local ID. + pub fn new(local_id: NodeId) -> Self { + let mut buckets = Vec::with_capacity(ID_BITS); + for _ in 0..ID_BITS { + buckets.push(KBucket::new()); + } + Self { + local_id, + buckets, + subnet_counts: std::collections::HashMap::new(), + pinned: std::collections::HashSet::new(), + } + } + + /// Pin a bootstrap node — it will never be evicted. + pub fn pin(&mut self, id: NodeId) { + self.pinned.insert(id); + } + + /// Check if a node is pinned. + pub fn is_pinned(&self, id: &NodeId) -> bool { + self.pinned.contains(id) + } + + /// Our own node ID. + pub fn local_id(&self) -> &NodeId { + &self.local_id + } + + /// Determine the bucket index for a given node ID. + /// + /// Returns `None` if `id` equals our local ID. + fn bucket_index(&self, id: &NodeId) -> Option<usize> { + let dist = self.local_id.distance(id); + if dist.is_zero() { + return None; + } + let lz = dist.leading_zeros() as usize; + + // Bucket 0 = furthest (bit 0 differs), + // Bucket 255 = closest (only bit 255 differs). + // Index = 255 - leading_zeros. + Some(ID_BITS - 1 - lz) + } + + /// Add a peer to the routing table. + /// + /// Rejects the peer if its /24 subnet already has + /// `MAX_PER_SUBNET` entries (Sybil resistance). + pub fn add(&mut self, peer: PeerInfo) -> InsertResult { + if peer.id == self.local_id { + return InsertResult::IsSelf; + } + let idx = match self.bucket_index(&peer.id) { + Some(i) => i, + None => return InsertResult::IsSelf, + }; + + // Sybil check: limit per /24 subnet + // Skip for loopback (tests, local dev) + let subnet = subnet_key(&peer.addr); + let is_loopback = peer.addr.ip().is_loopback(); + if !is_loopback && !self.buckets[idx].contains(&peer.id) { + let count = self.subnet_counts.get(&subnet).copied().unwrap_or(0); + if count >= MAX_PER_SUBNET { + log::debug!( + "Sybil: rejecting {:?} (subnet {:?} has {count} entries)", + peer.id, + subnet + ); + return InsertResult::BucketFull { lru: peer }; + } + } + + let result = self.buckets[idx].insert(peer); + if matches!(result, InsertResult::Inserted) { + *self.subnet_counts.entry(subnet).or_insert(0) += 1; + } + result + } + + /// Remove a peer from the routing table. + pub fn remove(&mut self, id: &NodeId) -> bool { + // Never evict pinned bootstrap nodes + if self.pinned.contains(id) { + return false; + } + if let Some(idx) = self.bucket_index(id) { + // Decrement subnet count + if let Some(peer) = + self.buckets[idx].nodes.iter().find(|p| p.id == *id) + { + let subnet = subnet_key(&peer.addr); + if let Some(c) = self.subnet_counts.get_mut(&subnet) { + *c = c.saturating_sub(1); + if *c == 0 { + self.subnet_counts.remove(&subnet); + } + } + } + self.buckets[idx].remove(id) + } else { + false + } + } + + /// Evict the LRU node in a bucket and insert a new + /// peer. Called after a ping timeout confirms the LRU + /// node is dead. + pub fn evict_and_insert(&mut self, old_id: &NodeId, new: PeerInfo) -> bool { + if let Some(idx) = self.bucket_index(&new.id) { + self.buckets[idx].evict_lru(old_id, new) + } else { + false + } + } + + /// Mark a peer as recently seen. + pub fn mark_seen(&mut self, id: &NodeId) { + if let Some(idx) = self.bucket_index(id) { + self.buckets[idx].mark_seen(id); + } + } + + /// Record a communication failure for a peer. + /// If the peer becomes stale (exceeds threshold), + /// tries to replace it with a cached contact. + /// Returns the evicted NodeId if replacement happened. + pub fn record_failure(&mut self, id: &NodeId) -> Option<NodeId> { + // Never mark pinned nodes as stale + if self.pinned.contains(id) { + return None; + } + let idx = self.bucket_index(id)?; + let became_stale = self.buckets[idx].record_failure(id); + if became_stale { + self.buckets[idx].try_replace_stale(id) + } else { + None + } + } + + /// Total number of contacts in all replacement caches. + pub fn replacement_cache_size(&self) -> usize { + self.buckets.iter().map(|b| b.cache_len()).sum() + } + + /// Find the `count` closest peers to `target` by XOR + /// distance, sorted closest-first. + pub fn closest(&self, target: &NodeId, count: usize) -> Vec<PeerInfo> { + let mut all: Vec<PeerInfo> = self + .buckets + .iter() + .flat_map(|b| b.nodes.iter().cloned()) + .collect(); + + all.sort_by(|a, b| { + let da = target.distance(&a.id); + let db = target.distance(&b.id); + da.cmp(&db) + }); + + all.truncate(count); + all + } + + /// Check if a given ID exists in the table. + pub fn has_id(&self, id: &NodeId) -> bool { + if let Some(idx) = self.bucket_index(id) { + self.buckets[idx].contains(id) + } else { + false + } + } + + /// Total number of peers in the table. + pub fn size(&self) -> usize { + self.buckets.iter().map(|b| b.len()).sum() + } + + /// Check if the table has no peers. + pub fn is_empty(&self) -> bool { + self.size() == 0 + } + + /// Get fill level of each non-empty bucket (for + /// debugging/metrics). + pub fn bucket_fill_levels(&self) -> Vec<(usize, usize)> { + self.buckets + .iter() + .enumerate() + .filter(|(_, b)| b.len() > 0) + .map(|(i, b)| (i, b.len())) + .collect() + } + + /// Find buckets that haven't been updated since + /// `threshold` and return random target IDs for + /// refresh lookups. + /// + /// This implements the Kademlia bucket refresh from + /// the paper: pick a random ID in each stale bucket's + /// range and do a find_node on it. + pub fn stale_bucket_targets( + &self, + threshold: std::time::Duration, + ) -> Vec<NodeId> { + let now = Instant::now(); + let mut targets = Vec::new(); + + for (i, bucket) in self.buckets.iter().enumerate() { + if now.duration_since(bucket.last_updated) >= threshold { + // Generate a random ID in this bucket's range. + // The bucket at index i covers nodes where the + // XOR distance has bit (255-i) as the highest + // set bit. We create a target by XORing our ID + // with a value that has bit (255-i) set. + let bit_pos = ID_BITS - 1 - i; + let byte_idx = bit_pos / 8; + let bit_idx = 7 - (bit_pos % 8); + + let bytes = *self.local_id.as_bytes(); + let mut buf = bytes; + buf[byte_idx] ^= 1 << bit_idx; + targets.push(NodeId::from_bytes(buf)); + } + } + + targets + } + + /// Return the LRU (least recently seen) peer from + /// each non-empty bucket, for liveness probing. + /// + /// The caller should ping each and call + /// `mark_seen()` on reply, or `remove()` after + /// repeated failures. + pub fn lru_peers(&self) -> Vec<PeerInfo> { + self.buckets + .iter() + .filter_map(|b| b.nodes.first().cloned()) + .collect() + } + + /// Print the routing table (debug). + pub fn print_table(&self) { + for (i, bucket) in self.buckets.iter().enumerate() { + if bucket.len() > 0 { + log::debug!("bucket {i}: {} nodes", bucket.len()); + for node in &bucket.nodes { + log::debug!(" {} @ {}", node.id, node.addr); + } + } + } + } +} + +/// Extract /24 subnet key from a socket address. +/// For IPv6, uses the first 6 bytes (/48). +fn subnet_key(addr: &std::net::SocketAddr) -> [u8; 3] { + match addr.ip() { + std::net::IpAddr::V4(v4) => { + let o = v4.octets(); + [o[0], o[1], o[2]] + } + std::net::IpAddr::V6(v6) => { + let o = v6.octets(); + [o[0], o[1], o[2]] + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + + fn local_id() -> NodeId { + NodeId::from_bytes([0x80; 32]) + } + + fn peer_at(byte: u8, port: u16) -> PeerInfo { + // Use different /24 subnets to avoid Sybil limit + PeerInfo::new( + NodeId::from_bytes([byte; 32]), + SocketAddr::from(([10, 0, byte, 1], port)), + ) + } + + #[test] + fn insert_self_is_ignored() { + let mut rt = RoutingTable::new(local_id()); + let p = PeerInfo::new(local_id(), "127.0.0.1:3000".parse().unwrap()); + assert!(matches!(rt.add(p), InsertResult::IsSelf)); + assert_eq!(rt.size(), 0); + } + + #[test] + fn insert_and_lookup() { + let mut rt = RoutingTable::new(local_id()); + let p = peer_at(0x01, 3000); + assert!(matches!(rt.add(p.clone()), InsertResult::Inserted)); + assert_eq!(rt.size(), 1); + assert!(rt.has_id(&p.id)); + } + + #[test] + fn update_moves_to_tail() { + let mut rt = RoutingTable::new(local_id()); + let p1 = peer_at(0x01, 3000); + let p2 = peer_at(0x02, 3001); + rt.add(p1.clone()); + rt.add(p2.clone()); + + // Re-add p1 should move to tail + assert!(matches!(rt.add(p1.clone()), InsertResult::Updated)); + } + + #[test] + fn remove_peer() { + let mut rt = RoutingTable::new(local_id()); + let p = peer_at(0x01, 3000); + rt.add(p.clone()); + assert!(rt.remove(&p.id)); + assert_eq!(rt.size(), 0); + assert!(!rt.has_id(&p.id)); + } + + #[test] + fn closest_sorted_by_xor() { + let mut rt = RoutingTable::new(local_id()); + + // Add peers with different distances from a target + for i in 1..=5u8 { + rt.add(peer_at(i, 3000 + i as u16)); + } + let target = NodeId::from_bytes([0x03; 32]); + let closest = rt.closest(&target, 3); + assert_eq!(closest.len(), 3); + + // Verify sorted by XOR distance + for w in closest.windows(2) { + let d0 = target.distance(&w[0].id); + let d1 = target.distance(&w[1].id); + assert!(d0 <= d1); + } + } + + #[test] + fn closest_respects_count() { + let mut rt = RoutingTable::new(local_id()); + for i in 1..=30u8 { + rt.add(peer_at(i, 3000 + i as u16)); + } + let target = NodeId::from_bytes([0x10; 32]); + let closest = rt.closest(&target, 10); + assert_eq!(closest.len(), 10); + } + + #[test] + fn bucket_full_returns_lru() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + + // Fill a bucket with MAX_BUCKET_ENTRY peers. + // All peers with [0xFF; 32] ^ small variations + // will land in the same bucket (highest bit differs). + for i in 0..MAX_BUCKET_ENTRY as u16 { + let mut bytes = [0xFF; 32]; + bytes[18] = (i >> 8) as u8; + bytes[19] = i as u8; + let p = PeerInfo::new( + NodeId::from_bytes(bytes), + // Different /24 per peer to avoid Sybil limit + SocketAddr::from(([10, 0, i as u8, 1], 3000 + i)), + ); + assert!(matches!(rt.add(p), InsertResult::Inserted)); + } + + assert_eq!(rt.size(), MAX_BUCKET_ENTRY); + + // Next insert should return BucketFull + let mut extra_bytes = [0xFF; 32]; + extra_bytes[19] = 0xFE; + extra_bytes[18] = 0xFE; + let extra = PeerInfo::new( + NodeId::from_bytes(extra_bytes), + SocketAddr::from(([10, 0, 250, 1], 9999)), + ); + let result = rt.add(extra); + assert!(matches!(result, InsertResult::BucketFull { .. })); + } + + #[test] + fn evict_and_insert() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + + // Fill bucket + let mut first_id = NodeId::from_bytes([0xFF; 32]); + for i in 0..MAX_BUCKET_ENTRY as u16 { + let mut bytes = [0xFF; 32]; + bytes[18] = (i >> 8) as u8; + bytes[19] = i as u8; + let id = NodeId::from_bytes(bytes); + if i == 0 { + first_id = id; + } + rt.add(PeerInfo::new( + id, + SocketAddr::from(([10, 0, i as u8, 1], 3000 + i)), + )); + } + + // Evict the first (LRU) and insert new + let mut new_bytes = [0xFF; 32]; + new_bytes[17] = 0x01; + let new_peer = PeerInfo::new( + NodeId::from_bytes(new_bytes), + SocketAddr::from(([10, 0, 251, 1], 9999)), + ); + + assert!(rt.evict_and_insert(&first_id, new_peer.clone())); + assert!(!rt.has_id(&first_id)); + assert!(rt.has_id(&new_peer.id)); + assert_eq!(rt.size(), MAX_BUCKET_ENTRY); + } + + #[test] + fn stale_bucket_targets() { + let mut rt = RoutingTable::new(local_id()); + let p = peer_at(0x01, 3000); + rt.add(p); + + // No stale buckets yet (just updated) + let targets = + rt.stale_bucket_targets(std::time::Duration::from_secs(0)); + + // At least the populated bucket should produce a target + assert!(!targets.is_empty()); + } + + #[test] + fn empty_table() { + let rt = RoutingTable::new(local_id()); + assert!(rt.is_empty()); + assert_eq!(rt.size(), 0); + assert!(rt.closest(&NodeId::from_bytes([0x01; 32]), 10).is_empty()); + } + + // ── Replacement cache tests ─────────────────── + + #[test] + fn bucket_full_adds_to_cache() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + + // Fill a bucket + for i in 0..MAX_BUCKET_ENTRY as u16 { + let mut bytes = [0xFF; 32]; + bytes[18] = (i >> 8) as u8; + bytes[19] = i as u8; + rt.add(PeerInfo::new( + NodeId::from_bytes(bytes), + SocketAddr::from(([10, 0, i as u8, 1], 3000 + i)), + )); + } + assert_eq!(rt.replacement_cache_size(), 0); + + // Next insert goes to replacement cache + let mut extra = [0xFF; 32]; + extra[18] = 0xFE; + extra[19] = 0xFE; + rt.add(PeerInfo::new( + NodeId::from_bytes(extra), + SocketAddr::from(([10, 0, 250, 1], 9999)), + )); + assert_eq!(rt.replacement_cache_size(), 1); + } + + #[test] + fn stale_contact_replaced_on_insert() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + + // All peers have high bit set (byte 0 = 0xFF) + // so they all land in the same bucket (bucket 0, + // furthest). Vary low bytes to get unique IDs. + for i in 0..MAX_BUCKET_ENTRY as u8 { + let mut bytes = [0x00; 32]; + bytes[0] = 0xFF; // same high bit → same bucket + bytes[31] = i; + rt.add(PeerInfo::new( + NodeId::from_bytes(bytes), + SocketAddr::from(([10, 0, i, 1], 3000 + i as u16)), + )); + } + + // Target: the first peer (bytes[31] = 0) + let mut first = [0x00; 32]; + first[0] = 0xFF; + first[31] = 0; + let first_id = NodeId::from_bytes(first); + + // Record failures until stale + for _ in 0..STALE_THRESHOLD { + rt.record_failure(&first_id); + } + + // Next insert to same bucket should replace stale + let mut new = [0x00; 32]; + new[0] = 0xFF; + new[31] = 0xFE; + let new_id = NodeId::from_bytes(new); + let result = rt.add(PeerInfo::new( + new_id, + SocketAddr::from(([10, 0, 254, 1], 9999)), + )); + assert!(matches!(result, InsertResult::Inserted)); + assert!(!rt.has_id(&first_id)); + assert!(rt.has_id(&new_id)); + } + + #[test] + fn record_failure_replaces_with_cache() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + + // All in same bucket (byte 0 = 0xFF) + for i in 0..MAX_BUCKET_ENTRY as u8 { + let mut bytes = [0x00; 32]; + bytes[0] = 0xFF; + bytes[31] = i; + rt.add(PeerInfo::new( + NodeId::from_bytes(bytes), + SocketAddr::from(([10, 0, i, 1], 3000 + i as u16)), + )); + } + + let mut target = [0x00; 32]; + target[0] = 0xFF; + target[31] = 0; + let target_id = NodeId::from_bytes(target); + + // Add a replacement to cache (same bucket) + let mut cache = [0x00; 32]; + cache[0] = 0xFF; + cache[31] = 0xFD; + let cache_id = NodeId::from_bytes(cache); + rt.add(PeerInfo::new( + cache_id, + SocketAddr::from(([10, 0, 253, 1], 8888)), + )); + assert_eq!(rt.replacement_cache_size(), 1); + + // Record failures until replacement happens + for _ in 0..STALE_THRESHOLD { + rt.record_failure(&target_id); + } + assert!(!rt.has_id(&target_id)); + assert!(rt.has_id(&cache_id)); + assert_eq!(rt.replacement_cache_size(), 0); + } + + #[test] + fn pinned_not_stale() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + let p = peer_at(0xFF, 3000); + rt.add(p.clone()); + rt.pin(p.id); + + // Failures should not evict pinned node + for _ in 0..10 { + assert!(rt.record_failure(&p.id).is_none()); + } + assert!(rt.has_id(&p.id)); + } + + #[test] + fn mark_seen_clears_stale() { + let lid = NodeId::from_bytes([0x00; 32]); + let mut rt = RoutingTable::new(lid); + let p = peer_at(0xFF, 3000); + rt.add(p.clone()); + + // Accumulate failures (but not enough to replace) + rt.record_failure(&p.id); + rt.record_failure(&p.id); + + // Successful contact clears stale count + rt.mark_seen(&p.id); + + // More failures needed now + rt.record_failure(&p.id); + rt.record_failure(&p.id); + // Still not stale (count reset to 0, now at 2 < 3) + assert!(rt.has_id(&p.id)); + } +} diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..7081ff5 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,159 @@ +//! UDP I/O via mio (kqueue on OpenBSD, epoll on Linux). +//! +//! Supports dynamic registration of additional sockets +//! (needed by NAT detector's temporary probe socket). + +use std::net::SocketAddr; +use std::time::Duration; + +use mio::net::UdpSocket; +use mio::{Events, Interest, Poll, Registry, Token}; + +use crate::error::Error; + +/// Token for the primary DHT socket. +pub const UDP_TOKEN: Token = Token(0); + +/// Event-driven network I/O loop. +/// +/// Wraps a mio `Poll` with a primary UDP socket. +/// Additional sockets can be registered dynamically +/// (e.g. for NAT detection probes). +pub struct NetLoop { + poll: Poll, + events: Events, + socket: UdpSocket, + next_token: usize, +} + +impl NetLoop { + /// Bind a UDP socket and register it with the poller. + pub fn bind(addr: SocketAddr) -> Result<Self, Error> { + let poll = Poll::new()?; + let mut socket = UdpSocket::bind(addr)?; + poll.registry() + .register(&mut socket, UDP_TOKEN, Interest::READABLE)?; + Ok(Self { + poll, + events: Events::with_capacity(64), + socket, + next_token: 1, + }) + } + + /// Send a datagram to the given address. + pub fn send_to( + &self, + buf: &[u8], + addr: SocketAddr, + ) -> Result<usize, Error> { + self.socket.send_to(buf, addr).map_err(Error::Io) + } + + /// Receive a datagram from the primary socket. + /// + /// Returns `(bytes_read, sender_address)` or + /// `WouldBlock` if no data available. + pub fn recv_from( + &self, + buf: &mut [u8], + ) -> Result<(usize, SocketAddr), Error> { + self.socket.recv_from(buf).map_err(Error::Io) + } + + /// Poll for I/O events, blocking up to `timeout`. + pub fn poll_events(&mut self, timeout: Duration) -> Result<(), Error> { + self.poll.poll(&mut self.events, Some(timeout))?; + Ok(()) + } + + /// Iterate over events from the last `poll_events` call. + pub fn drain_events(&self) -> impl Iterator<Item = &mio::event::Event> { + self.events.iter() + } + + /// Get the local address of the primary socket. + pub fn local_addr(&self) -> Result<SocketAddr, Error> { + self.socket.local_addr().map_err(Error::Io) + } + + /// Access the mio registry for registering additional + /// sockets (e.g. NAT detection probe sockets). + pub fn registry(&self) -> &Registry { + self.poll.registry() + } + + /// Allocate a new unique token for a dynamic socket. + /// Wraps at usize::MAX, skips 0 (reserved for + /// UDP_TOKEN). + pub fn next_token(&mut self) -> Token { + let t = Token(self.next_token); + self.next_token = self.next_token.wrapping_add(1).max(1); + t + } + + /// Register an additional UDP socket with the poller. + /// + /// Returns the assigned token. + pub fn register_socket( + &mut self, + socket: &mut UdpSocket, + ) -> Result<Token, Error> { + let token = self.next_token(); + self.poll + .registry() + .register(socket, token, Interest::READABLE)?; + Ok(token) + } + + /// Deregister a previously registered socket. + pub fn deregister_socket( + &mut self, + socket: &mut UdpSocket, + ) -> Result<(), Error> { + self.poll.registry().deregister(socket)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bind_and_local_addr() { + let net = NetLoop::bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = net.local_addr().unwrap(); + assert_eq!(addr.ip(), "127.0.0.1".parse::<std::net::IpAddr>().unwrap()); + assert_ne!(addr.port(), 0); + } + + #[test] + fn send_recv_loopback() { + let net = NetLoop::bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = net.local_addr().unwrap(); + + // Send to self + let sent = net.send_to(b"ping", addr).unwrap(); + assert_eq!(sent, 4); + + // Poll for the event + let mut net = net; + net.poll_events(Duration::from_millis(100)).unwrap(); + + let mut buf = [0u8; 64]; + let (len, from) = net.recv_from(&mut buf).unwrap(); + assert_eq!(&buf[..len], b"ping"); + assert_eq!(from, addr); + } + + #[test] + fn register_extra_socket() { + let mut net = NetLoop::bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let mut extra = + UdpSocket::bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let token = net.register_socket(&mut extra).unwrap(); + assert_ne!(token, UDP_TOKEN); + net.deregister_socket(&mut extra).unwrap(); + } +} diff --git a/src/store_track.rs b/src/store_track.rs new file mode 100644 index 0000000..a4ac78d --- /dev/null +++ b/src/store_track.rs @@ -0,0 +1,275 @@ +//! Store acknowledgment tracking. +//! +//! Tracks which STORE operations have been acknowledged +//! by remote peers. Failed stores are retried with +//! alternative peers to maintain data redundancy. + +use std::collections::HashMap; +use std::time::{Duration, Instant}; + +use crate::id::NodeId; +use crate::peers::PeerInfo; + +/// Maximum retry attempts before giving up. +const MAX_RETRIES: u32 = 3; + +/// Time to wait for a store acknowledgment before +/// considering it failed. +const STORE_TIMEOUT: Duration = Duration::from_secs(5); + +/// Interval between retry sweeps. +pub const RETRY_INTERVAL: Duration = Duration::from_secs(30); + +/// Tracks a pending STORE operation. +#[derive(Debug, Clone)] +struct PendingStore { + /// Target NodeId (SHA-256 of key). + target: NodeId, + /// Raw key bytes. + key: Vec<u8>, + /// Value bytes. + value: Vec<u8>, + /// TTL at time of store. + ttl: u16, + /// Whether the value is unique. + is_unique: bool, + /// Peer we sent the STORE to. + peer: PeerInfo, + /// When the STORE was sent. + sent_at: Instant, + /// Number of retry attempts. + retries: u32, +} + +/// Tracks pending and failed STORE operations. +pub struct StoreTracker { + /// Pending stores keyed by (nonce, peer_addr). + pending: HashMap<(NodeId, Vec<u8>), Vec<PendingStore>>, + /// Total successful stores. + pub acks: u64, + /// Total failed stores (exhausted retries). + pub failures: u64, +} + +impl StoreTracker { + pub fn new() -> Self { + Self { + pending: HashMap::new(), + acks: 0, + failures: 0, + } + } + + /// Record that a STORE was sent to a peer. + pub fn track( + &mut self, + target: NodeId, + key: Vec<u8>, + value: Vec<u8>, + ttl: u16, + is_unique: bool, + peer: PeerInfo, + ) { + let entry = PendingStore { + target, + key: key.clone(), + value, + ttl, + is_unique, + peer, + sent_at: Instant::now(), + retries: 0, + }; + self.pending.entry((target, key)).or_default().push(entry); + } + + /// Record a successful store acknowledgment from + /// a peer (they stored our value). + pub fn ack(&mut self, target: &NodeId, key: &[u8], peer_id: &NodeId) { + let k = (*target, key.to_vec()); + if let Some(stores) = self.pending.get_mut(&k) { + let before = stores.len(); + stores.retain(|s| s.peer.id != *peer_id); + let removed = before - stores.len(); + self.acks += removed as u64; + if stores.is_empty() { + self.pending.remove(&k); + } + } + } + + /// Collect stores that timed out and need retry. + /// Returns (target, key, value, ttl, is_unique, failed_peer) + /// for each timed-out store. + pub fn collect_timeouts(&mut self) -> Vec<RetryInfo> { + let mut retries = Vec::new(); + let mut exhausted_keys = Vec::new(); + + for (k, stores) in &mut self.pending { + stores.retain_mut(|s| { + if s.sent_at.elapsed() < STORE_TIMEOUT { + return true; // still waiting + } + if s.retries >= MAX_RETRIES { + // Exhausted retries + return false; + } + s.retries += 1; + retries.push(RetryInfo { + target: s.target, + key: s.key.clone(), + value: s.value.clone(), + ttl: s.ttl, + is_unique: s.is_unique, + failed_peer: s.peer.id, + }); + false // remove from pending (will be re-tracked if retried) + }); + if stores.is_empty() { + exhausted_keys.push(k.clone()); + } + } + + self.failures += exhausted_keys.len() as u64; + for k in &exhausted_keys { + self.pending.remove(k); + } + + retries + } + + /// Remove all expired tracking entries (older than + /// 2x timeout, cleanup safety net). + pub fn cleanup(&mut self) { + let cutoff = STORE_TIMEOUT * 2; + self.pending.retain(|_, stores| { + stores.retain(|s| s.sent_at.elapsed() < cutoff); + !stores.is_empty() + }); + } + + /// Number of pending store operations. + pub fn pending_count(&self) -> usize { + self.pending.values().map(|v| v.len()).sum() + } +} + +impl Default for StoreTracker { + fn default() -> Self { + Self::new() + } +} + +/// Information needed to retry a failed store. +pub struct RetryInfo { + pub target: NodeId, + pub key: Vec<u8>, + pub value: Vec<u8>, + pub ttl: u16, + pub is_unique: bool, + pub failed_peer: NodeId, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + + fn peer(byte: u8, port: u16) -> PeerInfo { + PeerInfo::new( + NodeId::from_bytes([byte; 32]), + SocketAddr::from(([127, 0, 0, 1], port)), + ) + } + + #[test] + fn track_and_ack() { + let mut t = StoreTracker::new(); + let target = NodeId::from_key(b"k"); + let p = peer(0x01, 3000); + t.track(target, b"k".to_vec(), b"v".to_vec(), 300, false, p); + assert_eq!(t.pending_count(), 1); + + t.ack(&target, b"k", &NodeId::from_bytes([0x01; 32])); + assert_eq!(t.pending_count(), 0); + assert_eq!(t.acks, 1); + } + + #[test] + fn timeout_triggers_retry() { + let mut t = StoreTracker::new(); + let target = NodeId::from_key(b"k"); + let p = peer(0x01, 3000); + t.track(target, b"k".to_vec(), b"v".to_vec(), 300, false, p); + + // No timeouts yet + assert!(t.collect_timeouts().is_empty()); + + // Force timeout by waiting + std::thread::sleep(Duration::from_millis(10)); + + // Hack: modify sent_at to force timeout + for stores in t.pending.values_mut() { + for s in stores.iter_mut() { + s.sent_at = + Instant::now() - STORE_TIMEOUT - Duration::from_secs(1); + } + } + + let retries = t.collect_timeouts(); + assert_eq!(retries.len(), 1); + assert_eq!(retries[0].key, b"k"); + assert_eq!(retries[0].failed_peer, NodeId::from_bytes([0x01; 32])); + } + + #[test] + fn multiple_peers_tracked() { + let mut t = StoreTracker::new(); + let target = NodeId::from_key(b"k"); + t.track( + target, + b"k".to_vec(), + b"v".to_vec(), + 300, + false, + peer(0x01, 3000), + ); + t.track( + target, + b"k".to_vec(), + b"v".to_vec(), + 300, + false, + peer(0x02, 3001), + ); + assert_eq!(t.pending_count(), 2); + + // Ack from one peer + t.ack(&target, b"k", &NodeId::from_bytes([0x01; 32])); + assert_eq!(t.pending_count(), 1); + } + + #[test] + fn cleanup_removes_old() { + let mut t = StoreTracker::new(); + let target = NodeId::from_key(b"k"); + t.track( + target, + b"k".to_vec(), + b"v".to_vec(), + 300, + false, + peer(0x01, 3000), + ); + + // Force old timestamp + for stores in t.pending.values_mut() { + for s in stores.iter_mut() { + s.sent_at = Instant::now() - STORE_TIMEOUT * 3; + } + } + + t.cleanup(); + assert_eq!(t.pending_count(), 0); + } +} diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..e4d5b7e --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,127 @@ +//! Cryptographically secure random bytes. +//! +//! Uses the best available platform source: +//! - OpenBSD/macOS: arc4random_buf(3) +//! - Linux/FreeBSD: getrandom(2) +//! - Fallback: /dev/urandom + +/// Fill buffer with cryptographically secure random +/// bytes. +pub fn random_bytes(buf: &mut [u8]) { + platform::fill(buf); +} + +#[cfg(any(target_os = "openbsd", target_os = "macos"))] +mod platform { + use std::ffi::c_void; + + // SAFETY: arc4random_buf always fills the entire + // buffer. Pointer valid from mutable slice. + unsafe extern "C" { + fn arc4random_buf(buf: *mut c_void, nbytes: usize); + } + + pub fn fill(buf: &mut [u8]) { + unsafe { + arc4random_buf(buf.as_mut_ptr() as *mut c_void, buf.len()); + } + } +} + +#[cfg(target_os = "linux")] +mod platform { + pub fn fill(buf: &mut [u8]) { + // getrandom(2) — available since Linux 3.17 + // Flags: 0 = block until entropy available + let ret = unsafe { + libc_getrandom( + buf.as_mut_ptr() as *mut std::ffi::c_void, + buf.len(), + 0, + ) + }; + if ret < 0 { + // Fallback to /dev/urandom + urandom_fill(buf); + } + } + + unsafe extern "C" { + fn getrandom( + buf: *mut std::ffi::c_void, + buflen: usize, + flags: std::ffi::c_uint, + ) -> isize; + } + + // Rename to avoid conflict with the syscall + unsafe fn libc_getrandom( + buf: *mut std::ffi::c_void, + buflen: usize, + flags: std::ffi::c_uint, + ) -> isize { + getrandom(buf, buflen, flags) + } + + fn urandom_fill(buf: &mut [u8]) { + use std::io::Read; + let mut f = std::fs::File::open("/dev/urandom").expect( + "FATAL: cannot open /dev/urandom — no secure randomness available", + ); + f.read_exact(buf).expect("FATAL: cannot read /dev/urandom"); + } +} + +#[cfg(target_os = "freebsd")] +mod platform { + use std::ffi::c_void; + + unsafe extern "C" { + fn arc4random_buf(buf: *mut c_void, nbytes: usize); + } + + pub fn fill(buf: &mut [u8]) { + unsafe { + arc4random_buf(buf.as_mut_ptr() as *mut c_void, buf.len()); + } + } +} + +#[cfg(not(any( + target_os = "openbsd", + target_os = "macos", + target_os = "linux", + target_os = "freebsd" +)))] +mod platform { + pub fn fill(buf: &mut [u8]) { + use std::io::Read; + let mut f = std::fs::File::open("/dev/urandom").expect( + "FATAL: cannot open /dev/urandom — no secure randomness available", + ); + f.read_exact(buf).expect("FATAL: cannot read /dev/urandom"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn random_bytes_not_zero() { + let mut buf = [0u8; 32]; + random_bytes(&mut buf); + + // Probability of all zeros: 2^-256 + assert!(buf.iter().any(|&b| b != 0)); + } + + #[test] + fn random_bytes_different_calls() { + let mut a = [0u8; 32]; + let mut b = [0u8; 32]; + random_bytes(&mut a); + random_bytes(&mut b); + assert_ne!(a, b); + } +} diff --git a/src/timer.rs b/src/timer.rs new file mode 100644 index 0000000..e3d7b62 --- /dev/null +++ b/src/timer.rs @@ -0,0 +1,221 @@ +//! Timer wheel for scheduling periodic and one-shot +//! callbacks without threads or async. +//! +//! The event loop calls `tick()` each iteration to +//! fire expired timers. + +use std::collections::{BTreeMap, HashMap}; +use std::time::{Duration, Instant}; + +/// Opaque timer identifier. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TimerId(u64); + +/// A simple timer wheel driven by the event loop. +/// +/// Timers are stored in a BTreeMap keyed by deadline, +/// which gives O(log n) insert/cancel and efficient +/// scanning of expired timers. +pub struct TimerWheel { + deadlines: BTreeMap<Instant, Vec<TimerId>>, + intervals: HashMap<TimerId, Duration>, + next_id: u64, +} + +impl TimerWheel { + pub fn new() -> Self { + Self { + deadlines: BTreeMap::new(), + intervals: HashMap::new(), + next_id: 0, + } + } + + /// Schedule a one-shot timer that fires after `delay`. + pub fn schedule(&mut self, delay: Duration) -> TimerId { + let id = self.alloc_id(); + let deadline = Instant::now() + delay; + self.deadlines.entry(deadline).or_default().push(id); + id + } + + /// Schedule a repeating timer that fires every + /// `interval`, starting after one interval. + pub fn schedule_repeating(&mut self, interval: Duration) -> TimerId { + let id = self.alloc_id(); + let deadline = Instant::now() + interval; + self.deadlines.entry(deadline).or_default().push(id); + self.intervals.insert(id, interval); + id + } + + /// Cancel a pending timer. + /// + /// Returns `true` if the timer was found and removed. + pub fn cancel(&mut self, id: TimerId) -> bool { + self.intervals.remove(&id); + + let mut found = false; + + // Remove from deadline map + let mut empty_keys = Vec::new(); + for (key, ids) in self.deadlines.iter_mut() { + if let Some(pos) = ids.iter().position(|i| *i == id) { + ids.swap_remove(pos); + found = true; + if ids.is_empty() { + empty_keys.push(*key); + } + break; + } + } + for k in empty_keys { + self.deadlines.remove(&k); + } + found + } + + /// Fire all expired timers, returning their IDs. + /// + /// Repeating timers are automatically rescheduled. + /// Call this once per event loop iteration. + pub fn tick(&mut self) -> Vec<TimerId> { + let now = Instant::now(); + let mut fired = Vec::new(); + + // Collect all deadlines <= now + let expired: Vec<Instant> = + self.deadlines.range(..=now).map(|(k, _)| *k).collect(); + + for deadline in expired { + if let Some(ids) = self.deadlines.remove(&deadline) { + for id in ids { + fired.push(id); + + // Reschedule if repeating + if let Some(&interval) = self.intervals.get(&id) { + let next = now + interval; + self.deadlines.entry(next).or_default().push(id); + } + } + } + } + + fired + } + + /// Duration until the next timer fires, or `None` + /// if no timers are scheduled. + /// + /// Useful for determining the poll timeout. + pub fn next_deadline(&self) -> Option<Duration> { + self.deadlines.keys().next().map(|deadline| { + let now = Instant::now(); + if *deadline <= now { + Duration::ZERO + } else { + *deadline - now + } + }) + } + + /// Number of pending timers. + pub fn pending_count(&self) -> usize { + self.deadlines.values().map(|v| v.len()).sum() + } + + /// Check if there are no pending timers. + pub fn is_empty(&self) -> bool { + self.deadlines.is_empty() + } + + fn alloc_id(&mut self) -> TimerId { + let id = TimerId(self.next_id); + self.next_id += 1; + id + } +} + +impl Default for TimerWheel { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + #[test] + fn schedule_and_tick() { + let mut tw = TimerWheel::new(); + let id = tw.schedule(Duration::from_millis(10)); + + // Not yet expired + assert!(tw.tick().is_empty()); + assert_eq!(tw.pending_count(), 1); + + // Wait and tick + thread::sleep(Duration::from_millis(15)); + let fired = tw.tick(); + assert_eq!(fired, vec![id]); + assert!(tw.is_empty()); + } + + #[test] + fn cancel_timer() { + let mut tw = TimerWheel::new(); + let id = tw.schedule(Duration::from_millis(100)); + assert!(tw.cancel(id)); + assert!(tw.is_empty()); + + // Cancel non-existent returns false + assert!(!tw.cancel(TimerId(999))); + } + + #[test] + fn repeating_timer() { + let mut tw = TimerWheel::new(); + let id = tw.schedule_repeating(Duration::from_millis(10)); + + thread::sleep(Duration::from_millis(15)); + let fired = tw.tick(); + assert_eq!(fired, vec![id]); + + // Should be rescheduled, not empty + assert_eq!(tw.pending_count(), 1); + + // Cancel the repeating timer + tw.cancel(id); + assert!(tw.is_empty()); + } + + #[test] + fn next_deadline_empty() { + let tw = TimerWheel::new(); + assert!(tw.next_deadline().is_none()); + } + + #[test] + fn next_deadline_returns_duration() { + let mut tw = TimerWheel::new(); + tw.schedule(Duration::from_secs(10)); + let d = tw.next_deadline().unwrap(); + assert!(d <= Duration::from_secs(10)); + assert!(d > Duration::from_secs(9)); + } + + #[test] + fn multiple_timers_fire_in_order() { + let mut tw = TimerWheel::new(); + let a = tw.schedule(Duration::from_millis(5)); + let b = tw.schedule(Duration::from_millis(10)); + + thread::sleep(Duration::from_millis(15)); + let fired = tw.tick(); + assert!(fired.contains(&a)); + assert!(fired.contains(&b)); + assert!(tw.is_empty()); + } +} diff --git a/src/wire.rs b/src/wire.rs new file mode 100644 index 0000000..3d80c3b --- /dev/null +++ b/src/wire.rs @@ -0,0 +1,368 @@ +//! On-the-wire binary protocol. +//! +//! Defines message header, message types, and +//! serialization for all protocol messages. Maintains +//! the same field layout and byte order for +//! structural reference. + +use crate::error::Error; +use crate::id::{ID_LEN, NodeId}; + +// ── Protocol constants ────────────────────────────── + +/// Protocol magic number: 0x7E55 ("TESS"). +pub const MAGIC_NUMBER: u16 = 0x7E55; + +/// Protocol version. +pub const TESSERAS_DHT_VERSION: u8 = 0; + +/// Size of the message header in bytes. +/// +/// magic(2) + ver(1) + type(1) + len(2) + reserved(2) +/// + src(32) + dst(32) = 72 +pub const HEADER_SIZE: usize = 8 + ID_LEN * 2; + +/// Ed25519 signature appended to authenticated packets. +pub const SIGNATURE_SIZE: usize = crate::crypto::SIGNATURE_SIZE; + +// ── Address domains ───────────────────────────────── + +pub const DOMAIN_LOOPBACK: u16 = 0; +pub const DOMAIN_INET: u16 = 1; +pub const DOMAIN_INET6: u16 = 2; + +// ── Node state on the wire ────────────────────────── + +pub const STATE_GLOBAL: u16 = 1; +pub const STATE_NAT: u16 = 2; + +// ── Message types ─────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum MsgType { + // Datagram + Dgram = 0x01, + + // Advertise + Advertise = 0x02, + AdvertiseReply = 0x03, + + // NAT detection + NatEcho = 0x11, + NatEchoReply = 0x12, + NatEchoRedirect = 0x13, + NatEchoRedirectReply = 0x14, + + // DTUN + DtunPing = 0x21, + DtunPingReply = 0x22, + DtunFindNode = 0x23, + DtunFindNodeReply = 0x24, + DtunFindValue = 0x25, + DtunFindValueReply = 0x26, + DtunRegister = 0x27, + DtunRequest = 0x28, + DtunRequestBy = 0x29, + DtunRequestReply = 0x2A, + + // DHT + DhtPing = 0x41, + DhtPingReply = 0x42, + DhtFindNode = 0x43, + DhtFindNodeReply = 0x44, + DhtFindValue = 0x45, + DhtFindValueReply = 0x46, + DhtStore = 0x47, + + // Proxy + ProxyRegister = 0x81, + ProxyRegisterReply = 0x82, + ProxyStore = 0x83, + ProxyGet = 0x84, + ProxyGetReply = 0x85, + ProxyDgram = 0x86, + ProxyDgramForwarded = 0x87, + ProxyRdp = 0x88, + ProxyRdpForwarded = 0x89, + + // RDP + Rdp = 0x90, +} + +impl MsgType { + pub fn from_u8(v: u8) -> Result<Self, Error> { + match v { + 0x01 => Ok(MsgType::Dgram), + 0x02 => Ok(MsgType::Advertise), + 0x03 => Ok(MsgType::AdvertiseReply), + 0x11 => Ok(MsgType::NatEcho), + 0x12 => Ok(MsgType::NatEchoReply), + 0x13 => Ok(MsgType::NatEchoRedirect), + 0x14 => Ok(MsgType::NatEchoRedirectReply), + 0x21 => Ok(MsgType::DtunPing), + 0x22 => Ok(MsgType::DtunPingReply), + 0x23 => Ok(MsgType::DtunFindNode), + 0x24 => Ok(MsgType::DtunFindNodeReply), + 0x25 => Ok(MsgType::DtunFindValue), + 0x26 => Ok(MsgType::DtunFindValueReply), + 0x27 => Ok(MsgType::DtunRegister), + 0x28 => Ok(MsgType::DtunRequest), + 0x29 => Ok(MsgType::DtunRequestBy), + 0x2A => Ok(MsgType::DtunRequestReply), + 0x41 => Ok(MsgType::DhtPing), + 0x42 => Ok(MsgType::DhtPingReply), + 0x43 => Ok(MsgType::DhtFindNode), + 0x44 => Ok(MsgType::DhtFindNodeReply), + 0x45 => Ok(MsgType::DhtFindValue), + 0x46 => Ok(MsgType::DhtFindValueReply), + 0x47 => Ok(MsgType::DhtStore), + 0x81 => Ok(MsgType::ProxyRegister), + 0x82 => Ok(MsgType::ProxyRegisterReply), + 0x83 => Ok(MsgType::ProxyStore), + 0x84 => Ok(MsgType::ProxyGet), + 0x85 => Ok(MsgType::ProxyGetReply), + 0x86 => Ok(MsgType::ProxyDgram), + 0x87 => Ok(MsgType::ProxyDgramForwarded), + 0x88 => Ok(MsgType::ProxyRdp), + 0x89 => Ok(MsgType::ProxyRdpForwarded), + 0x90 => Ok(MsgType::Rdp), + _ => Err(Error::UnknownMessageType(v)), + } + } +} + +// ── Response flags ────────────────────────────────── + +pub const DATA_ARE_NODES: u8 = 0xa0; +pub const DATA_ARE_VALUES: u8 = 0xa1; +pub const DATA_ARE_NUL: u8 = 0xa2; +pub const GET_BY_UDP: u8 = 0xb0; +pub const GET_BY_RDP: u8 = 0xb1; +pub const DHT_FLAG_UNIQUE: u8 = 0x01; +pub const DHT_GET_NEXT: u8 = 0xc0; +pub const PROXY_GET_SUCCESS: u8 = 0xd0; +pub const PROXY_GET_FAIL: u8 = 0xd1; +pub const PROXY_GET_NEXT: u8 = 0xd2; + +// ── Message header ────────────────────────────────── + +/// Parsed message header (48 bytes on the wire). +#[derive(Debug, Clone)] +pub struct MsgHeader { + pub magic: u16, + pub ver: u8, + pub msg_type: MsgType, + pub len: u16, + pub src: NodeId, + pub dst: NodeId, +} + +impl MsgHeader { + /// Parse a header from a byte buffer. + pub fn parse(buf: &[u8]) -> Result<Self, Error> { + if buf.len() < HEADER_SIZE { + return Err(Error::BufferTooSmall); + } + + let magic = u16::from_be_bytes([buf[0], buf[1]]); + if magic != MAGIC_NUMBER { + return Err(Error::BadMagic(magic)); + } + + let ver = buf[2]; + if ver != TESSERAS_DHT_VERSION { + return Err(Error::UnsupportedVersion(ver)); + } + + let msg_type = MsgType::from_u8(buf[3])?; + let len = u16::from_be_bytes([buf[4], buf[5]]); + + // buf[6..8] reserved + + let src = NodeId::read_from(&buf[8..8 + ID_LEN]); + let dst = NodeId::read_from(&buf[8 + ID_LEN..8 + ID_LEN * 2]); + + Ok(Self { + magic, + ver, + msg_type, + len, + src, + dst, + }) + } + + /// Write header to a byte buffer. Returns bytes written + /// (always HEADER_SIZE). + pub fn write(&self, buf: &mut [u8]) -> Result<usize, Error> { + if buf.len() < HEADER_SIZE { + return Err(Error::BufferTooSmall); + } + + buf[0..2].copy_from_slice(&MAGIC_NUMBER.to_be_bytes()); + buf[2] = TESSERAS_DHT_VERSION; + buf[3] = self.msg_type as u8; + buf[4..6].copy_from_slice(&self.len.to_be_bytes()); + buf[6] = 0; // reserved + buf[7] = 0; + self.src.write_to(&mut buf[8..8 + ID_LEN]); + self.dst.write_to(&mut buf[8 + ID_LEN..8 + ID_LEN * 2]); + + Ok(HEADER_SIZE) + } + + /// Create a new header for sending. + pub fn new( + msg_type: MsgType, + total_len: u16, + src: NodeId, + dst: NodeId, + ) -> Self { + Self { + magic: MAGIC_NUMBER, + ver: TESSERAS_DHT_VERSION, + msg_type, + len: total_len, + src, + dst, + } + } +} + +/// Append Ed25519 signature to a packet buffer. +/// +/// Signs the entire buffer (header + body) using the +/// sender's private key. Appends 64-byte signature. +pub fn sign_packet(buf: &mut Vec<u8>, identity: &crate::crypto::Identity) { + let sig = identity.sign(buf); + buf.extend_from_slice(&sig); +} + +/// Verify Ed25519 signature on a received packet. +/// +/// The last 64 bytes of `buf` are the signature. +/// `sender_pubkey` is the sender's 32-byte Ed25519 +/// public key. +/// +/// Returns `true` if the signature is valid. +pub fn verify_packet( + buf: &[u8], + sender_pubkey: &[u8; crate::crypto::PUBLIC_KEY_SIZE], +) -> bool { + if buf.len() < HEADER_SIZE + SIGNATURE_SIZE { + return false; + } + let (data, sig) = buf.split_at(buf.len() - SIGNATURE_SIZE); + crate::crypto::Identity::verify(sender_pubkey, data, sig) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_header() -> MsgHeader { + MsgHeader::new( + MsgType::DhtPing, + 52, + NodeId::from_bytes([0xAA; ID_LEN]), + NodeId::from_bytes([0xBB; ID_LEN]), + ) + } + + #[test] + fn header_roundtrip() { + let hdr = make_header(); + let mut buf = [0u8; HEADER_SIZE]; + hdr.write(&mut buf).unwrap(); + let parsed = MsgHeader::parse(&buf).unwrap(); + + assert_eq!(parsed.magic, MAGIC_NUMBER); + assert_eq!(parsed.ver, TESSERAS_DHT_VERSION); + assert_eq!(parsed.msg_type, MsgType::DhtPing); + assert_eq!(parsed.len, 52); + assert_eq!(parsed.src, hdr.src); + assert_eq!(parsed.dst, hdr.dst); + } + + #[test] + fn reject_bad_magic() { + let mut buf = [0u8; HEADER_SIZE]; + buf[0] = 0xBA; + buf[1] = 0xBE; // wrong magic + let err = MsgHeader::parse(&buf); + assert!(matches!(err, Err(Error::BadMagic(0xBABE)))); + } + + #[test] + fn reject_bad_version() { + let hdr = make_header(); + let mut buf = [0u8; HEADER_SIZE]; + hdr.write(&mut buf).unwrap(); + buf[2] = 99; // bad version + let err = MsgHeader::parse(&buf); + assert!(matches!(err, Err(Error::UnsupportedVersion(99)))); + } + + #[test] + fn reject_unknown_type() { + let hdr = make_header(); + let mut buf = [0u8; HEADER_SIZE]; + hdr.write(&mut buf).unwrap(); + buf[3] = 0xFF; // unknown type + let err = MsgHeader::parse(&buf); + assert!(matches!(err, Err(Error::UnknownMessageType(0xFF)))); + } + + #[test] + fn reject_truncated() { + let err = MsgHeader::parse(&[0u8; 10]); + assert!(matches!(err, Err(Error::BufferTooSmall))); + } + + #[test] + fn all_msg_types_roundtrip() { + let types = [ + MsgType::Dgram, + MsgType::Advertise, + MsgType::AdvertiseReply, + MsgType::NatEcho, + MsgType::NatEchoReply, + MsgType::NatEchoRedirect, + MsgType::NatEchoRedirectReply, + MsgType::DtunPing, + MsgType::DtunPingReply, + MsgType::DtunFindNode, + MsgType::DtunFindNodeReply, + MsgType::DtunFindValue, + MsgType::DtunFindValueReply, + MsgType::DtunRegister, + MsgType::DtunRequest, + MsgType::DtunRequestBy, + MsgType::DtunRequestReply, + MsgType::DhtPing, + MsgType::DhtPingReply, + MsgType::DhtFindNode, + MsgType::DhtFindNodeReply, + MsgType::DhtFindValue, + MsgType::DhtFindValueReply, + MsgType::DhtStore, + MsgType::ProxyRegister, + MsgType::ProxyRegisterReply, + MsgType::ProxyStore, + MsgType::ProxyGet, + MsgType::ProxyGetReply, + MsgType::ProxyDgram, + MsgType::ProxyDgramForwarded, + MsgType::ProxyRdp, + MsgType::ProxyRdpForwarded, + MsgType::Rdp, + ]; + + for &t in &types { + let val = t as u8; + let parsed = MsgType::from_u8(val).unwrap(); + assert_eq!(parsed, t, "roundtrip failed for 0x{val:02x}"); + } + } +} diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..8bfe66d --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,704 @@ +//! Integration tests: multi-node scenarios over real +//! UDP sockets on loopback. + +use std::time::Duration; + +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +/// Poll all nodes once each (non-blocking best-effort). +fn poll_all(nodes: &mut [Node], rounds: usize) { + let fast = Duration::from_millis(1); + for _ in 0..rounds { + for n in nodes.iter_mut() { + n.poll_timeout(fast).ok(); + } + } +} + +/// Create N nodes, join them to the first. +fn make_network(n: usize) -> Vec<Node> { + let mut nodes = Vec::with_capacity(n); + let bootstrap = Node::bind(0).unwrap(); + let bp = bootstrap.local_addr().unwrap().port(); + nodes.push(bootstrap); + nodes[0].set_nat_state(NatState::Global); + + for _ in 1..n { + let mut node = Node::bind(0).unwrap(); + node.set_nat_state(NatState::Global); + node.join("127.0.0.1", bp).unwrap(); + nodes.push(node); + } + + // Small sleep to let packets arrive, then poll + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + nodes +} + +// ── Bootstrap tests ───────────────────────────────── + +#[test] +fn two_nodes_discover_each_other() { + let nodes = make_network(2); + + assert!( + nodes[0].routing_table_size() >= 1, + "Node 0 should have at least 1 peer in routing table" + ); + assert!( + nodes[1].routing_table_size() >= 1, + "Node 1 should have at least 1 peer in routing table" + ); +} + +#[test] +fn three_nodes_form_network() { + let nodes = make_network(3); + + // All nodes should know at least 1 other node + for (i, node) in nodes.iter().enumerate() { + assert!( + node.routing_table_size() >= 1, + "Node {i} routing table empty" + ); + } +} + +#[test] +fn five_nodes_routing_tables() { + let nodes = make_network(5); + + // With 5 nodes, most should know 2+ peers + let total_peers: usize = nodes.iter().map(|n| n.routing_table_size()).sum(); + assert!( + total_peers >= 5, + "Total routing entries ({total_peers}) too low" + ); +} + +// ── Put/Get tests ─────────────────────────────────── + +#[test] +fn put_get_local() { + let mut node = Node::bind(0).unwrap(); + node.put(b"key1", b"value1", 300, false); + let vals = node.get(b"key1"); + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], b"value1"); +} + +#[test] +fn put_get_across_two_nodes() { + let mut nodes = make_network(2); + + // Node 0 stores + nodes[0].put(b"hello", b"world", 300, false); + + // Poll to deliver STORE + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Node 1 should have the value (received via STORE) + let vals = nodes[1].get(b"hello"); + assert_eq!(vals.len(), 1, "Node 1 should have received the value"); + assert_eq!(vals[0], b"world"); +} + +#[test] +fn put_multiple_values() { + let mut nodes = make_network(3); + + // Store 10 key-value pairs from node 0 + for i in 0..10u32 { + let key = format!("k{i}"); + let val = format!("v{i}"); + nodes[0].put(key.as_bytes(), val.as_bytes(), 300, false); + } + + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Node 0 should have all 10 + let mut found = 0; + for i in 0..10u32 { + let key = format!("k{i}"); + if !nodes[0].get(key.as_bytes()).is_empty() { + found += 1; + } + } + assert_eq!(found, 10, "Node 0 should have all 10 values"); +} + +#[test] +fn put_unique_replaces() { + let mut node = Node::bind(0).unwrap(); + node.put(b"uk", b"first", 300, true); + node.put(b"uk", b"second", 300, true); + + let vals = node.get(b"uk"); + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], b"second"); +} + +#[test] +fn put_get_distributed() { + let mut nodes = make_network(5); + + // Each node stores one value + for i in 0..5u32 { + let key = format!("node{i}-key"); + let val = format!("node{i}-val"); + nodes[i as usize].put(key.as_bytes(), val.as_bytes(), 300, false); + } + + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Each node should have its own value at minimum + for i in 0..5u32 { + let key = format!("node{i}-key"); + let vals = nodes[i as usize].get(key.as_bytes()); + assert!(!vals.is_empty(), "Node {i} should have its own value"); + } +} + +// ── Identity tests ────────────────────────────────── + +#[test] +fn set_id_deterministic() { + let mut n1 = Node::bind(0).unwrap(); + let mut n2 = Node::bind(0).unwrap(); + n1.set_id(b"same-seed"); + n2.set_id(b"same-seed"); + assert_eq!(n1.id(), n2.id()); +} + +#[test] +fn node_id_is_unique() { + let n1 = Node::bind(0).unwrap(); + let n2 = Node::bind(0).unwrap(); + assert_ne!(n1.id(), n2.id()); +} + +// ── NAT state tests ──────────────────────────────── + +#[test] +fn nat_state_transitions() { + let mut node = Node::bind(0).unwrap(); + assert_eq!(node.nat_state(), NatState::Unknown); + + node.set_nat_state(NatState::Global); + assert_eq!(node.nat_state(), NatState::Global); + + node.set_nat_state(NatState::ConeNat); + assert_eq!(node.nat_state(), NatState::ConeNat); + + node.set_nat_state(NatState::SymmetricNat); + assert_eq!(node.nat_state(), NatState::SymmetricNat); +} + +// ── RDP tests ─────────────────────────────────────── + +#[test] +fn rdp_listen_connect_close() { + let mut node = Node::bind(0).unwrap(); + let desc = node.rdp_listen(5000).unwrap(); + node.rdp_close(desc); + // Should be able to re-listen + let desc2 = node.rdp_listen(5000).unwrap(); + node.rdp_close(desc2); +} + +#[test] +fn rdp_connect_state() { + use tesseras_dht::rdp::RdpState; + let mut node = Node::bind(0).unwrap(); + let dst = tesseras_dht::NodeId::from_bytes([0x01; 32]); + let desc = node.rdp_connect(0, &dst, 5000).unwrap(); + assert_eq!(node.rdp_state(desc).unwrap(), RdpState::SynSent); + node.rdp_close(desc); +} + +// ── Resilience tests ──────────────────────────────── + +#[test] +fn poll_with_no_peers() { + let mut node = Node::bind(0).unwrap(); + // Should not panic or block + node.poll().unwrap(); +} + +#[test] +fn join_invalid_address() { + let mut node = Node::bind(0).unwrap(); + let result = node.join("this-does-not-exist.invalid", 9999); + assert!(result.is_err()); +} + +#[test] +fn empty_get() { + let mut node = Node::bind(0).unwrap(); + assert!(node.get(b"nonexistent").is_empty()); +} + +#[test] +fn put_zero_ttl_removes() { + let mut node = Node::bind(0).unwrap(); + node.put(b"temp", b"data", 300, false); + assert!(!node.get(b"temp").is_empty()); + + // Store with TTL 0 is a delete in the protocol + // (handled at the wire level in handle_dht_store, + // but locally we'd need to call storage.remove). + // This test validates the local storage. +} + +// ── Scale test ────────────────────────────────────── + +#[test] +fn ten_nodes_put_get() { + let mut nodes = make_network(10); + + // Node 0 stores + nodes[0].put(b"scale-key", b"scale-val", 300, false); + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Count how many nodes received the value + let mut count = 0; + for node in &mut nodes { + if !node.get(b"scale-key").is_empty() { + count += 1; + } + } + assert!( + count >= 2, + "At least 2 nodes should have the value, got {count}" + ); +} + +// ── Remote get via FIND_VALUE ─────────────────────── + +#[test] +fn remote_get_via_find_value() { + let mut nodes = make_network(3); + + // Node 0 stores locally + nodes[0].put(b"remote-key", b"remote-val", 300, false); + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Node 2 does remote get + let before = nodes[2].get(b"remote-key"); + // Might already have it from STORE, or empty + if before.is_empty() { + // Poll to let FIND_VALUE propagate + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + let after = nodes[2].get(b"remote-key"); + assert!( + !after.is_empty(), + "Node 2 should find the value via FIND_VALUE" + ); + assert_eq!(after[0], b"remote-val"); + } +} + +// ── NAT detection tests ──────────────────────────── + +#[test] +fn nat_state_default_unknown() { + let node = Node::bind(0).unwrap(); + assert_eq!(node.nat_state(), NatState::Unknown); +} + +#[test] +fn nat_state_set_persists() { + let mut node = Node::bind(0).unwrap(); + node.set_nat_state(NatState::SymmetricNat); + assert_eq!(node.nat_state(), NatState::SymmetricNat); + node.poll().unwrap(); + assert_eq!(node.nat_state(), NatState::SymmetricNat); +} + +// ── DTUN tests ────────────────────────────────────── + +#[test] +fn dtun_find_node_exchange() { + // Two nodes: node2 sends DtunFindNode to node1 + // by joining. The DTUN table should be populated. + let nodes = make_network(2); + + // Both nodes should have peers after join + assert!(nodes[0].routing_table_size() >= 1); + assert!(nodes[1].routing_table_size() >= 1); +} + +// ── Proxy tests ───────────────────────────────────── + +#[test] +fn proxy_dgram_forwarded() { + use std::sync::{Arc, Mutex}; + + let mut nodes = make_network(2); + + let received: Arc<Mutex<Vec<Vec<u8>>>> = Arc::new(Mutex::new(Vec::new())); + + let recv_clone = received.clone(); + nodes[1].set_dgram_callback(move |data, _from| { + recv_clone.lock().unwrap().push(data.to_vec()); + }); + + // Node 0 sends dgram to Node 1 + let id1 = *nodes[1].id(); + nodes[0].send_dgram(b"proxy-test", &id1); + + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + let msgs = received.lock().unwrap(); + assert!(!msgs.is_empty(), "Node 1 should receive the dgram"); + assert_eq!(msgs[0], b"proxy-test"); +} + +// ── Advertise tests ───────────────────────────────── + +#[test] +fn nodes_peer_count_after_join() { + let nodes = make_network(3); + + // All nodes should have at least 1 peer + for (i, node) in nodes.iter().enumerate() { + assert!( + node.peer_count() >= 1, + "Node {i} should have at least 1 peer" + ); + } +} + +// ── Storage tests ─────────────────────────────────── + +#[test] +fn storage_count_after_put() { + let mut node = Node::bind(0).unwrap(); + assert_eq!(node.storage_count(), 0); + + node.put(b"k1", b"v1", 300, false); + node.put(b"k2", b"v2", 300, false); + assert_eq!(node.storage_count(), 2); +} + +#[test] +fn put_from_multiple_nodes() { + let mut nodes = make_network(3); + + nodes[0].put(b"from-0", b"val-0", 300, false); + nodes[1].put(b"from-1", b"val-1", 300, false); + nodes[2].put(b"from-2", b"val-2", 300, false); + + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Each node should have its own value + assert!(!nodes[0].get(b"from-0").is_empty()); + assert!(!nodes[1].get(b"from-1").is_empty()); + assert!(!nodes[2].get(b"from-2").is_empty()); +} + +// ── Config tests ──────────────────────────────────── + +#[test] +fn config_default_works() { + let config = tesseras_dht::config::Config::default(); + assert_eq!(config.num_find_node, 10); + assert_eq!(config.bucket_size, 20); + assert_eq!(config.default_ttl, 300); +} + +#[test] +fn config_pastebin_preset() { + let config = tesseras_dht::config::Config::pastebin(); + assert_eq!(config.default_ttl, 65535); + assert!(config.require_signatures); +} + +// ── Metrics tests ─────────────────────────────────── + +#[test] +fn metrics_after_put() { + let mut nodes = make_network(2); + let before = nodes[0].metrics(); + nodes[0].put(b"m-key", b"m-val", 300, false); + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + let after = nodes[0].metrics(); + assert!( + after.messages_sent > before.messages_sent, + "messages_sent should increase after put" + ); +} + +#[test] +fn metrics_bytes_tracked() { + let mut node = Node::bind(0).unwrap(); + let m = node.metrics(); + assert_eq!(m.bytes_sent, 0); + assert_eq!(m.bytes_received, 0); +} + +// ── Builder tests ─────────────────────────────────── + +#[test] +fn builder_basic() { + use tesseras_dht::node::NodeBuilder; + let node = NodeBuilder::new() + .port(0) + .nat(NatState::Global) + .build() + .unwrap(); + assert_eq!(node.nat_state(), NatState::Global); +} + +#[test] +fn builder_with_seed() { + use tesseras_dht::node::NodeBuilder; + let n1 = NodeBuilder::new() + .port(0) + .seed(b"same-seed") + .build() + .unwrap(); + let n2 = NodeBuilder::new() + .port(0) + .seed(b"same-seed") + .build() + .unwrap(); + assert_eq!(n1.id(), n2.id()); +} + +#[test] +fn builder_with_config() { + use tesseras_dht::node::NodeBuilder; + let config = tesseras_dht::config::Config::pastebin(); + let node = NodeBuilder::new().port(0).config(config).build().unwrap(); + assert!(node.config().require_signatures); +} + +// ── Persistence mock test ─────────────────────────── + +#[test] +fn persistence_nop_save_load() { + let mut node = Node::bind(0).unwrap(); + node.put(b"persist-key", b"persist-val", 300, false); + // With NoPersistence, save does nothing + node.save_state(); + // load_persisted with NoPersistence loads nothing + node.load_persisted(); + // Value still there from local storage + assert!(!node.get(b"persist-key").is_empty()); +} + +// ── Ban list tests ──────────────────────────────────── + +#[test] +fn ban_list_initially_empty() { + let node = Node::bind(0).unwrap(); + assert_eq!(node.ban_count(), 0); +} + +#[test] +fn ban_list_unit() { + use tesseras_dht::banlist::BanList; + let mut bl = BanList::new(); + let addr: std::net::SocketAddr = "127.0.0.1:9999".parse().unwrap(); + + assert!(!bl.is_banned(&addr)); + bl.record_failure(addr); + bl.record_failure(addr); + assert!(!bl.is_banned(&addr)); // 2 < threshold 3 + bl.record_failure(addr); + assert!(bl.is_banned(&addr)); // 3 >= threshold +} + +#[test] +fn ban_list_success_resets() { + use tesseras_dht::banlist::BanList; + let mut bl = BanList::new(); + let addr: std::net::SocketAddr = "127.0.0.1:9999".parse().unwrap(); + + bl.record_failure(addr); + bl.record_failure(addr); + bl.record_success(&addr); + bl.record_failure(addr); // starts over from 1 + assert!(!bl.is_banned(&addr)); +} + +// ── Store tracker tests ─────────────────────────────── + +#[test] +fn store_tracker_initially_empty() { + let node = Node::bind(0).unwrap(); + assert_eq!(node.pending_stores(), 0); + assert_eq!(node.store_stats(), (0, 0)); +} + +#[test] +fn store_tracker_counts_after_put() { + let mut nodes = make_network(3); + + nodes[0].put(b"tracked-key", b"tracked-val", 300, false); + std::thread::sleep(Duration::from_millis(50)); + poll_all(&mut nodes, 5); + + // Node 0 should have pending stores (sent to peers) + // or acks if peers responded quickly + let (acks, _) = nodes[0].store_stats(); + let pending = nodes[0].pending_stores(); + assert!( + acks > 0 || pending > 0, + "Should have tracked some stores (acks={acks}, pending={pending})" + ); +} + +// ── Node activity monitor tests ────────────────────── + +#[test] +fn activity_check_does_not_crash() { + let mut node = Node::bind(0).unwrap(); + node.set_nat_state(NatState::Global); + // Calling poll runs the activity check — should + // not crash even with no peers + node.poll().unwrap(); +} + +// ── Batch operations tests ─────────────────────────── + +#[test] +fn put_batch_stores_locally() { + let mut node = Node::bind(0).unwrap(); + + let entries: Vec<(&[u8], &[u8], u16, bool)> = vec![ + (b"b1", b"v1", 300, false), + (b"b2", b"v2", 300, false), + (b"b3", b"v3", 300, false), + ]; + node.put_batch(&entries); + + assert_eq!(node.storage_count(), 3); + assert_eq!(node.get(b"b1"), vec![b"v1".to_vec()]); + assert_eq!(node.get(b"b2"), vec![b"v2".to_vec()]); + assert_eq!(node.get(b"b3"), vec![b"v3".to_vec()]); +} + +#[test] +fn get_batch_returns_local() { + let mut node = Node::bind(0).unwrap(); + node.put(b"gb1", b"v1", 300, false); + node.put(b"gb2", b"v2", 300, false); + + let results = node.get_batch(&[b"gb1", b"gb2", b"gb-missing"]); + assert_eq!(results.len(), 3); + assert_eq!(results[0].1, vec![b"v1".to_vec()]); + assert_eq!(results[1].1, vec![b"v2".to_vec()]); + assert!(results[2].1.is_empty()); // not found +} + +#[test] +fn put_batch_distributes_to_peers() { + let mut nodes = make_network(3); + + let entries: Vec<(&[u8], &[u8], u16, bool)> = vec![ + (b"dist-1", b"val-1", 300, false), + (b"dist-2", b"val-2", 300, false), + (b"dist-3", b"val-3", 300, false), + (b"dist-4", b"val-4", 300, false), + (b"dist-5", b"val-5", 300, false), + ]; + nodes[0].put_batch(&entries); + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // All 5 values should be stored locally on node 0 + for i in 1..=5 { + let key = format!("dist-{i}"); + let vals = nodes[0].get(key.as_bytes()); + assert!(!vals.is_empty(), "Node 0 should have {key}"); + } + + // At least some should be distributed to other nodes + let total: usize = nodes.iter().map(|n| n.storage_count()).sum(); + assert!(total > 5, "Total stored {total} should be > 5 (replicated)"); +} + +// ── Proactive replication tests ────────────────── + +#[test] +fn proactive_replicate_on_new_node() { + // Node 0 stores a value, then node 2 joins. + // After routing table sync, node 2 should receive + // the value proactively (§2.5). + let mut nodes = make_network(2); + nodes[0].put(b"proactive-key", b"proactive-val", 300, false); + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // Add a third node + let bp = nodes[0].local_addr().unwrap().port(); + let mut node2 = Node::bind(0).unwrap(); + node2.set_nat_state(NatState::Global); + node2.join("127.0.0.1", bp).unwrap(); + nodes.push(node2); + + // Poll to let proactive replication trigger + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 20); + + // At least the original 2 nodes should have the value; + // the new node may also have it via proactive replication + let total: usize = nodes.iter().map(|n| n.storage_count()).sum(); + assert!( + total >= 2, + "Total stored {total} should be >= 2 after proactive replication" + ); +} + +// ── Republish on access tests ──────────────────── + +#[test] +fn republish_on_find_value() { + // Store on node 0, retrieve from node 2 via FIND_VALUE. + // After the value is found, it should be cached on + // the nearest queried node without it (§2.3). + let mut nodes = make_network(3); + + nodes[0].put(b"republish-key", b"republish-val", 300, false); + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // Node 2 triggers FIND_VALUE + let _ = nodes[2].get(b"republish-key"); + + // Poll to let the lookup and republish propagate + for _ in 0..30 { + poll_all(&mut nodes, 5); + std::thread::sleep(Duration::from_millis(20)); + + let vals = nodes[2].get(b"republish-key"); + if !vals.is_empty() { + break; + } + } + + // Count total stored across all nodes — should be + // more than 1 due to republish-on-access caching + let total: usize = nodes.iter().map(|n| n.storage_count()).sum(); + assert!( + total >= 2, + "Total stored {total} should be >= 2 after republish-on-access" + ); +} diff --git a/tests/rdp_lossy.rs b/tests/rdp_lossy.rs new file mode 100644 index 0000000..a31cc65 --- /dev/null +++ b/tests/rdp_lossy.rs @@ -0,0 +1,125 @@ +//! RDP packet loss simulation test. +//! +//! Tests that RDP retransmission handles packet loss +//! correctly by using two nodes where the send path +//! drops a percentage of packets. + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; +use tesseras_dht::rdp::RdpState; + +const RDP_PORT: u16 = 6000; + +#[test] +fn rdp_delivers_despite_drops() { + // Two nodes with standard UDP (no actual drops — + // this test validates the RDP retransmission + // mechanism works end-to-end). + let mut server = Node::bind(0).unwrap(); + server.set_nat_state(NatState::Global); + let server_addr = server.local_addr().unwrap(); + let server_id = *server.id(); + + let mut client = Node::bind(0).unwrap(); + client.set_nat_state(NatState::Global); + client.join("127.0.0.1", server_addr.port()).unwrap(); + + // Exchange routing + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(10)); + } + + // Server listens + server.rdp_listen(RDP_PORT).unwrap(); + + // Client connects + let desc = client.rdp_connect(0, &server_id, RDP_PORT).unwrap(); + + // Handshake + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(10)); + } + + assert_eq!( + client.rdp_state(desc).unwrap(), + RdpState::Open, + "Connection should be open" + ); + + // Send multiple messages + let msg_count = 10; + for i in 0..msg_count { + let msg = format!("msg-{i}"); + client.rdp_send(desc, msg.as_bytes()).unwrap(); + } + + // Poll to deliver + for _ in 0..20 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(20)); + } + + // Server reads all messages + let mut received = Vec::new(); + let status = server.rdp_status(); + for s in &status { + if s.state == RdpState::Open { + // Try all likely descriptors + for d in 1..=10 { + let mut buf = [0u8; 256]; + loop { + match server.rdp_recv(d, &mut buf) { + Ok(0) => break, + Ok(n) => { + received.push( + String::from_utf8_lossy(&buf[..n]).to_string(), + ); + } + Err(_) => break, + } + } + } + } + } + + assert!(!received.is_empty(), "Server should have received messages"); +} + +#[test] +fn rdp_connection_state_after_close() { + let mut server = Node::bind(0).unwrap(); + server.set_nat_state(NatState::Global); + let server_addr = server.local_addr().unwrap(); + let server_id = *server.id(); + + let mut client = Node::bind(0).unwrap(); + client.set_nat_state(NatState::Global); + client.join("127.0.0.1", server_addr.port()).unwrap(); + + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(10)); + } + + server.rdp_listen(RDP_PORT + 1).unwrap(); + let desc = client.rdp_connect(0, &server_id, RDP_PORT + 1).unwrap(); + + for _ in 0..10 { + server.poll().ok(); + client.poll().ok(); + std::thread::sleep(Duration::from_millis(10)); + } + + // Close from client side + client.rdp_close(desc); + + // Descriptor should no longer be valid + assert!(client.rdp_state(desc).is_err()); +} diff --git a/tests/scale.rs b/tests/scale.rs new file mode 100644 index 0000000..b518385 --- /dev/null +++ b/tests/scale.rs @@ -0,0 +1,138 @@ +//! Scale test: 20 nodes with distributed put/get. +//! +//! Runs 20 nodes with distributed put/get to verify +//! correctness at scale. + +use std::time::Duration; +use tesseras_dht::Node; +use tesseras_dht::nat::NatState; + +fn poll_all(nodes: &mut [Node], rounds: usize) { + let fast = Duration::from_millis(1); + for _ in 0..rounds { + for n in nodes.iter_mut() { + n.poll_timeout(fast).ok(); + } + } +} + +fn make_network(n: usize) -> Vec<Node> { + let mut nodes = Vec::with_capacity(n); + let bootstrap = Node::bind(0).unwrap(); + let bp = bootstrap.local_addr().unwrap().port(); + nodes.push(bootstrap); + nodes[0].set_nat_state(NatState::Global); + + for _ in 1..n { + let mut node = Node::bind(0).unwrap(); + node.set_nat_state(NatState::Global); + node.join("127.0.0.1", bp).unwrap(); + nodes.push(node); + } + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + nodes +} + +#[test] +fn twenty_nodes_routing() { + let nodes = make_network(20); + + // Every node should have at least 1 peer + for (i, node) in nodes.iter().enumerate() { + assert!( + node.routing_table_size() >= 1, + "Node {i} has empty routing table" + ); + } + + // Average routing table should be > 5 + let total: usize = nodes.iter().map(|n| n.routing_table_size()).sum(); + let avg = total / nodes.len(); + assert!(avg >= 5, "Average routing table {avg} too low"); +} + +#[test] +fn twenty_nodes_put_get() { + let mut nodes = make_network(20); + + // Each node stores one value + for i in 0..20u32 { + let key = format!("scale-key-{i}"); + let val = format!("scale-val-{i}"); + nodes[i as usize].put(key.as_bytes(), val.as_bytes(), 300, false); + } + + // Poll to distribute + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // Each node should have its own value + for i in 0..20u32 { + let key = format!("scale-key-{i}"); + let vals = nodes[i as usize].get(key.as_bytes()); + assert!(!vals.is_empty(), "Node {i} lost its own value"); + } + + // Count total stored values across network + let total_stored: usize = nodes.iter().map(|n| n.storage_count()).sum(); + assert!( + total_stored >= 20, + "Total stored {total_stored} should be >= 20" + ); +} + +#[test] +#[ignore] // timing-sensitive, consumes 100% CPU with 20 nodes polling +fn twenty_nodes_remote_get() { + let mut nodes = make_network(20); + + // Node 0 stores a value + nodes[0].put(b"find-me", b"found", 300, false); + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // Node 19 tries to get it — trigger FIND_VALUE + let _ = nodes[19].get(b"find-me"); + + // Poll all nodes to let FIND_VALUE propagate + for _ in 0..40 { + poll_all(&mut nodes, 5); + std::thread::sleep(Duration::from_millis(30)); + + let vals = nodes[19].get(b"find-me"); + if !vals.is_empty() { + assert_eq!(vals[0], b"found"); + return; + } + } + panic!("Node 19 should find the value via FIND_VALUE"); +} + +#[test] +fn twenty_nodes_multiple_puts() { + let mut nodes = make_network(20); + + // 5 nodes store 10 values each + for n in 0..5 { + for k in 0..10u32 { + let key = format!("n{n}-k{k}"); + let val = format!("n{n}-v{k}"); + nodes[n].put(key.as_bytes(), val.as_bytes(), 300, false); + } + } + + std::thread::sleep(Duration::from_millis(100)); + poll_all(&mut nodes, 10); + + // Verify origin nodes have their values + for n in 0..5 { + for k in 0..10u32 { + let key = format!("n{n}-k{k}"); + let vals = nodes[n].get(key.as_bytes()); + assert!(!vals.is_empty(), "Node {n} lost key n{n}-k{k}"); + } + } +} |