network: rework raw sockets to use channels

This commit is contained in:
Alex Zenla
2024-02-13 14:58:21 +00:00
parent b7db12cf68
commit fdd70dee9b
3 changed files with 149 additions and 124 deletions

View File

@ -1,12 +1,18 @@
use anyhow::Result;
use futures::ready;
use std::os::fd::IntoRawFd;
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use log::warn;
use std::io::ErrorKind;
use std::os::fd::{FromRawFd, IntoRawFd};
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::Arc;
use std::{io, mem};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 500;
const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 500;
#[derive(Debug)]
pub enum RawSocketProtocol {
@ -186,80 +192,99 @@ fn ifreq_ioctl(
Ok(ifreq.ifr_data)
}
pub struct AsyncRawSocket {
inner: AsyncFd<RawSocketHandle>,
pub struct AsyncRawSocketChannel {
pub sender: Sender<BytesMut>,
pub receiver: Receiver<BytesMut>,
_task: Arc<JoinHandle<()>>,
}
impl AsyncRawSocket {
pub fn new(socket: RawSocketHandle) -> Result<Self> {
Ok(Self {
inner: AsyncFd::new(socket)?,
enum AsyncRawSocketChannelSelect {
TransmitPacket(Option<BytesMut>),
Readable(()),
}
impl AsyncRawSocketChannel {
pub fn new(socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
let task = AsyncRawSocketChannel::launch(socket, transmit_receiver, receive_sender)?;
Ok(AsyncRawSocketChannel {
sender: transmit_sender,
receiver: receive_receiver,
_task: Arc::new(task),
})
}
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
let socket = RawSocketHandle::bound_to_interface(interface, protocol)?;
AsyncRawSocket::new(socket)
fn launch(
socket: RawSocketHandle,
transmit_receiver: Receiver<BytesMut>,
receive_sender: Sender<BytesMut>,
) -> Result<JoinHandle<()>> {
Ok(tokio::task::spawn(async move {
if let Err(error) =
AsyncRawSocketChannel::process(socket, transmit_receiver, receive_sender).await
{
warn!("failed to process raw socket: {}", error);
}
}))
}
pub fn mtu_of_interface(&mut self, interface: &str) -> Result<usize> {
Ok(self.inner.get_mut().mtu_of_interface(interface)?)
}
}
async fn process(
socket: RawSocketHandle,
mut transmit_receiver: Receiver<BytesMut>,
receive_sender: Sender<BytesMut>,
) -> Result<()> {
let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
let socket = UdpSocket::from_std(socket)?;
impl TryFrom<RawSocketHandle> for AsyncRawSocket {
type Error = anyhow::Error;
fn try_from(value: RawSocketHandle) -> Result<Self, Self::Error> {
Ok(Self {
inner: AsyncFd::new(value)?,
})
}
}
impl AsyncRead for AsyncRawSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
loop {
let mut guard = ready!(self.inner.poll_read_ready(cx))?;
let selection = select! {
x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
};
let unfilled = buf.initialize_unfilled();
match guard.try_io(|inner| inner.get_ref().recv(unfilled)) {
Ok(Ok(len)) => {
buf.advance(len);
return Poll::Ready(Ok(()));
match selection {
AsyncRawSocketChannelSelect::Readable(_) => {
let mut buffer = vec![0; 1500];
match socket.try_recv(&mut buffer) {
Ok(len) => {
if len == 0 {
continue;
}
let buffer = (&buffer[0..len]).into();
if let Err(error) = receive_sender.try_send(buffer) {
warn!("raw socket failed to process received packet: {}", error);
}
}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
continue;
}
return Err(anyhow!("failed to read from raw socket: {}", error));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
match socket.try_send(&packet) {
Ok(_len) => {}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
warn!("failed to transmit: would block");
continue;
}
return Err(anyhow!("failed to write to raw socket: {}", error));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(None) => {
break;
}
Ok(Err(err)) => return Poll::Ready(Err(err)),
Err(_would_block) => continue,
}
}
}
}
impl AsyncWrite for AsyncRawSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
let mut guard = ready!(self.inner.poll_write_ready(cx))?;
match guard.try_io(|inner| inner.get_ref().send(buf)) {
Ok(result) => return Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
Ok(())
}
}