network: rework raw sockets to use channels

This commit is contained in:
Alex Zenla 2024-02-13 14:58:21 +00:00
parent b7db12cf68
commit fdd70dee9b
No known key found for this signature in database
GPG Key ID: 067B238899B51269
3 changed files with 149 additions and 124 deletions

View File

@ -3,7 +3,7 @@ use crate::chandev::ChannelDevice;
use crate::nat::NatRouter; use crate::nat::NatRouter;
use crate::pkt::RecvPacket; use crate::pkt::RecvPacket;
use crate::proxynat::ProxyNatHandlerFactory; use crate::proxynat::ProxyNatHandlerFactory;
use crate::raw_socket::{AsyncRawSocket, RawSocketProtocol}; use crate::raw_socket::{AsyncRawSocketChannel, RawSocketHandle, RawSocketProtocol};
use crate::vbridge::{BridgeJoinHandle, VirtualBridge}; use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use bytes::BytesMut; use bytes::BytesMut;
@ -14,7 +14,6 @@ use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::phy::Medium; use smoltcp::phy::Medium;
use smoltcp::time::Instant; use smoltcp::time::Instant;
use smoltcp::wire::{HardwareAddress, IpCidr}; use smoltcp::wire::{HardwareAddress, IpCidr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::select; use tokio::select;
use tokio::sync::mpsc::{channel, Receiver}; use tokio::sync::mpsc::{channel, Receiver};
@ -26,16 +25,16 @@ pub struct NetworkBackend {
bridge: VirtualBridge, bridge: VirtualBridge,
} }
enum NetworkStackSelect<'a> { #[derive(Debug)]
Receive(&'a [u8]), enum NetworkStackSelect {
Receive(Option<BytesMut>),
Send(Option<BytesMut>), Send(Option<BytesMut>),
Reclaim, Reclaim,
} }
struct NetworkStack<'a> { struct NetworkStack<'a> {
mtu: usize,
tx: Receiver<BytesMut>, tx: Receiver<BytesMut>,
kdev: AsyncRawSocket, kdev: AsyncRawSocketChannel,
udev: ChannelDevice, udev: ChannelDevice,
interface: Interface, interface: Interface,
sockets: SocketSet<'a>, sockets: SocketSet<'a>,
@ -44,23 +43,23 @@ struct NetworkStack<'a> {
} }
impl NetworkStack<'_> { impl NetworkStack<'_> {
async fn poll(&mut self, buffer: &mut [u8]) -> Result<()> { async fn poll(&mut self) -> Result<()> {
let what = select! { let what = select! {
x = self.kdev.read(buffer) => NetworkStackSelect::Receive(&buffer[0..x?]), x = self.kdev.receiver.recv() => NetworkStackSelect::Receive(x),
x = self.bridge.bridge_rx_receiver.recv() => NetworkStackSelect::Send(x), x = self.bridge.from_bridge_receiver.recv() => NetworkStackSelect::Send(x),
x = self.bridge.broadcast_rx_receiver.recv() => NetworkStackSelect::Send(x.ok()), x = self.bridge.from_broadcast_receiver.recv() => NetworkStackSelect::Send(x.ok()),
x = self.tx.recv() => NetworkStackSelect::Send(x), x = self.tx.recv() => NetworkStackSelect::Send(x),
_ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim,
}; };
match what { match what {
NetworkStackSelect::Receive(packet) => { NetworkStackSelect::Receive(Some(packet)) => {
if let Err(error) = self.bridge.bridge_tx_sender.try_send(packet.into()) { if let Err(error) = self.bridge.to_bridge_sender.try_send(packet.clone()) {
trace!("failed to send guest packet to bridge: {}", error); trace!("failed to send guest packet to bridge: {}", error);
} }
if let Ok(slice) = SlicedPacket::from_ethernet(packet) { if let Ok(slice) = SlicedPacket::from_ethernet(&packet) {
let packet = RecvPacket::new(packet, &slice)?; let packet = RecvPacket::new(&packet, &slice)?;
if let Err(error) = self.router.process(&packet).await { if let Err(error) = self.router.process(&packet).await {
debug!("router failed to process packet: {}", error); debug!("router failed to process packet: {}", error);
} }
@ -71,8 +70,13 @@ impl NetworkStack<'_> {
} }
} }
NetworkStackSelect::Send(Some(packet)) => self.kdev.write_all(&packet).await?, NetworkStackSelect::Send(Some(packet)) => {
if let Err(error) = self.kdev.sender.try_send(packet) {
warn!("failed to transmit packet to interface: {}", error);
}
}
NetworkStackSelect::Receive(None) => {}
NetworkStackSelect::Send(None) => {} NetworkStackSelect::Send(None) => {}
NetworkStackSelect::Reclaim => {} NetworkStackSelect::Reclaim => {}
@ -107,9 +111,8 @@ impl NetworkBackend {
pub async fn run(&self) -> Result<()> { pub async fn run(&self) -> Result<()> {
let mut stack = self.create_network_stack().await?; let mut stack = self.create_network_stack().await?;
let mut buffer = vec![0u8; stack.mtu];
loop { loop {
stack.poll(&mut buffer).await?; stack.poll().await?;
} }
} }
@ -120,7 +123,8 @@ impl NetworkBackend {
self.metadata.gateway.ipv4.into(), self.metadata.gateway.ipv4.into(),
self.metadata.gateway.ipv6.into(), self.metadata.gateway.ipv6.into(),
]; ];
let mut kdev = AsyncRawSocket::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?; let mut kdev =
RawSocketHandle::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?;
let mtu = kdev.mtu_of_interface(&interface)?; let mtu = kdev.mtu_of_interface(&interface)?;
let (tx_sender, tx_receiver) = channel::<BytesMut>(TX_CHANNEL_BUFFER_LEN); let (tx_sender, tx_receiver) = channel::<BytesMut>(TX_CHANNEL_BUFFER_LEN);
let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone()); let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone());
@ -136,8 +140,8 @@ impl NetworkBackend {
}); });
let sockets = SocketSet::new(vec![]); let sockets = SocketSet::new(vec![]);
let handle = self.bridge.join(self.metadata.guest.mac).await?; let handle = self.bridge.join(self.metadata.guest.mac).await?;
let kdev = AsyncRawSocketChannel::new(kdev)?;
Ok(NetworkStack { Ok(NetworkStack {
mtu,
tx: tx_receiver, tx: tx_receiver,
kdev, kdev,
udev, udev,

View File

@ -1,12 +1,18 @@
use anyhow::Result; use anyhow::{anyhow, Result};
use futures::ready; use bytes::BytesMut;
use std::os::fd::IntoRawFd; use log::warn;
use std::io::ErrorKind;
use std::os::fd::{FromRawFd, IntoRawFd};
use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin; use std::sync::Arc;
use std::task::{Context, Poll};
use std::{io, mem}; use std::{io, mem};
use tokio::io::unix::AsyncFd; use tokio::net::UdpSocket;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::select;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 500;
const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 500;
#[derive(Debug)] #[derive(Debug)]
pub enum RawSocketProtocol { pub enum RawSocketProtocol {
@ -186,80 +192,99 @@ fn ifreq_ioctl(
Ok(ifreq.ifr_data) Ok(ifreq.ifr_data)
} }
pub struct AsyncRawSocket { pub struct AsyncRawSocketChannel {
inner: AsyncFd<RawSocketHandle>, pub sender: Sender<BytesMut>,
pub receiver: Receiver<BytesMut>,
_task: Arc<JoinHandle<()>>,
} }
impl AsyncRawSocket { enum AsyncRawSocketChannelSelect {
pub fn new(socket: RawSocketHandle) -> Result<Self> { TransmitPacket(Option<BytesMut>),
Ok(Self { Readable(()),
inner: AsyncFd::new(socket)?, }
impl AsyncRawSocketChannel {
pub fn new(socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
let task = AsyncRawSocketChannel::launch(socket, transmit_receiver, receive_sender)?;
Ok(AsyncRawSocketChannel {
sender: transmit_sender,
receiver: receive_receiver,
_task: Arc::new(task),
}) })
} }
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> { fn launch(
let socket = RawSocketHandle::bound_to_interface(interface, protocol)?; socket: RawSocketHandle,
AsyncRawSocket::new(socket) transmit_receiver: Receiver<BytesMut>,
receive_sender: Sender<BytesMut>,
) -> Result<JoinHandle<()>> {
Ok(tokio::task::spawn(async move {
if let Err(error) =
AsyncRawSocketChannel::process(socket, transmit_receiver, receive_sender).await
{
warn!("failed to process raw socket: {}", error);
}
}))
} }
pub fn mtu_of_interface(&mut self, interface: &str) -> Result<usize> { async fn process(
Ok(self.inner.get_mut().mtu_of_interface(interface)?) socket: RawSocketHandle,
} mut transmit_receiver: Receiver<BytesMut>,
} receive_sender: Sender<BytesMut>,
) -> Result<()> {
let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
let socket = UdpSocket::from_std(socket)?;
impl TryFrom<RawSocketHandle> for AsyncRawSocket {
type Error = anyhow::Error;
fn try_from(value: RawSocketHandle) -> Result<Self, Self::Error> {
Ok(Self {
inner: AsyncFd::new(value)?,
})
}
}
impl AsyncRead for AsyncRawSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop { loop {
let mut guard = ready!(self.inner.poll_read_ready(cx))?; let selection = select! {
x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
};
let unfilled = buf.initialize_unfilled(); match selection {
match guard.try_io(|inner| inner.get_ref().recv(unfilled)) { AsyncRawSocketChannelSelect::Readable(_) => {
Ok(Ok(len)) => { let mut buffer = vec![0; 1500];
buf.advance(len); match socket.try_recv(&mut buffer) {
return Poll::Ready(Ok(())); Ok(len) => {
if len == 0 {
continue;
}
let buffer = (&buffer[0..len]).into();
if let Err(error) = receive_sender.try_send(buffer) {
warn!("raw socket failed to process received packet: {}", error);
}
}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
continue;
}
return Err(anyhow!("failed to read from raw socket: {}", error));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
match socket.try_send(&packet) {
Ok(_len) => {}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
warn!("failed to transmit: would block");
continue;
}
return Err(anyhow!("failed to write to raw socket: {}", error));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(None) => {
break;
} }
Ok(Err(err)) => return Poll::Ready(Err(err)),
Err(_would_block) => continue,
} }
} }
}
} Ok(())
impl AsyncWrite for AsyncRawSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready(cx))?;
match guard.try_io(|inner| inner.get_ref().send(buf)) {
Ok(result) => return Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
} }
} }

View File

@ -19,19 +19,19 @@ use tokio::{
task::JoinHandle, task::JoinHandle,
}; };
const BRIDGE_TX_QUEUE_LEN: usize = 50; const TO_BRIDGE_QUEUE_LEN: usize = 50;
const BRIDGE_RX_QUEUE_LEN: usize = 50; const FROM_BRIDGE_QUEUE_LEN: usize = 50;
const BROADCAST_RX_QUEUE_LEN: usize = 50; const BROADCAST_QUEUE_LEN: usize = 50;
#[derive(Debug)] #[derive(Debug)]
struct BridgeMember { struct BridgeMember {
pub bridge_rx_sender: Sender<BytesMut>, pub from_bridge_sender: Sender<BytesMut>,
} }
pub struct BridgeJoinHandle { pub struct BridgeJoinHandle {
pub bridge_tx_sender: Sender<BytesMut>, pub to_bridge_sender: Sender<BytesMut>,
pub bridge_rx_receiver: Receiver<BytesMut>, pub from_bridge_receiver: Receiver<BytesMut>,
pub broadcast_rx_receiver: BroadcastReceiver<BytesMut>, pub from_broadcast_receiver: BroadcastReceiver<BytesMut>,
} }
type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>; type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>;
@ -39,8 +39,8 @@ type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>
#[derive(Clone)] #[derive(Clone)]
pub struct VirtualBridge { pub struct VirtualBridge {
members: VirtualBridgeMemberMap, members: VirtualBridgeMemberMap,
bridge_tx_sender: Sender<BytesMut>, to_bridge_sender: Sender<BytesMut>,
broadcast_rx_sender: BroadcastSender<BytesMut>, from_broadcast_sender: BroadcastSender<BytesMut>,
_task: Arc<JoinHandle<()>>, _task: Arc<JoinHandle<()>>,
} }
@ -51,20 +51,20 @@ enum VirtualBridgeSelect {
impl VirtualBridge { impl VirtualBridge {
pub fn new() -> Result<VirtualBridge> { pub fn new() -> Result<VirtualBridge> {
let (bridge_tx_sender, bridge_tx_receiver) = channel::<BytesMut>(BRIDGE_TX_QUEUE_LEN); let (to_bridge_sender, to_bridge_receiver) = channel::<BytesMut>(TO_BRIDGE_QUEUE_LEN);
let (broadcast_rx_sender, broadcast_rx_receiver) = let (from_broadcast_sender, from_broadcast_receiver) =
broadcast_channel(BROADCAST_RX_QUEUE_LEN); broadcast_channel(BROADCAST_QUEUE_LEN);
let members = Arc::new(Mutex::new(HashMap::new())); let members = Arc::new(Mutex::new(HashMap::new()));
let handle = { let handle = {
let members = members.clone(); let members = members.clone();
let broadcast_rx_sender = broadcast_rx_sender.clone(); let broadcast_rx_sender = from_broadcast_sender.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
if let Err(error) = VirtualBridge::process( if let Err(error) = VirtualBridge::process(
members, members,
bridge_tx_receiver, to_bridge_receiver,
broadcast_rx_sender, broadcast_rx_sender,
broadcast_rx_receiver, from_broadcast_receiver,
) )
.await .await
{ {
@ -74,52 +74,48 @@ impl VirtualBridge {
}; };
Ok(VirtualBridge { Ok(VirtualBridge {
bridge_tx_sender, to_bridge_sender,
members, members,
broadcast_rx_sender, from_broadcast_sender,
_task: Arc::new(handle), _task: Arc::new(handle),
}) })
} }
pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> { pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> {
let (bridge_rx_sender, bridge_rx_receiver) = channel::<BytesMut>(BRIDGE_RX_QUEUE_LEN); let (from_bridge_sender, from_bridge_receiver) = channel::<BytesMut>(FROM_BRIDGE_QUEUE_LEN);
let member = BridgeMember { bridge_rx_sender }; let member = BridgeMember { from_bridge_sender };
match self.members.lock().await.entry(mac) { match self.members.lock().await.entry(mac) {
Entry::Occupied(_) => { Entry::Occupied(_) => {
return Err(anyhow!( return Err(anyhow!("virtual bridge member {} already exists", mac));
"virtual bridge already has a member with address {}",
mac
));
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
entry.insert(member); entry.insert(member);
} }
}; };
debug!("virtual bridge member has joined: {}", mac); debug!("virtual bridge member {} has joined", mac);
Ok(BridgeJoinHandle { Ok(BridgeJoinHandle {
bridge_rx_receiver, from_bridge_receiver,
broadcast_rx_receiver: self.broadcast_rx_sender.subscribe(), from_broadcast_receiver: self.from_broadcast_sender.subscribe(),
bridge_tx_sender: self.bridge_tx_sender.clone(), to_bridge_sender: self.to_bridge_sender.clone(),
}) })
} }
async fn process( async fn process(
members: VirtualBridgeMemberMap, members: VirtualBridgeMemberMap,
mut bridge_tx_receiver: Receiver<BytesMut>, mut to_bridge_receiver: Receiver<BytesMut>,
broadcast_rx_sender: BroadcastSender<BytesMut>, broadcast_rx_sender: BroadcastSender<BytesMut>,
mut broadcast_rx_receiver: BroadcastReceiver<BytesMut>, mut from_broadcast_receiver: BroadcastReceiver<BytesMut>,
) -> Result<()> { ) -> Result<()> {
loop { loop {
let selection = select! { let selection = select! {
biased; biased;
x = bridge_tx_receiver.recv() => VirtualBridgeSelect::PacketReceived(x), x = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()),
x = broadcast_rx_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()), x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x),
}; };
match selection { match selection {
VirtualBridgeSelect::PacketReceived(Some(packet)) => { VirtualBridgeSelect::PacketReceived(Some(mut packet)) => {
let mut packet: Vec<u8> = packet.into();
let (header, payload) = match Ethernet2Header::from_slice(&packet) { let (header, payload) = match Ethernet2Header::from_slice(&packet) {
Ok(data) => data, Ok(data) => data,
Err(error) => { Err(error) => {
@ -149,15 +145,15 @@ impl VirtualBridge {
let destination = EthernetAddress(header.destination); let destination = EthernetAddress(header.destination);
if destination.is_multicast() { if destination.is_multicast() {
trace!( trace!(
"broadcasting bridged packet from {}", "broadcasting bridge packet from {}",
EthernetAddress(header.source) EthernetAddress(header.source)
); );
broadcast_rx_sender.send(packet.as_slice().into())?; broadcast_rx_sender.send(packet)?;
continue; continue;
} }
match members.lock().await.get(&destination) { match members.lock().await.get(&destination) {
Some(member) => { Some(member) => {
member.bridge_rx_sender.try_send(packet.as_slice().into())?; member.from_bridge_sender.try_send(packet)?;
trace!( trace!(
"sending bridged packet from {} to {}", "sending bridged packet from {} to {}",
EthernetAddress(header.source), EthernetAddress(header.source),