mirror of
https://github.com/edera-dev/krata.git
synced 2025-08-04 13:41:31 +00:00
network: rework raw sockets to use channels
This commit is contained in:
@ -1,12 +1,18 @@
|
||||
use anyhow::Result;
|
||||
use futures::ready;
|
||||
use std::os::fd::IntoRawFd;
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::BytesMut;
|
||||
use log::warn;
|
||||
use std::io::ErrorKind;
|
||||
use std::os::fd::{FromRawFd, IntoRawFd};
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::sync::Arc;
|
||||
use std::{io, mem};
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 500;
|
||||
const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 500;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RawSocketProtocol {
|
||||
@ -186,80 +192,99 @@ fn ifreq_ioctl(
|
||||
Ok(ifreq.ifr_data)
|
||||
}
|
||||
|
||||
pub struct AsyncRawSocket {
|
||||
inner: AsyncFd<RawSocketHandle>,
|
||||
pub struct AsyncRawSocketChannel {
|
||||
pub sender: Sender<BytesMut>,
|
||||
pub receiver: Receiver<BytesMut>,
|
||||
_task: Arc<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl AsyncRawSocket {
|
||||
pub fn new(socket: RawSocketHandle) -> Result<Self> {
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(socket)?,
|
||||
enum AsyncRawSocketChannelSelect {
|
||||
TransmitPacket(Option<BytesMut>),
|
||||
Readable(()),
|
||||
}
|
||||
|
||||
impl AsyncRawSocketChannel {
|
||||
pub fn new(socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
|
||||
let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
|
||||
let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
|
||||
let task = AsyncRawSocketChannel::launch(socket, transmit_receiver, receive_sender)?;
|
||||
Ok(AsyncRawSocketChannel {
|
||||
sender: transmit_sender,
|
||||
receiver: receive_receiver,
|
||||
_task: Arc::new(task),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
|
||||
let socket = RawSocketHandle::bound_to_interface(interface, protocol)?;
|
||||
AsyncRawSocket::new(socket)
|
||||
fn launch(
|
||||
socket: RawSocketHandle,
|
||||
transmit_receiver: Receiver<BytesMut>,
|
||||
receive_sender: Sender<BytesMut>,
|
||||
) -> Result<JoinHandle<()>> {
|
||||
Ok(tokio::task::spawn(async move {
|
||||
if let Err(error) =
|
||||
AsyncRawSocketChannel::process(socket, transmit_receiver, receive_sender).await
|
||||
{
|
||||
warn!("failed to process raw socket: {}", error);
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn mtu_of_interface(&mut self, interface: &str) -> Result<usize> {
|
||||
Ok(self.inner.get_mut().mtu_of_interface(interface)?)
|
||||
}
|
||||
}
|
||||
async fn process(
|
||||
socket: RawSocketHandle,
|
||||
mut transmit_receiver: Receiver<BytesMut>,
|
||||
receive_sender: Sender<BytesMut>,
|
||||
) -> Result<()> {
|
||||
let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
|
||||
let socket = UdpSocket::from_std(socket)?;
|
||||
|
||||
impl TryFrom<RawSocketHandle> for AsyncRawSocket {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: RawSocketHandle) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(value)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for AsyncRawSocket {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
let mut guard = ready!(self.inner.poll_read_ready(cx))?;
|
||||
let selection = select! {
|
||||
x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
|
||||
x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
|
||||
};
|
||||
|
||||
let unfilled = buf.initialize_unfilled();
|
||||
match guard.try_io(|inner| inner.get_ref().recv(unfilled)) {
|
||||
Ok(Ok(len)) => {
|
||||
buf.advance(len);
|
||||
return Poll::Ready(Ok(()));
|
||||
match selection {
|
||||
AsyncRawSocketChannelSelect::Readable(_) => {
|
||||
let mut buffer = vec![0; 1500];
|
||||
match socket.try_recv(&mut buffer) {
|
||||
Ok(len) => {
|
||||
if len == 0 {
|
||||
continue;
|
||||
}
|
||||
let buffer = (&buffer[0..len]).into();
|
||||
if let Err(error) = receive_sender.try_send(buffer) {
|
||||
warn!("raw socket failed to process received packet: {}", error);
|
||||
}
|
||||
}
|
||||
|
||||
Err(ref error) => {
|
||||
if error.kind() == ErrorKind::WouldBlock {
|
||||
continue;
|
||||
}
|
||||
return Err(anyhow!("failed to read from raw socket: {}", error));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
|
||||
match socket.try_send(&packet) {
|
||||
Ok(_len) => {}
|
||||
Err(ref error) => {
|
||||
if error.kind() == ErrorKind::WouldBlock {
|
||||
warn!("failed to transmit: would block");
|
||||
continue;
|
||||
}
|
||||
return Err(anyhow!("failed to write to raw socket: {}", error));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
AsyncRawSocketChannelSelect::TransmitPacket(None) => {
|
||||
break;
|
||||
}
|
||||
Ok(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AsyncRawSocket {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
loop {
|
||||
let mut guard = ready!(self.inner.poll_write_ready(cx))?;
|
||||
|
||||
match guard.try_io(|inner| inner.get_ref().send(buf)) {
|
||||
Ok(result) => return Poll::Ready(result),
|
||||
Err(_would_block) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user