diff --git a/network/src/backend.rs b/network/src/backend.rs index 38a49f4..5367696 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -24,6 +24,7 @@ pub struct NetworkBackend { enum NetworkStackSelect<'a> { Receive(&'a [u8]), Send(Option>), + Reclaim, } struct NetworkStack<'a> { @@ -40,6 +41,7 @@ impl NetworkStack<'_> { let what = select! { x = self.tx.recv() => NetworkStackSelect::Send(x), x = self.kdev.read(receive_buffer) => NetworkStackSelect::Receive(&receive_buffer[0..x?]), + _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, }; match what { @@ -59,6 +61,8 @@ impl NetworkStack<'_> { self.interface .poll(timestamp, &mut self.udev, &mut self.sockets); } + + NetworkStackSelect::Reclaim => {} } Ok(()) diff --git a/network/src/nat.rs b/network/src/nat.rs index b3185a8..0bcdd88 100644 --- a/network/src/nat.rs +++ b/network/src/nat.rs @@ -1,23 +1,24 @@ -// Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/router.rs as a very interesting way to implement NAT. -// hypha will heavily change how the original code functions however. NatKey was a very useful example of what we need to store in a NAT map. - use anyhow::Result; use async_trait::async_trait; use etherparse::Ethernet2Slice; use etherparse::IpNumber; use etherparse::IpPayloadSlice; use etherparse::Ipv4Slice; +use etherparse::Ipv6Slice; use etherparse::LinkSlice; use etherparse::NetSlice; use etherparse::SlicedPacket; use etherparse::TcpHeaderSlice; use etherparse::UdpHeaderSlice; +use log::debug; use smoltcp::wire::EthernetAddress; use smoltcp::wire::IpAddress; use smoltcp::wire::IpEndpoint; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::fmt::Display; +use tokio::sync::mpsc::channel; +use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] @@ -52,7 +53,12 @@ pub trait NatHandler: Send { #[async_trait] pub trait NatHandlerFactory: Send { - async fn nat(&self, key: NatKey, sender: Sender>) -> Option>; + async fn nat( + &self, + key: NatKey, + tx_sender: Sender>, + reclaim_sender: Sender, + ) -> Option>; } pub struct NatTable { @@ -72,6 +78,8 @@ pub struct NatRouter { factory: Box, table: NatTable, tx_sender: Sender>, + reclaim_sender: Sender, + reclaim_receiver: Receiver, } impl NatRouter { @@ -80,14 +88,27 @@ impl NatRouter { mac: EthernetAddress, tx_sender: Sender>, ) -> Self { + let (reclaim_sender, reclaim_receiver) = channel(4); Self { _local_mac: mac, factory, table: NatTable::new(), tx_sender, + reclaim_sender, + reclaim_receiver, } } + pub async fn process_reclaim(&mut self) -> Result> { + Ok(if let Some(key) = self.reclaim_receiver.recv().await { + self.table.inner.remove(&key); + debug!("reclaimed nat key: {}", key); + Some(key) + } else { + None + }) + } + pub async fn process(&mut self, data: &[u8]) -> Result<()> { let packet = SlicedPacket::from_ethernet(data)?; let Some(ref link) = packet.link else { @@ -105,13 +126,8 @@ impl NatRouter { }; match net { - NetSlice::Ipv4(ipv4) => { - self.process_ipv4(data, ether, ipv4).await?; - } - - _ => { - return Ok(()); - } + NetSlice::Ipv4(ipv4) => self.process_ipv4(data, ether, ipv4).await?, + NetSlice::Ipv6(ipv6) => self.process_ipv6(data, ether, ipv6).await?, } Ok(()) @@ -142,6 +158,31 @@ impl NatRouter { Ok(()) } + pub async fn process_ipv6<'a>( + &mut self, + data: &[u8], + ether: &Ethernet2Slice<'a>, + ipv6: &Ipv6Slice<'a>, + ) -> Result<()> { + let source_addr = IpAddress::Ipv6(ipv6.header().source_addr().into()); + let dest_addr = IpAddress::Ipv6(ipv6.header().destination_addr().into()); + match ipv6.header().next_header() { + IpNumber::TCP => { + self.process_tcp(data, ether, source_addr, dest_addr, ipv6.payload()) + .await?; + } + + IpNumber::UDP => { + self.process_udp(data, ether, source_addr, dest_addr, ipv6.payload()) + .await?; + } + + _ => {} + } + + Ok(()) + } + pub async fn process_tcp<'a>( &mut self, data: &'a [u8], @@ -190,7 +231,11 @@ impl NatRouter { let handler: Option<&mut Box> = match self.table.inner.entry(key) { Entry::Occupied(entry) => Some(entry.into_mut()), Entry::Vacant(entry) => { - if let Some(handler) = self.factory.nat(key, self.tx_sender.clone()).await { + if let Some(handler) = self + .factory + .nat(key, self.tx_sender.clone(), self.reclaim_sender.clone()) + .await + { Some(entry.insert(handler)) } else { None @@ -201,7 +246,6 @@ impl NatRouter { if let Some(handler) = handler { handler.receive(data).await?; } - Ok(()) } } diff --git a/network/src/proxynat/mod.rs b/network/src/proxynat/mod.rs index 2cddef0..4f5fd26 100644 --- a/network/src/proxynat/mod.rs +++ b/network/src/proxynat/mod.rs @@ -1,5 +1,3 @@ -mod udp; - use async_trait::async_trait; use log::{debug, warn}; @@ -11,6 +9,8 @@ use crate::proxynat::udp::ProxyUdpHandler; use crate::nat::{NatHandler, NatHandlerFactory, NatKey, NatKeyProtocol}; +mod udp; + pub struct ProxyNatHandlerFactory {} impl ProxyNatHandlerFactory { @@ -21,7 +21,12 @@ impl ProxyNatHandlerFactory { #[async_trait] impl NatHandlerFactory for ProxyNatHandlerFactory { - async fn nat(&self, key: NatKey, sender: Sender>) -> Option> { + async fn nat( + &self, + key: NatKey, + tx_sender: Sender>, + reclaim_sender: Sender, + ) -> Option> { debug!("creating proxy nat entry for key: {}", key); match key.protocol { @@ -29,7 +34,7 @@ impl NatHandlerFactory for ProxyNatHandlerFactory { let (rx_sender, rx_receiver) = channel::>(4); let mut handler = ProxyUdpHandler::new(key, rx_sender); - if let Err(error) = handler.spawn(rx_receiver, sender.clone()).await { + if let Err(error) = handler.spawn(rx_receiver, tx_sender, reclaim_sender).await { warn!("unable to spawn udp proxy handler: {}", error); None } else { @@ -45,5 +50,5 @@ impl NatHandlerFactory for ProxyNatHandlerFactory { pub enum ProxyNatSelect { External(usize), Internal(Vec), - Closed, + Close, } diff --git a/network/src/proxynat/udp.rs b/network/src/proxynat/udp.rs index a82a335..a590de9 100644 --- a/network/src/proxynat/udp.rs +++ b/network/src/proxynat/udp.rs @@ -1,13 +1,13 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::{ + net::{IpAddr, SocketAddr}, + time::Duration, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use etherparse::{PacketBuilder, SlicedPacket, UdpSlice}; use log::{debug, warn}; -use smoltcp::{ - phy::{Checksum, ChecksumCapabilities}, - wire::IpAddress, -}; +use smoltcp::wire::IpAddress; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, select, @@ -19,6 +19,8 @@ use crate::nat::{NatHandler, NatKey}; use super::ProxyNatSelect; +const UDP_TIMEOUT_SECS: u64 = 60; + pub struct ProxyUdpHandler { key: NatKey, rx_sender: Sender>, @@ -41,19 +43,22 @@ impl ProxyUdpHandler { &mut self, rx_receiver: Receiver>, tx_sender: Sender>, + reclaim_sender: Sender, ) -> Result<()> { let external_addr = match self.key.external_ip.addr { - IpAddress::Ipv4(addr) => SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(addr.0[0], addr.0[1], addr.0[2], addr.0[3])), - self.key.external_ip.port, - ), - IpAddress::Ipv6(_) => return Err(anyhow!("IPv6 unsupported")), + IpAddress::Ipv4(addr) => { + SocketAddr::new(IpAddr::V4(addr.0.into()), self.key.external_ip.port) + } + IpAddress::Ipv6(addr) => { + SocketAddr::new(IpAddr::V6(addr.0.into()), self.key.external_ip.port) + } }; let socket = UdpStream::connect(external_addr).await?; let key = self.key; tokio::spawn(async move { - if let Err(error) = ProxyUdpHandler::process(key, socket, rx_receiver, tx_sender).await + if let Err(error) = + ProxyUdpHandler::process(key, socket, rx_receiver, tx_sender, reclaim_sender).await { warn!("processing of udp proxy failed: {}", error); } @@ -66,22 +71,20 @@ impl ProxyUdpHandler { mut socket: UdpStream, mut rx_receiver: Receiver>, tx_sender: Sender>, + reclaim_sender: Sender, ) -> Result<()> { - let mut checksum = ChecksumCapabilities::ignored(); - checksum.udp = Checksum::Tx; - checksum.ipv4 = Checksum::Tx; - checksum.tcp = Checksum::Tx; - let mut external_buffer = vec![0u8; 2048]; loop { + let deadline = tokio::time::sleep(Duration::from_secs(UDP_TIMEOUT_SECS)); let selection = select! { x = rx_receiver.recv() => if let Some(data) = x { ProxyNatSelect::Internal(data) } else { - ProxyNatSelect::Closed + ProxyNatSelect::Close }, x = socket.read(&mut external_buffer) => ProxyNatSelect::External(x?), + _ = deadline => ProxyNatSelect::Close, }; match selection { @@ -119,8 +122,14 @@ impl ProxyUdpHandler { let udp = UdpSlice::from_slice(ip.payload)?; socket.write_all(udp.payload()).await?; } - ProxyNatSelect::Closed => warn!("UDP socket closed"), + ProxyNatSelect::Close => { + drop(socket); + reclaim_sender.send(key).await?; + break; + } } } + + Ok(()) } }