diff --git a/network/bin/network.rs b/network/bin/network.rs index 460b5df..77ec2bd 100644 --- a/network/bin/network.rs +++ b/network/bin/network.rs @@ -1,3 +1,6 @@ +use std::str::FromStr; + +use advmac::MacAddr6; use anyhow::Result; use clap::Parser; use env_logger::Env; @@ -10,13 +13,23 @@ struct NetworkArgs { #[arg(long, default_value = "fe80::1/10")] ipv6_network: String, + + #[arg(long)] + force_mac_address: Option, } #[tokio::main] async fn main() -> Result<()> { env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init(); let args = NetworkArgs::parse(); - let mut service = NetworkService::new(args.ipv4_network, args.ipv6_network)?; + + let force_mac_address = if let Some(mac_str) = args.force_mac_address { + Some(MacAddr6::from_str(&mac_str)?) + } else { + None + }; + + let mut service = NetworkService::new(args.ipv4_network, args.ipv6_network, force_mac_address)?; service.watch().await?; Ok(()) } diff --git a/network/src/backend.rs b/network/src/backend.rs index 0c297ea..b99a29f 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -7,6 +7,7 @@ use anyhow::{anyhow, Result}; use futures::TryStreamExt; use log::warn; use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::Medium; use smoltcp::time::Instant; use smoltcp::wire::{HardwareAddress, IpCidr}; use std::str::FromStr; @@ -19,6 +20,7 @@ use tokio::sync::mpsc::{channel, Receiver}; pub struct NetworkBackend { ipv4: String, ipv6: String, + force_mac_address: Option, interface: String, } @@ -72,10 +74,16 @@ impl NetworkStack<'_> { } impl NetworkBackend { - pub fn new(ipv4: &str, ipv6: &str, interface: &str) -> Result { + pub fn new( + ipv4: &str, + ipv6: &str, + force_mac_address: &Option, + interface: &str, + ) -> Result { Ok(Self { ipv4: ipv4.to_string(), ipv6: ipv6.to_string(), + force_mac_address: *force_mac_address, interface: interface.to_string(), }) } @@ -121,10 +129,10 @@ impl NetworkBackend { AsyncRawSocket::bound_to_interface(&self.interface, RawSocketProtocol::Ethernet)?; let mtu = kdev.mtu_of_interface(&self.interface)?; let (tx_sender, tx_receiver) = channel::>(4); - let mut udev = ChannelDevice::new(mtu, tx_sender.clone()); - let mac = MacAddr6::random(); + let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone()); + let mac = self.force_mac_address.unwrap_or_else(MacAddr6::random); let mac = smoltcp::wire::EthernetAddress(mac.to_array()); - let nat = NatRouter::new(proxy, mac, addresses.clone(), tx_sender.clone()); + let nat = NatRouter::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone()); let mac = HardwareAddress::Ethernet(mac); let config = Config::new(mac); let mut iface = Interface::new(config, &mut udev, Instant::now()); diff --git a/network/src/chandev.rs b/network/src/chandev.rs index 2a2cc49..17bd4f1 100644 --- a/network/src/chandev.rs +++ b/network/src/chandev.rs @@ -1,17 +1,23 @@ // Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/channelized_smoltcp_device.rs -use log::warn; -use smoltcp::phy::{Checksum, Device}; +use log::{debug, warn}; +use smoltcp::phy::{Checksum, Device, Medium}; use tokio::sync::mpsc::Sender; pub struct ChannelDevice { pub mtu: usize, + pub medium: Medium, pub tx: Sender>, pub rx: Option>, } impl ChannelDevice { - pub fn new(mtu: usize, tx: Sender>) -> Self { - Self { mtu, tx, rx: None } + pub fn new(mtu: usize, medium: Medium, tx: Sender>) -> Self { + Self { + mtu, + medium, + tx, + rx: None, + } } } @@ -30,7 +36,7 @@ impl Device for ChannelDevice { fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { if self.tx.capacity() == 0 { - warn!("ran out of transmission capacity"); + debug!("ran out of transmission capacity"); return None; } Some(self) @@ -38,7 +44,7 @@ impl Device for ChannelDevice { fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities { let mut capabilities = smoltcp::phy::DeviceCapabilities::default(); - capabilities.medium = smoltcp::phy::Medium::Ethernet; + capabilities.medium = self.medium; capabilities.max_transmission_unit = self.mtu; capabilities.checksum = smoltcp::phy::ChecksumCapabilities::ignored(); capabilities.checksum.tcp = Checksum::Tx; diff --git a/network/src/lib.rs b/network/src/lib.rs index 9f82b46..e30caa3 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -1,3 +1,4 @@ +use advmac::MacAddr6; use anyhow::Result; use futures::TryStreamExt; use log::{error, info, warn}; @@ -18,11 +19,20 @@ pub mod raw_socket; pub struct NetworkService { pub ipv4: String, pub ipv6: String, + pub force_mac_address: Option, } impl NetworkService { - pub fn new(ipv4: String, ipv6: String) -> Result { - Ok(NetworkService { ipv4, ipv6 }) + pub fn new( + ipv4: String, + ipv6: String, + force_mac_address: Option, + ) -> Result { + Ok(NetworkService { + ipv4, + ipv6, + force_mac_address, + }) } } @@ -78,7 +88,8 @@ impl NetworkService { spawned: Arc>>, ) -> Result<()> { let interface = interface.to_string(); - let mut network = NetworkBackend::new(&self.ipv4, &self.ipv6, &interface)?; + let mut network = + NetworkBackend::new(&self.ipv4, &self.ipv6, &self.force_mac_address, &interface)?; info!("initializing network backend for interface {}", interface); network.init().await?; tokio::time::sleep(Duration::from_secs(1)).await; diff --git a/network/src/nat.rs b/network/src/nat.rs index 6e7003b..4ad3fe2 100644 --- a/network/src/nat.rs +++ b/network/src/nat.rs @@ -14,7 +14,6 @@ use etherparse::NetSlice; use etherparse::SlicedPacket; use etherparse::TcpHeaderSlice; use etherparse::UdpHeaderSlice; -use log::warn; use log::{debug, trace}; use smoltcp::wire::EthernetAddress; use smoltcp::wire::IpAddress; @@ -53,7 +52,9 @@ impl Display for NatKey { } } +#[derive(Debug)] pub struct NatHandlerContext { + pub mtu: usize, pub key: NatKey, tx_sender: Sender>, reclaim_sender: Sender, @@ -64,13 +65,10 @@ impl NatHandlerContext { self.tx_sender.try_send(buffer)?; Ok(()) } -} -impl Drop for NatHandlerContext { - fn drop(&mut self) { - if let Err(error) = self.reclaim_sender.try_send(self.key) { - warn!("failed to reclaim nat key: {}", error); - } + pub async fn reclaim(&self) -> Result<()> { + self.reclaim_sender.try_send(self.key)?; + Ok(()) } } @@ -103,6 +101,7 @@ impl NatTable { } pub struct NatRouter { + mtu: usize, local_mac: EthernetAddress, local_cidrs: Vec, factory: Box, @@ -114,6 +113,7 @@ pub struct NatRouter { impl NatRouter { pub fn new( + mtu: usize, factory: Box, local_mac: EthernetAddress, local_cidrs: Vec, @@ -121,6 +121,7 @@ impl NatRouter { ) -> Self { let (reclaim_sender, reclaim_receiver) = channel(4); Self { + mtu, local_mac, local_cidrs, factory, @@ -335,6 +336,7 @@ impl NatRouter { } let context = NatHandlerContext { + mtu: self.mtu, key, tx_sender: self.tx_sender.clone(), reclaim_sender: self.reclaim_sender.clone(), diff --git a/network/src/proxynat/icmp.rs b/network/src/proxynat/icmp.rs index 7400c65..d234a8e 100644 --- a/network/src/proxynat/icmp.rs +++ b/network/src/proxynat/icmp.rs @@ -205,6 +205,8 @@ impl ProxyIcmpHandler { } } + context.reclaim().await?; + Ok(()) } } diff --git a/network/src/proxynat/mod.rs b/network/src/proxynat/mod.rs index 2f38ee7..d59d281 100644 --- a/network/src/proxynat/mod.rs +++ b/network/src/proxynat/mod.rs @@ -10,8 +10,10 @@ use crate::proxynat::udp::ProxyUdpHandler; use crate::nat::{NatHandler, NatHandlerFactory, NatKeyProtocol}; use self::icmp::ProxyIcmpHandler; +use self::tcp::ProxyTcpHandler; mod icmp; +mod tcp; mod udp; pub struct ProxyNatHandlerFactory {} @@ -56,7 +58,17 @@ impl NatHandlerFactory for ProxyNatHandlerFactory { } } - _ => None, + NatKeyProtocol::Tcp => { + let (rx_sender, rx_receiver) = channel::>(4); + let mut handler = ProxyTcpHandler::new(rx_sender); + + if let Err(error) = handler.spawn(context, rx_receiver).await { + warn!("unable to spawn tcp proxy handler: {}", error); + None + } else { + Some(Box::new(handler)) + } + } } } } diff --git a/network/src/proxynat/tcp.rs b/network/src/proxynat/tcp.rs new file mode 100644 index 0000000..4154a26 --- /dev/null +++ b/network/src/proxynat/tcp.rs @@ -0,0 +1,396 @@ +use std::{ + net::{IpAddr, SocketAddr}, + time::Duration, +}; + +use anyhow::Result; +use async_trait::async_trait; +use etherparse::{EtherType, Ethernet2Header}; +use log::{debug, warn}; +use smoltcp::{ + iface::{Config, Interface, SocketSet, SocketStorage}, + phy::Medium, + socket::tcp::{self, SocketBuffer, State}, + time::Instant, + wire::{HardwareAddress, IpAddress, IpCidr}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + select, + sync::mpsc::channel, +}; +use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender}; + +use crate::{ + chandev::ChannelDevice, + nat::{NatHandler, NatHandlerContext}, +}; + +const TCP_BUFFER_SIZE: usize = 65535; +const TCP_ACCEPT_TIMEOUT_SECS: u64 = 120; + +pub struct ProxyTcpHandler { + rx_sender: Sender>, +} + +#[async_trait] +impl NatHandler for ProxyTcpHandler { + async fn receive(&self, data: &[u8]) -> Result<()> { + self.rx_sender.try_send(data.to_vec())?; + Ok(()) + } +} + +#[derive(Debug)] +enum ProxyTcpAcceptSelect { + Internal(Vec), + TxIpPacket(Vec), + TimePassed, + DoNothing, + Close, +} + +#[derive(Debug)] +enum ProxyTcpDataSelect { + ExternalRecv(usize), + ExternalSent(usize), + InternalRecv(Vec), + TxIpPacket(Vec), + TimePassed, + DoNothing, + Close, +} + +impl ProxyTcpHandler { + pub fn new(rx_sender: Sender>) -> Self { + ProxyTcpHandler { rx_sender } + } + + pub async fn spawn( + &mut self, + context: NatHandlerContext, + rx_receiver: Receiver>, + ) -> Result<()> { + let external_addr = match context.key.external_ip.addr { + IpAddress::Ipv4(addr) => { + SocketAddr::new(IpAddr::V4(addr.0.into()), context.key.external_ip.port) + } + IpAddress::Ipv6(addr) => { + SocketAddr::new(IpAddr::V6(addr.0.into()), context.key.external_ip.port) + } + }; + + let socket = TcpStream::connect(external_addr).await?; + tokio::spawn(async move { + if let Err(error) = ProxyTcpHandler::process(context, socket, rx_receiver).await { + warn!("processing of tcp proxy failed: {}", error); + } + }); + Ok(()) + } + + async fn process( + context: NatHandlerContext, + mut external_socket: TcpStream, + mut rx_receiver: Receiver>, + ) -> Result<()> { + let (ip_sender, mut ip_receiver) = channel::>(4); + let mut external_buffer = vec![0u8; TCP_BUFFER_SIZE]; + + let mut device = ChannelDevice::new(context.mtu, Medium::Ip, ip_sender.clone()); + let config = Config::new(HardwareAddress::Ip); + + let tcp_rx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]); + let tcp_tx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]); + let internal_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); + let mut iface = Interface::new(config, &mut device, Instant::now()); + + iface.update_ip_addrs(|addrs| { + let _ = addrs.push(IpCidr::new(context.key.external_ip.addr, 0)); + }); + + let mut sockets = SocketSet::new([SocketStorage::EMPTY]); + let internal_socket_handle = sockets.add(internal_socket); + let (mut external_r, mut external_w) = external_socket.split(); + + { + let socket = sockets.get_mut::(internal_socket_handle); + socket.connect( + iface.context(), + context.key.client_ip, + context.key.external_ip, + )?; + } + + iface.poll(Instant::now(), &mut device, &mut sockets); + + let mut sleeper: Option = None; + loop { + let socket = sockets.get_mut::(internal_socket_handle); + if socket.is_active() && socket.state() != State::SynSent { + break; + } + + if socket.state() == State::Closed { + break; + } + + let deadline = tokio::time::sleep(Duration::from_secs(TCP_ACCEPT_TIMEOUT_SECS)); + let selection = if let Some(sleep) = sleeper.take() { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpAcceptSelect::Internal(data) + } else { + ProxyTcpAcceptSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpAcceptSelect::TxIpPacket(data) + } else { + ProxyTcpAcceptSelect::Close + }, + _ = sleep => ProxyTcpAcceptSelect::TimePassed, + _ = deadline => ProxyTcpAcceptSelect::Close, + } + } else { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpAcceptSelect::Internal(data) + } else { + ProxyTcpAcceptSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpAcceptSelect::TxIpPacket(data) + } else { + ProxyTcpAcceptSelect::Close + }, + _ = std::future::ready(()) => ProxyTcpAcceptSelect::DoNothing, + _ = deadline => ProxyTcpAcceptSelect::Close, + } + }; + match selection { + ProxyTcpAcceptSelect::TimePassed => { + iface.poll(Instant::now(), &mut device, &mut sockets); + } + + ProxyTcpAcceptSelect::DoNothing => { + sleeper = Some(tokio::time::sleep(Duration::from_millis(50))); + } + + ProxyTcpAcceptSelect::Internal(data) => { + let (_, payload) = Ethernet2Header::from_slice(&data)?; + device.rx = Some(payload.to_vec()); + iface.poll(Instant::now(), &mut device, &mut sockets); + } + + ProxyTcpAcceptSelect::TxIpPacket(payload) => { + let mut buffer: Vec = Vec::new(); + let header = Ethernet2Header { + source: context.key.local_mac.0, + destination: context.key.client_mac.0, + ether_type: match context.key.external_ip.addr { + IpAddress::Ipv4(_) => EtherType::IPV4, + IpAddress::Ipv6(_) => EtherType::IPV6, + }, + }; + header.write(&mut buffer)?; + buffer.extend_from_slice(&payload); + if let Err(error) = context.try_send(buffer) { + debug!("failed to transmit tcp packet: {}", error); + } + } + + ProxyTcpAcceptSelect::Close => { + break; + } + } + } + + let accepted = if sockets + .get_mut::(internal_socket_handle) + .is_active() + { + debug!("failed to accept tcp connection from client"); + true + } else { + true + }; + + let mut already_shutdown = false; + let mut sleeper: Option = None; + loop { + if !accepted { + break; + } + + let socket = sockets.get_mut::(internal_socket_handle); + + match socket.state() { + State::Closed + | State::Listen + | State::Closing + | State::LastAck + | State::TimeWait => { + break; + } + State::FinWait1 + | State::SynSent + | State::CloseWait + | State::FinWait2 + | State::SynReceived + | State::Established => {} + } + + let bytes_to_client = if socket.can_send() { + socket.send_capacity() - socket.send_queue() + } else { + 0 + }; + + let (bytes_to_external, do_shutdown) = if socket.may_recv() { + if let Ok(data) = socket.peek(TCP_BUFFER_SIZE) { + if data.is_empty() { + (None, false) + } else { + (Some(data), false) + } + } else { + (None, false) + } + } else if !already_shutdown && matches!(socket.state(), State::CloseWait) { + (None, true) + } else { + (None, false) + }; + let selection = if let Some(sleep) = sleeper.take() { + if !do_shutdown { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::InternalRecv(data) + } else { + ProxyTcpDataSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::TxIpPacket(data) + } else { + ProxyTcpDataSelect::Close + }, + x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?), + x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?), + _ = sleep => ProxyTcpDataSelect::TimePassed, + } + } else { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::InternalRecv(data) + } else { + ProxyTcpDataSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::TxIpPacket(data) + } else { + ProxyTcpDataSelect::Close + }, + _ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0), + x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?), + _ = sleep => ProxyTcpDataSelect::TimePassed, + } + } + } else if !do_shutdown { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::InternalRecv(data) + } else { + ProxyTcpDataSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::TxIpPacket(data) + } else { + ProxyTcpDataSelect::Close + }, + x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?), + x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?), + _ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing, + } + } else { + select! { + biased; + x = rx_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::InternalRecv(data) + } else { + ProxyTcpDataSelect::Close + }, + x = ip_receiver.recv() => if let Some(data) = x { + ProxyTcpDataSelect::TxIpPacket(data) + } else { + ProxyTcpDataSelect::Close + }, + _ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0), + x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?), + _ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing, + } + }; + match selection { + ProxyTcpDataSelect::ExternalRecv(size) => { + if size == 0 { + socket.close(); + } else { + socket.send_slice(&external_buffer[..size])?; + } + } + + ProxyTcpDataSelect::ExternalSent(size) => { + if size == 0 { + already_shutdown = true; + } else { + socket.recv(|_| (size, ()))?; + } + } + + ProxyTcpDataSelect::InternalRecv(data) => { + let (_, payload) = Ethernet2Header::from_slice(&data)?; + device.rx = Some(payload.to_vec()); + iface.poll(Instant::now(), &mut device, &mut sockets); + } + + ProxyTcpDataSelect::TxIpPacket(payload) => { + let mut buffer: Vec = Vec::new(); + let header = Ethernet2Header { + source: context.key.local_mac.0, + destination: context.key.client_mac.0, + ether_type: match context.key.external_ip.addr { + IpAddress::Ipv4(_) => EtherType::IPV4, + IpAddress::Ipv6(_) => EtherType::IPV6, + }, + }; + header.write(&mut buffer)?; + buffer.extend_from_slice(&payload); + if let Err(error) = context.try_send(buffer) { + debug!("failed to transmit tcp packet: {}", error); + } + } + + ProxyTcpDataSelect::TimePassed => { + iface.poll(Instant::now(), &mut device, &mut sockets); + } + + ProxyTcpDataSelect::DoNothing => { + sleeper = Some(tokio::time::sleep(Duration::from_millis(50))); + } + + ProxyTcpDataSelect::Close => { + break; + } + } + } + + context.reclaim().await?; + + Ok(()) + } +} diff --git a/network/src/proxynat/udp.rs b/network/src/proxynat/udp.rs index d437d1b..bc1dd1a 100644 --- a/network/src/proxynat/udp.rs +++ b/network/src/proxynat/udp.rs @@ -128,6 +128,8 @@ impl ProxyUdpHandler { } } + context.reclaim().await?; + Ok(()) } }