From fdd70dee9bcc30221d040d71c7ce447770500b1d Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Tue, 13 Feb 2024 14:58:21 +0000 Subject: [PATCH] network: rework raw sockets to use channels --- network/src/backend.rs | 42 +++++----- network/src/raw_socket.rs | 163 ++++++++++++++++++++++---------------- network/src/vbridge.rs | 68 ++++++++-------- 3 files changed, 149 insertions(+), 124 deletions(-) diff --git a/network/src/backend.rs b/network/src/backend.rs index d1b13c8..8fa5384 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -3,7 +3,7 @@ use crate::chandev::ChannelDevice; use crate::nat::NatRouter; use crate::pkt::RecvPacket; use crate::proxynat::ProxyNatHandlerFactory; -use crate::raw_socket::{AsyncRawSocket, RawSocketProtocol}; +use crate::raw_socket::{AsyncRawSocketChannel, RawSocketHandle, RawSocketProtocol}; use crate::vbridge::{BridgeJoinHandle, VirtualBridge}; use anyhow::{anyhow, Result}; use bytes::BytesMut; @@ -14,7 +14,6 @@ use smoltcp::iface::{Config, Interface, SocketSet}; use smoltcp::phy::Medium; use smoltcp::time::Instant; use smoltcp::wire::{HardwareAddress, IpCidr}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::select; use tokio::sync::mpsc::{channel, Receiver}; @@ -26,16 +25,16 @@ pub struct NetworkBackend { bridge: VirtualBridge, } -enum NetworkStackSelect<'a> { - Receive(&'a [u8]), +#[derive(Debug)] +enum NetworkStackSelect { + Receive(Option), Send(Option), Reclaim, } struct NetworkStack<'a> { - mtu: usize, tx: Receiver, - kdev: AsyncRawSocket, + kdev: AsyncRawSocketChannel, udev: ChannelDevice, interface: Interface, sockets: SocketSet<'a>, @@ -44,23 +43,23 @@ struct NetworkStack<'a> { } impl NetworkStack<'_> { - async fn poll(&mut self, buffer: &mut [u8]) -> Result<()> { + async fn poll(&mut self) -> Result<()> { let what = select! { - x = self.kdev.read(buffer) => NetworkStackSelect::Receive(&buffer[0..x?]), - x = self.bridge.bridge_rx_receiver.recv() => NetworkStackSelect::Send(x), - x = self.bridge.broadcast_rx_receiver.recv() => NetworkStackSelect::Send(x.ok()), + x = self.kdev.receiver.recv() => NetworkStackSelect::Receive(x), + x = self.bridge.from_bridge_receiver.recv() => NetworkStackSelect::Send(x), + x = self.bridge.from_broadcast_receiver.recv() => NetworkStackSelect::Send(x.ok()), x = self.tx.recv() => NetworkStackSelect::Send(x), _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, }; match what { - NetworkStackSelect::Receive(packet) => { - if let Err(error) = self.bridge.bridge_tx_sender.try_send(packet.into()) { + NetworkStackSelect::Receive(Some(packet)) => { + if let Err(error) = self.bridge.to_bridge_sender.try_send(packet.clone()) { trace!("failed to send guest packet to bridge: {}", error); } - if let Ok(slice) = SlicedPacket::from_ethernet(packet) { - let packet = RecvPacket::new(packet, &slice)?; + if let Ok(slice) = SlicedPacket::from_ethernet(&packet) { + let packet = RecvPacket::new(&packet, &slice)?; if let Err(error) = self.router.process(&packet).await { debug!("router failed to process packet: {}", error); } @@ -71,8 +70,13 @@ impl NetworkStack<'_> { } } - NetworkStackSelect::Send(Some(packet)) => self.kdev.write_all(&packet).await?, + NetworkStackSelect::Send(Some(packet)) => { + if let Err(error) = self.kdev.sender.try_send(packet) { + warn!("failed to transmit packet to interface: {}", error); + } + } + NetworkStackSelect::Receive(None) => {} NetworkStackSelect::Send(None) => {} NetworkStackSelect::Reclaim => {} @@ -107,9 +111,8 @@ impl NetworkBackend { pub async fn run(&self) -> Result<()> { let mut stack = self.create_network_stack().await?; - let mut buffer = vec![0u8; stack.mtu]; loop { - stack.poll(&mut buffer).await?; + stack.poll().await?; } } @@ -120,7 +123,8 @@ impl NetworkBackend { self.metadata.gateway.ipv4.into(), self.metadata.gateway.ipv6.into(), ]; - let mut kdev = AsyncRawSocket::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?; + let mut kdev = + RawSocketHandle::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?; let mtu = kdev.mtu_of_interface(&interface)?; let (tx_sender, tx_receiver) = channel::(TX_CHANNEL_BUFFER_LEN); let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone()); @@ -136,8 +140,8 @@ impl NetworkBackend { }); let sockets = SocketSet::new(vec![]); let handle = self.bridge.join(self.metadata.guest.mac).await?; + let kdev = AsyncRawSocketChannel::new(kdev)?; Ok(NetworkStack { - mtu, tx: tx_receiver, kdev, udev, diff --git a/network/src/raw_socket.rs b/network/src/raw_socket.rs index ac823a4..d02b918 100644 --- a/network/src/raw_socket.rs +++ b/network/src/raw_socket.rs @@ -1,12 +1,18 @@ -use anyhow::Result; -use futures::ready; -use std::os::fd::IntoRawFd; +use anyhow::{anyhow, Result}; +use bytes::BytesMut; +use log::warn; +use std::io::ErrorKind; +use std::os::fd::{FromRawFd, IntoRawFd}; use std::os::unix::io::{AsRawFd, RawFd}; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::sync::Arc; use std::{io, mem}; -use tokio::io::unix::AsyncFd; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::task::JoinHandle; + +const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 500; +const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 500; #[derive(Debug)] pub enum RawSocketProtocol { @@ -186,80 +192,99 @@ fn ifreq_ioctl( Ok(ifreq.ifr_data) } -pub struct AsyncRawSocket { - inner: AsyncFd, +pub struct AsyncRawSocketChannel { + pub sender: Sender, + pub receiver: Receiver, + _task: Arc>, } -impl AsyncRawSocket { - pub fn new(socket: RawSocketHandle) -> Result { - Ok(Self { - inner: AsyncFd::new(socket)?, +enum AsyncRawSocketChannelSelect { + TransmitPacket(Option), + Readable(()), +} + +impl AsyncRawSocketChannel { + pub fn new(socket: RawSocketHandle) -> Result { + let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN); + let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN); + let task = AsyncRawSocketChannel::launch(socket, transmit_receiver, receive_sender)?; + Ok(AsyncRawSocketChannel { + sender: transmit_sender, + receiver: receive_receiver, + _task: Arc::new(task), }) } - pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result { - let socket = RawSocketHandle::bound_to_interface(interface, protocol)?; - AsyncRawSocket::new(socket) + fn launch( + socket: RawSocketHandle, + transmit_receiver: Receiver, + receive_sender: Sender, + ) -> Result> { + Ok(tokio::task::spawn(async move { + if let Err(error) = + AsyncRawSocketChannel::process(socket, transmit_receiver, receive_sender).await + { + warn!("failed to process raw socket: {}", error); + } + })) } - pub fn mtu_of_interface(&mut self, interface: &str) -> Result { - Ok(self.inner.get_mut().mtu_of_interface(interface)?) - } -} + async fn process( + socket: RawSocketHandle, + mut transmit_receiver: Receiver, + receive_sender: Sender, + ) -> Result<()> { + let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) }; + let socket = UdpSocket::from_std(socket)?; -impl TryFrom for AsyncRawSocket { - type Error = anyhow::Error; - - fn try_from(value: RawSocketHandle) -> Result { - Ok(Self { - inner: AsyncFd::new(value)?, - }) - } -} - -impl AsyncRead for AsyncRawSocket { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { loop { - let mut guard = ready!(self.inner.poll_read_ready(cx))?; + let selection = select! { + x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x), + x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?), + }; - let unfilled = buf.initialize_unfilled(); - match guard.try_io(|inner| inner.get_ref().recv(unfilled)) { - Ok(Ok(len)) => { - buf.advance(len); - return Poll::Ready(Ok(())); + match selection { + AsyncRawSocketChannelSelect::Readable(_) => { + let mut buffer = vec![0; 1500]; + match socket.try_recv(&mut buffer) { + Ok(len) => { + if len == 0 { + continue; + } + let buffer = (&buffer[0..len]).into(); + if let Err(error) = receive_sender.try_send(buffer) { + warn!("raw socket failed to process received packet: {}", error); + } + } + + Err(ref error) => { + if error.kind() == ErrorKind::WouldBlock { + continue; + } + return Err(anyhow!("failed to read from raw socket: {}", error)); + } + }; + } + + AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => { + match socket.try_send(&packet) { + Ok(_len) => {} + Err(ref error) => { + if error.kind() == ErrorKind::WouldBlock { + warn!("failed to transmit: would block"); + continue; + } + return Err(anyhow!("failed to write to raw socket: {}", error)); + } + }; + } + + AsyncRawSocketChannelSelect::TransmitPacket(None) => { + break; } - Ok(Err(err)) => return Poll::Ready(Err(err)), - Err(_would_block) => continue, } } - } -} - -impl AsyncWrite for AsyncRawSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - loop { - let mut guard = ready!(self.inner.poll_write_ready(cx))?; - - match guard.try_io(|inner| inner.get_ref().send(buf)) { - Ok(result) => return Poll::Ready(result), - Err(_would_block) => continue, - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + + Ok(()) } } diff --git a/network/src/vbridge.rs b/network/src/vbridge.rs index 9ba3ff5..a5fab8e 100644 --- a/network/src/vbridge.rs +++ b/network/src/vbridge.rs @@ -19,19 +19,19 @@ use tokio::{ task::JoinHandle, }; -const BRIDGE_TX_QUEUE_LEN: usize = 50; -const BRIDGE_RX_QUEUE_LEN: usize = 50; -const BROADCAST_RX_QUEUE_LEN: usize = 50; +const TO_BRIDGE_QUEUE_LEN: usize = 50; +const FROM_BRIDGE_QUEUE_LEN: usize = 50; +const BROADCAST_QUEUE_LEN: usize = 50; #[derive(Debug)] struct BridgeMember { - pub bridge_rx_sender: Sender, + pub from_bridge_sender: Sender, } pub struct BridgeJoinHandle { - pub bridge_tx_sender: Sender, - pub bridge_rx_receiver: Receiver, - pub broadcast_rx_receiver: BroadcastReceiver, + pub to_bridge_sender: Sender, + pub from_bridge_receiver: Receiver, + pub from_broadcast_receiver: BroadcastReceiver, } type VirtualBridgeMemberMap = Arc>>; @@ -39,8 +39,8 @@ type VirtualBridgeMemberMap = Arc>> #[derive(Clone)] pub struct VirtualBridge { members: VirtualBridgeMemberMap, - bridge_tx_sender: Sender, - broadcast_rx_sender: BroadcastSender, + to_bridge_sender: Sender, + from_broadcast_sender: BroadcastSender, _task: Arc>, } @@ -51,20 +51,20 @@ enum VirtualBridgeSelect { impl VirtualBridge { pub fn new() -> Result { - let (bridge_tx_sender, bridge_tx_receiver) = channel::(BRIDGE_TX_QUEUE_LEN); - let (broadcast_rx_sender, broadcast_rx_receiver) = - broadcast_channel(BROADCAST_RX_QUEUE_LEN); + let (to_bridge_sender, to_bridge_receiver) = channel::(TO_BRIDGE_QUEUE_LEN); + let (from_broadcast_sender, from_broadcast_receiver) = + broadcast_channel(BROADCAST_QUEUE_LEN); let members = Arc::new(Mutex::new(HashMap::new())); let handle = { let members = members.clone(); - let broadcast_rx_sender = broadcast_rx_sender.clone(); + let broadcast_rx_sender = from_broadcast_sender.clone(); tokio::task::spawn(async move { if let Err(error) = VirtualBridge::process( members, - bridge_tx_receiver, + to_bridge_receiver, broadcast_rx_sender, - broadcast_rx_receiver, + from_broadcast_receiver, ) .await { @@ -74,52 +74,48 @@ impl VirtualBridge { }; Ok(VirtualBridge { - bridge_tx_sender, + to_bridge_sender, members, - broadcast_rx_sender, + from_broadcast_sender, _task: Arc::new(handle), }) } pub async fn join(&self, mac: EthernetAddress) -> Result { - let (bridge_rx_sender, bridge_rx_receiver) = channel::(BRIDGE_RX_QUEUE_LEN); - let member = BridgeMember { bridge_rx_sender }; + let (from_bridge_sender, from_bridge_receiver) = channel::(FROM_BRIDGE_QUEUE_LEN); + let member = BridgeMember { from_bridge_sender }; match self.members.lock().await.entry(mac) { Entry::Occupied(_) => { - return Err(anyhow!( - "virtual bridge already has a member with address {}", - mac - )); + return Err(anyhow!("virtual bridge member {} already exists", mac)); } Entry::Vacant(entry) => { entry.insert(member); } }; - debug!("virtual bridge member has joined: {}", mac); + debug!("virtual bridge member {} has joined", mac); Ok(BridgeJoinHandle { - bridge_rx_receiver, - broadcast_rx_receiver: self.broadcast_rx_sender.subscribe(), - bridge_tx_sender: self.bridge_tx_sender.clone(), + from_bridge_receiver, + from_broadcast_receiver: self.from_broadcast_sender.subscribe(), + to_bridge_sender: self.to_bridge_sender.clone(), }) } async fn process( members: VirtualBridgeMemberMap, - mut bridge_tx_receiver: Receiver, + mut to_bridge_receiver: Receiver, broadcast_rx_sender: BroadcastSender, - mut broadcast_rx_receiver: BroadcastReceiver, + mut from_broadcast_receiver: BroadcastReceiver, ) -> Result<()> { loop { let selection = select! { biased; - x = bridge_tx_receiver.recv() => VirtualBridgeSelect::PacketReceived(x), - x = broadcast_rx_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()), + x = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()), + x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x), }; match selection { - VirtualBridgeSelect::PacketReceived(Some(packet)) => { - let mut packet: Vec = packet.into(); + VirtualBridgeSelect::PacketReceived(Some(mut packet)) => { let (header, payload) = match Ethernet2Header::from_slice(&packet) { Ok(data) => data, Err(error) => { @@ -149,15 +145,15 @@ impl VirtualBridge { let destination = EthernetAddress(header.destination); if destination.is_multicast() { trace!( - "broadcasting bridged packet from {}", + "broadcasting bridge packet from {}", EthernetAddress(header.source) ); - broadcast_rx_sender.send(packet.as_slice().into())?; + broadcast_rx_sender.send(packet)?; continue; } match members.lock().await.get(&destination) { Some(member) => { - member.bridge_rx_sender.try_send(packet.as_slice().into())?; + member.from_bridge_sender.try_send(packet)?; trace!( "sending bridged packet from {} to {}", EthernetAddress(header.source),