From 21707daa98005a9b8ad78079133da450464bfa67 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Tue, 13 Feb 2024 18:01:52 +0000 Subject: [PATCH] network: split nat into separate mods --- network/src/backend.rs | 33 ++-- network/src/nat/handler.rs | 36 ++++ network/src/nat/key.rs | 29 ++++ network/src/nat/mod.rs | 42 +++++ network/src/{nat.rs => nat/processor.rs} | 203 +++++++++++------------ network/src/nat/table.rs | 21 +++ network/src/proxynat/icmp.rs | 6 +- network/src/proxynat/mod.rs | 4 +- network/src/proxynat/tcp.rs | 8 +- network/src/proxynat/udp.rs | 4 +- network/src/raw_socket.rs | 10 +- 11 files changed, 250 insertions(+), 146 deletions(-) create mode 100644 network/src/nat/handler.rs create mode 100644 network/src/nat/key.rs create mode 100644 network/src/nat/mod.rs rename network/src/{nat.rs => nat/processor.rs} (78%) create mode 100644 network/src/nat/table.rs diff --git a/network/src/backend.rs b/network/src/backend.rs index 9fc6589..c97e44c 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -1,15 +1,13 @@ use crate::autonet::NetworkMetadata; use crate::chandev::ChannelDevice; -use crate::nat::NatRouter; -use crate::pkt::RecvPacket; +use crate::nat::Nat; use crate::proxynat::ProxyNatHandlerFactory; use crate::raw_socket::{AsyncRawSocketChannel, RawSocketHandle, RawSocketProtocol}; use crate::vbridge::{BridgeJoinHandle, VirtualBridge}; use anyhow::{anyhow, Result}; use bytes::BytesMut; -use etherparse::SlicedPacket; use futures::TryStreamExt; -use log::{debug, info, trace, warn}; +use log::{info, trace, warn}; use smoltcp::iface::{Config, Interface, SocketSet}; use smoltcp::phy::Medium; use smoltcp::time::Instant; @@ -30,7 +28,6 @@ pub struct NetworkBackend { enum NetworkStackSelect { Receive(Option), Send(Option), - Reclaim, } struct NetworkStack<'a> { @@ -39,7 +36,7 @@ struct NetworkStack<'a> { udev: ChannelDevice, interface: Interface, sockets: SocketSet<'a>, - router: NatRouter, + nat: Nat, bridge: BridgeJoinHandle, } @@ -50,7 +47,6 @@ impl NetworkStack<'_> { x = self.bridge.from_bridge_receiver.recv() => NetworkStackSelect::Send(x), x = self.bridge.from_broadcast_receiver.recv() => NetworkStackSelect::Send(x.ok()), x = self.tx.recv() => NetworkStackSelect::Send(x), - _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, }; match what { @@ -59,16 +55,13 @@ impl NetworkStack<'_> { trace!("failed to send guest packet to bridge: {}", error); } - if let Ok(slice) = SlicedPacket::from_ethernet(&packet) { - let packet = RecvPacket::new(&packet, &slice)?; - if let Err(error) = self.router.process(&packet).await { - debug!("router failed to process packet: {}", error); - } - - self.udev.rx = Some(packet.raw.into()); - self.interface - .poll(Instant::now(), &mut self.udev, &mut self.sockets); + if let Err(error) = self.nat.receive_sender.try_send(packet.clone()) { + trace!("failed to send guest packet to nat: {}", error); } + + self.udev.rx = Some(packet); + self.interface + .poll(Instant::now(), &mut self.udev, &mut self.sockets); } NetworkStackSelect::Send(Some(packet)) => { @@ -80,8 +73,6 @@ impl NetworkStack<'_> { NetworkStackSelect::Receive(None) | NetworkStackSelect::Send(None) => { return Ok(false); } - - NetworkStackSelect::Reclaim => {} } Ok(true) @@ -134,7 +125,7 @@ impl NetworkBackend { let (tx_sender, tx_receiver) = channel::(TX_CHANNEL_BUFFER_LEN); let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone()); let mac = self.metadata.gateway.mac; - let nat = NatRouter::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone()); + let nat = Nat::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone())?; let hardware_addr = HardwareAddress::Ethernet(mac); let config = Config::new(hardware_addr); let mut iface = Interface::new(config, &mut udev, Instant::now()); @@ -145,14 +136,14 @@ impl NetworkBackend { }); let sockets = SocketSet::new(vec![]); let handle = self.bridge.join(self.metadata.guest.mac).await?; - let kdev = AsyncRawSocketChannel::new(kdev)?; + let kdev = AsyncRawSocketChannel::new(mtu, kdev)?; Ok(NetworkStack { tx: tx_receiver, kdev, udev, interface: iface, sockets, - router: nat, + nat, bridge: handle, }) } diff --git a/network/src/nat/handler.rs b/network/src/nat/handler.rs new file mode 100644 index 0000000..e1943e8 --- /dev/null +++ b/network/src/nat/handler.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use async_trait::async_trait; +use bytes::BytesMut; +use tokio::sync::mpsc::Sender; + +use super::key::NatKey; + +#[derive(Debug, Clone)] +pub struct NatHandlerContext { + pub mtu: usize, + pub key: NatKey, + pub transmit_sender: Sender, + pub reclaim_sender: Sender, +} + +impl NatHandlerContext { + pub fn try_transmit(&self, buffer: BytesMut) -> Result<()> { + self.transmit_sender.try_send(buffer)?; + Ok(()) + } + + pub async fn reclaim(&self) -> Result<()> { + self.reclaim_sender.try_send(self.key)?; + Ok(()) + } +} + +#[async_trait] +pub trait NatHandler: Send { + async fn receive(&self, packet: &[u8]) -> Result; +} + +#[async_trait] +pub trait NatHandlerFactory: Send { + async fn nat(&self, context: NatHandlerContext) -> Option>; +} diff --git a/network/src/nat/key.rs b/network/src/nat/key.rs new file mode 100644 index 0000000..ac1447d --- /dev/null +++ b/network/src/nat/key.rs @@ -0,0 +1,29 @@ +use std::fmt::Display; + +use smoltcp::wire::{EthernetAddress, IpEndpoint}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub enum NatKeyProtocol { + Tcp, + Udp, + Icmp, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct NatKey { + pub protocol: NatKeyProtocol, + pub client_mac: EthernetAddress, + pub local_mac: EthernetAddress, + pub client_ip: IpEndpoint, + pub external_ip: IpEndpoint, +} + +impl Display for NatKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} -> {} {:?} {} -> {}", + self.client_mac, self.local_mac, self.protocol, self.client_ip, self.external_ip + ) + } +} diff --git a/network/src/nat/mod.rs b/network/src/nat/mod.rs new file mode 100644 index 0000000..ca0a404 --- /dev/null +++ b/network/src/nat/mod.rs @@ -0,0 +1,42 @@ +use anyhow::Result; +use tokio::sync::mpsc::Sender; + +use self::handler::NatHandlerFactory; +use self::processor::NatProcessor; +use bytes::BytesMut; +use smoltcp::wire::EthernetAddress; +use smoltcp::wire::IpCidr; +use tokio::task::JoinHandle; + +pub mod handler; +pub mod key; +pub mod processor; +pub mod table; + +pub struct Nat { + pub receive_sender: Sender, + task: JoinHandle<()>, +} + +impl Nat { + pub fn new( + mtu: usize, + factory: Box, + local_mac: EthernetAddress, + local_cidrs: Vec, + transmit_sender: Sender, + ) -> Result { + let (receive_sender, task) = + NatProcessor::launch(mtu, factory, local_mac, local_cidrs, transmit_sender)?; + Ok(Self { + receive_sender, + task, + }) + } +} + +impl Drop for Nat { + fn drop(&mut self) { + self.task.abort(); + } +} diff --git a/network/src/nat.rs b/network/src/nat/processor.rs similarity index 78% rename from network/src/nat.rs rename to network/src/nat/processor.rs index cb3e485..1f348e9 100644 --- a/network/src/nat.rs +++ b/network/src/nat/processor.rs @@ -1,7 +1,6 @@ use crate::pkt::RecvPacket; use crate::pkt::RecvPacketIp; use anyhow::Result; -use async_trait::async_trait; use bytes::BytesMut; use etherparse::Icmpv4Header; use etherparse::Icmpv4Type; @@ -11,126 +10,110 @@ use etherparse::IpNumber; use etherparse::IpPayloadSlice; use etherparse::Ipv4Slice; use etherparse::Ipv6Slice; +use etherparse::SlicedPacket; use etherparse::TcpHeaderSlice; use etherparse::UdpHeaderSlice; +use log::warn; use log::{debug, trace}; use smoltcp::wire::EthernetAddress; use smoltcp::wire::IpAddress; use smoltcp::wire::IpCidr; use smoltcp::wire::IpEndpoint; use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::fmt::Display; +use tokio::select; use tokio::sync::mpsc::channel; use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Sender; +use tokio::task::JoinHandle; + +use super::handler::NatHandler; +use super::handler::NatHandlerContext; +use super::handler::NatHandlerFactory; +use super::key::NatKey; +use super::key::NatKeyProtocol; +use super::table::NatTable; const RECLAIM_CHANNEL_QUEUE_LEN: usize = 10; +const RECEIVE_CHANNEL_QUEUE_LEN: usize = 30; -#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -pub enum NatKeyProtocol { - Tcp, - Udp, - Icmp, -} - -#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -pub struct NatKey { - pub protocol: NatKeyProtocol, - pub client_mac: EthernetAddress, - pub local_mac: EthernetAddress, - pub client_ip: IpEndpoint, - pub external_ip: IpEndpoint, -} - -impl Display for NatKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{} -> {} {:?} {} -> {}", - self.client_mac, self.local_mac, self.protocol, self.client_ip, self.external_ip - ) - } -} - -#[derive(Debug, Clone)] -pub struct NatHandlerContext { - pub mtu: usize, - pub key: NatKey, - tx_sender: Sender, - reclaim_sender: Sender, -} - -impl NatHandlerContext { - pub fn try_send(&self, buffer: BytesMut) -> Result<()> { - self.tx_sender.try_send(buffer)?; - Ok(()) - } - - pub async fn reclaim(&self) -> Result<()> { - self.reclaim_sender.try_send(self.key)?; - Ok(()) - } -} - -#[async_trait] -pub trait NatHandler: Send { - async fn receive(&self, packet: &[u8]) -> Result; -} - -#[async_trait] -pub trait NatHandlerFactory: Send { - async fn nat(&self, context: NatHandlerContext) -> Option>; -} - -pub struct NatTable { - inner: HashMap>, -} - -impl Default for NatTable { - fn default() -> Self { - Self::new() - } -} - -impl NatTable { - pub fn new() -> Self { - Self { - inner: HashMap::new(), - } - } -} - -pub struct NatRouter { +pub struct NatProcessor { mtu: usize, local_mac: EthernetAddress, local_cidrs: Vec, - factory: Box, table: NatTable, - tx_sender: Sender, + factory: Box, + transmit_sender: Sender, reclaim_sender: Sender, reclaim_receiver: Receiver, + receive_receiver: Receiver, } -impl NatRouter { - pub fn new( +enum NatProcessorSelect { + Reclaim(Option), + ReceivedPacket(Option), +} + +impl NatProcessor { + pub fn launch( mtu: usize, factory: Box, local_mac: EthernetAddress, local_cidrs: Vec, - tx_sender: Sender, - ) -> Self { + transmit_sender: Sender, + ) -> Result<(Sender, JoinHandle<()>)> { let (reclaim_sender, reclaim_receiver) = channel(RECLAIM_CHANNEL_QUEUE_LEN); - Self { + let (receive_sender, receive_receiver) = channel(RECEIVE_CHANNEL_QUEUE_LEN); + let mut processor = Self { mtu, local_mac, local_cidrs, factory, table: NatTable::new(), - tx_sender, + transmit_sender, reclaim_sender, + receive_receiver, reclaim_receiver, + }; + + let handle = tokio::task::spawn(async move { + if let Err(error) = processor.process().await { + warn!("nat processing failed: {}", error); + } + }); + + Ok((receive_sender, handle)) + } + + pub async fn process(&mut self) -> Result<()> { + loop { + let selection = select! { + x = self.reclaim_receiver.recv() => NatProcessorSelect::Reclaim(x), + x = self.receive_receiver.recv() => NatProcessorSelect::ReceivedPacket(x), + }; + + match selection { + NatProcessorSelect::Reclaim(Some(key)) => { + if self.table.inner.remove(&key).is_some() { + debug!("reclaimed nat key: {}", key); + } + } + + NatProcessorSelect::ReceivedPacket(Some(packet)) => { + if let Ok(slice) = SlicedPacket::from_ethernet(&packet) { + let Ok(packet) = RecvPacket::new(&packet, &slice) else { + continue; + }; + + self.process_packet(&packet).await?; + } + } + + NatProcessorSelect::ReceivedPacket(None) | NatProcessorSelect::Reclaim(None) => { + break + } + } } + Ok(()) } pub async fn process_reclaim(&mut self) -> Result> { @@ -146,7 +129,7 @@ impl NatRouter { }) } - pub async fn process<'a>(&mut self, packet: &RecvPacket<'a>) -> Result<()> { + pub async fn process_packet<'a>(&mut self, packet: &RecvPacket<'a>) -> Result<()> { let Some(ether) = packet.ether else { return Ok(()); }; @@ -180,7 +163,7 @@ impl NatRouter { let context = NatHandlerContext { mtu: self.mtu, key, - tx_sender: self.tx_sender.clone(), + transmit_sender: self.transmit_sender.clone(), reclaim_sender: self.reclaim_sender.clone(), }; let handler: Option<&mut Box> = match self.table.inner.entry(key) { @@ -251,28 +234,6 @@ impl NatRouter { }) } - pub fn extract_key_tcp<'a>( - &mut self, - packet: &RecvPacket<'a>, - source_addr: IpAddress, - dest_addr: IpAddress, - payload: &IpPayloadSlice<'a>, - ) -> Result> { - let Some(ether) = packet.ether else { - return Ok(None); - }; - let header = TcpHeaderSlice::from_slice(payload.payload)?; - let source = IpEndpoint::new(source_addr, header.source_port()); - let dest = IpEndpoint::new(dest_addr, header.destination_port()); - Ok(Some(NatKey { - protocol: NatKeyProtocol::Tcp, - client_mac: EthernetAddress(ether.source()), - local_mac: EthernetAddress(ether.destination()), - client_ip: source, - external_ip: dest, - })) - } - pub fn extract_key_udp<'a>( &mut self, packet: &RecvPacket<'a>, @@ -344,4 +305,26 @@ impl NatRouter { external_ip: dest, })) } + + pub fn extract_key_tcp<'a>( + &mut self, + packet: &RecvPacket<'a>, + source_addr: IpAddress, + dest_addr: IpAddress, + payload: &IpPayloadSlice<'a>, + ) -> Result> { + let Some(ether) = packet.ether else { + return Ok(None); + }; + let header = TcpHeaderSlice::from_slice(payload.payload)?; + let source = IpEndpoint::new(source_addr, header.source_port()); + let dest = IpEndpoint::new(dest_addr, header.destination_port()); + Ok(Some(NatKey { + protocol: NatKeyProtocol::Tcp, + client_mac: EthernetAddress(ether.source()), + local_mac: EthernetAddress(ether.destination()), + client_ip: source, + external_ip: dest, + })) + } } diff --git a/network/src/nat/table.rs b/network/src/nat/table.rs new file mode 100644 index 0000000..79e4a5a --- /dev/null +++ b/network/src/nat/table.rs @@ -0,0 +1,21 @@ +use std::collections::HashMap; + +use super::{handler::NatHandler, key::NatKey}; + +pub struct NatTable { + pub inner: HashMap>, +} + +impl Default for NatTable { + fn default() -> Self { + Self::new() + } +} + +impl NatTable { + pub fn new() -> Self { + Self { + inner: HashMap::new(), + } + } +} diff --git a/network/src/proxynat/icmp.rs b/network/src/proxynat/icmp.rs index 2254aac..e73653f 100644 --- a/network/src/proxynat/icmp.rs +++ b/network/src/proxynat/icmp.rs @@ -19,7 +19,7 @@ use tokio::{ use crate::{ icmp::{IcmpClient, IcmpProtocol, IcmpReply}, - nat::{NatHandler, NatHandlerContext}, + nat::handler::{NatHandler, NatHandlerContext}, }; const ICMP_PING_TIMEOUT_SECS: u64 = 20; @@ -223,7 +223,7 @@ impl ProxyIcmpHandler { let mut writer = buffer.writer(); packet.write(&mut writer, &payload)?; let buffer = writer.into_inner(); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit icmp packet: {}", error); } Ok(()) @@ -268,7 +268,7 @@ impl ProxyIcmpHandler { let mut writer = buffer.writer(); packet.write(&mut writer, &payload)?; let buffer = writer.into_inner(); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit icmp packet: {}", error); } Ok(()) diff --git a/network/src/proxynat/mod.rs b/network/src/proxynat/mod.rs index 40784f4..590206b 100644 --- a/network/src/proxynat/mod.rs +++ b/network/src/proxynat/mod.rs @@ -5,10 +5,10 @@ use log::warn; use tokio::sync::mpsc::channel; -use crate::nat::NatHandlerContext; use crate::proxynat::udp::ProxyUdpHandler; -use crate::nat::{NatHandler, NatHandlerFactory, NatKeyProtocol}; +use crate::nat::handler::{NatHandler, NatHandlerContext, NatHandlerFactory}; +use crate::nat::key::NatKeyProtocol; use self::icmp::ProxyIcmpHandler; use self::tcp::ProxyTcpHandler; diff --git a/network/src/proxynat/tcp.rs b/network/src/proxynat/tcp.rs index 4fa487a..d7516ca 100644 --- a/network/src/proxynat/tcp.rs +++ b/network/src/proxynat/tcp.rs @@ -25,7 +25,7 @@ use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender}; use crate::{ chandev::ChannelDevice, - nat::{NatHandler, NatHandlerContext}, + nat::handler::{NatHandler, NatHandlerContext}, }; const TCP_BUFFER_SIZE: usize = 65535; @@ -216,7 +216,7 @@ impl ProxyTcpHandler { }; buffer.extend_from_slice(&header.to_bytes()); buffer.extend_from_slice(&payload); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit tcp packet: {}", error); } } @@ -389,7 +389,7 @@ impl ProxyTcpHandler { }; buffer.extend_from_slice(&header.to_bytes()); buffer.extend_from_slice(&payload); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit tcp packet: {}", error); } } @@ -449,7 +449,7 @@ impl ProxyTcpHandler { }; buffer.extend_from_slice(&header.to_bytes()); buffer.extend_from_slice(&payload); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit tcp packet: {}", error); } } diff --git a/network/src/proxynat/udp.rs b/network/src/proxynat/udp.rs index d902c25..9d7a6fc 100644 --- a/network/src/proxynat/udp.rs +++ b/network/src/proxynat/udp.rs @@ -16,7 +16,7 @@ use tokio::{ use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender}; use udp_stream::UdpStream; -use crate::nat::{NatHandler, NatHandlerContext}; +use crate::nat::handler::{NatHandler, NatHandlerContext}; const UDP_TIMEOUT_SECS: u64 = 60; @@ -111,7 +111,7 @@ impl ProxyUdpHandler { let mut writer = buffer.writer(); packet.write(&mut writer, data)?; let buffer = writer.into_inner(); - if let Err(error) = context.try_send(buffer) { + if let Err(error) = context.try_transmit(buffer) { debug!("failed to transmit udp packet: {}", error); } } diff --git a/network/src/raw_socket.rs b/network/src/raw_socket.rs index 9677ebf..83805ae 100644 --- a/network/src/raw_socket.rs +++ b/network/src/raw_socket.rs @@ -204,10 +204,10 @@ enum AsyncRawSocketChannelSelect { } impl AsyncRawSocketChannel { - pub fn new(socket: RawSocketHandle) -> Result { + pub fn new(mtu: usize, socket: RawSocketHandle) -> Result { let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN); let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN); - let task = AsyncRawSocketChannel::launch(socket, transmit_receiver, receive_sender)?; + let task = AsyncRawSocketChannel::launch(mtu, socket, transmit_receiver, receive_sender)?; Ok(AsyncRawSocketChannel { sender: transmit_sender, receiver: receive_receiver, @@ -216,13 +216,14 @@ impl AsyncRawSocketChannel { } fn launch( + mtu: usize, socket: RawSocketHandle, transmit_receiver: Receiver, receive_sender: Sender, ) -> Result> { Ok(tokio::task::spawn(async move { if let Err(error) = - AsyncRawSocketChannel::process(socket, transmit_receiver, receive_sender).await + AsyncRawSocketChannel::process(mtu, socket, transmit_receiver, receive_sender).await { warn!("failed to process raw socket: {}", error); } @@ -230,6 +231,7 @@ impl AsyncRawSocketChannel { } async fn process( + mtu: usize, socket: RawSocketHandle, mut transmit_receiver: Receiver, receive_sender: Sender, @@ -237,6 +239,7 @@ impl AsyncRawSocketChannel { let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) }; let socket = UdpSocket::from_std(socket)?; + let mut buffer = vec![0; mtu]; loop { let selection = select! { x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x), @@ -245,7 +248,6 @@ impl AsyncRawSocketChannel { match selection { AsyncRawSocketChannelSelect::Readable(_) => { - let mut buffer = vec![0; 1500]; match socket.try_recv(&mut buffer) { Ok(len) => { if len == 0 {