network: don't block icmp proxies on ping awaits

This commit is contained in:
Alex Zenla
2024-02-12 16:11:29 +00:00
parent 59bdd8d80d
commit ddeab7610d
4 changed files with 132 additions and 84 deletions

View File

@ -8,7 +8,7 @@ use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use etherparse::SlicedPacket; use etherparse::SlicedPacket;
use futures::TryStreamExt; use futures::TryStreamExt;
use log::{debug, info, warn}; use log::{debug, info, trace, warn};
use smoltcp::iface::{Config, Interface, SocketSet}; use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::phy::Medium; use smoltcp::phy::Medium;
use smoltcp::time::Instant; use smoltcp::time::Instant;
@ -26,7 +26,6 @@ pub struct NetworkBackend {
enum NetworkStackSelect<'a> { enum NetworkStackSelect<'a> {
Receive(&'a [u8]), Receive(&'a [u8]),
BridgeSend(Option<Vec<u8>>),
Send(Option<Vec<u8>>), Send(Option<Vec<u8>>),
Reclaim, Reclaim,
} }
@ -46,8 +45,8 @@ impl NetworkStack<'_> {
async fn poll(&mut self, buffer: &mut [u8]) -> Result<()> { async fn poll(&mut self, buffer: &mut [u8]) -> Result<()> {
let what = select! { let what = select! {
x = self.kdev.read(buffer) => NetworkStackSelect::Receive(&buffer[0..x?]), x = self.kdev.read(buffer) => NetworkStackSelect::Receive(&buffer[0..x?]),
x = self.bridge.bridge_rx_receiver.recv() => NetworkStackSelect::BridgeSend(x), x = self.bridge.bridge_rx_receiver.recv() => NetworkStackSelect::Send(x),
x = self.bridge.broadcast_rx_receiver.recv() => NetworkStackSelect::BridgeSend(x.ok()), x = self.bridge.broadcast_rx_receiver.recv() => NetworkStackSelect::Send(x.ok()),
x = self.tx.recv() => NetworkStackSelect::Send(x), x = self.tx.recv() => NetworkStackSelect::Send(x),
_ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim,
}; };
@ -55,7 +54,7 @@ impl NetworkStack<'_> {
match what { match what {
NetworkStackSelect::Receive(packet) => { NetworkStackSelect::Receive(packet) => {
if let Err(error) = self.bridge.bridge_tx_sender.try_send(packet.to_vec()) { if let Err(error) = self.bridge.bridge_tx_sender.try_send(packet.to_vec()) {
warn!("failed to send guest packet to bridge: {}", error); trace!("failed to send guest packet to bridge: {}", error);
} }
let slice = SlicedPacket::from_ethernet(packet)?; let slice = SlicedPacket::from_ethernet(packet)?;
@ -69,19 +68,9 @@ impl NetworkStack<'_> {
.poll(Instant::now(), &mut self.udev, &mut self.sockets); .poll(Instant::now(), &mut self.udev, &mut self.sockets);
} }
NetworkStackSelect::BridgeSend(Some(packet)) => { NetworkStackSelect::Send(Some(packet)) => self.kdev.write_all(&packet).await?,
if let Err(error) = self.udev.tx.try_send(packet) {
warn!("failed to send bridge packet to guest: {}", error);
}
}
NetworkStackSelect::BridgeSend(None) => {} NetworkStackSelect::Send(None) => {}
NetworkStackSelect::Send(packet) => {
if let Some(packet) = packet {
self.kdev.write_all(&packet).await?
}
}
NetworkStackSelect::Reclaim => {} NetworkStackSelect::Reclaim => {}
} }

View File

@ -39,13 +39,13 @@ struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
#[derive(Debug)] #[derive(Debug)]
pub enum IcmpReply { pub enum IcmpReply {
Icmp4 { Icmpv4 {
header: Icmpv4Header, header: Icmpv4Header,
echo: IcmpEchoHeader, echo: IcmpEchoHeader,
payload: Vec<u8>, payload: Vec<u8>,
}, },
Icmp6 { Icmpv6 {
header: Icmpv6Header, header: Icmpv6Header,
echo: IcmpEchoHeader, echo: IcmpEchoHeader,
payload: Vec<u8>, payload: Vec<u8>,
@ -53,6 +53,8 @@ pub enum IcmpReply {
} }
type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>; type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
#[derive(Clone)]
pub struct IcmpClient { pub struct IcmpClient {
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
handlers: IcmpHandlerMap, handlers: IcmpHandlerMap,
@ -118,7 +120,7 @@ impl IcmpClient {
Some(echo.id), Some(echo.id),
echo.seq, echo.seq,
); );
let reply = IcmpReply::Icmp4 { let reply = IcmpReply::Icmpv4 {
header: icmpv4.header(), header: icmpv4.header(),
echo, echo,
payload: icmpv4.payload().to_vec(), payload: icmpv4.payload().to_vec(),
@ -141,7 +143,7 @@ impl IcmpClient {
let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq); let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
let reply = IcmpReply::Icmp6 { let reply = IcmpReply::Icmpv6 {
header: icmpv6.header(), header: icmpv6.header(),
echo, echo,
payload: icmpv6.payload().to_vec(), payload: icmpv6.payload().to_vec(),

View File

@ -50,7 +50,7 @@ impl Display for NatKey {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct NatHandlerContext { pub struct NatHandlerContext {
pub mtu: usize, pub mtu: usize,
pub key: NatKey, pub key: NatKey,

View File

@ -1,12 +1,15 @@
use std::{net::IpAddr, time::Duration}; use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::Duration,
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use async_trait::async_trait; use async_trait::async_trait;
use etherparse::{ use etherparse::{
Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type, IpNumber, Ipv4Slice, Ipv6Slice, NetSlice, IcmpEchoHeader, Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type, IpNumber, Ipv4Slice,
PacketBuilder, SlicedPacket, Ipv6Slice, NetSlice, PacketBuilder, SlicedPacket,
}; };
use log::{debug, warn}; use log::{debug, trace, warn};
use smoltcp::wire::IpAddress; use smoltcp::wire::IpAddress;
use tokio::{ use tokio::{
select, select,
@ -104,6 +107,8 @@ impl ProxyIcmpHandler {
} }
} }
context.reclaim().await?;
Ok(()) Ok(())
} }
@ -122,39 +127,22 @@ impl ProxyIcmpHandler {
return Ok(()); return Ok(());
}; };
let Some(IcmpReply::Icmp4 { let context = context.clone();
header: _, let client = client.clone();
let payload = payload.to_vec();
tokio::task::spawn(async move {
if let Err(error) = ProxyIcmpHandler::process_echo_ipv4(
context,
client,
external_ipv4,
echo, echo,
payload, payload,
}) = client
.ping4(
external_ipv4,
echo.id,
echo.seq,
payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
) )
.await? .await
else { {
return Ok(()); trace!("icmp4 echo failed: {}", error);
};
let packet =
PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
packet.ipv4(external_addr.0, client_addr.0, 20)
}
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv4_echo_reply(echo.id, echo.seq);
let mut buffer: Vec<u8> = Vec::new();
packet.write(&mut buffer, &payload)?;
if let Err(error) = context.try_send(buffer) {
debug!("failed to transmit icmp packet: {}", error);
} }
});
} }
Ok(()) Ok(())
} }
@ -174,25 +162,98 @@ impl ProxyIcmpHandler {
return Ok(()); return Ok(());
}; };
let Some(IcmpReply::Icmp6 { let context = context.clone();
let client = client.clone();
let payload = payload.to_vec();
tokio::task::spawn(async move {
if let Err(error) = ProxyIcmpHandler::process_echo_ipv6(
context,
client,
external_ipv6,
echo,
payload,
)
.await
{
trace!("icmp6 echo failed: {}", error);
}
});
}
context.reclaim().await?;
Ok(())
}
async fn process_echo_ipv4(
context: NatHandlerContext,
client: IcmpClient,
external_ipv4: Ipv4Addr,
echo: IcmpEchoHeader,
payload: Vec<u8>,
) -> Result<()> {
let reply = client
.ping4(
external_ipv4,
echo.id,
echo.seq,
&payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?;
let Some(IcmpReply::Icmpv4 {
header: _, header: _,
echo, echo,
payload, payload,
}) = client }) = reply
.ping6(
external_ipv6,
echo.id,
echo.seq,
payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?
else { else {
return Ok(()); return Ok(());
}; };
let packet = let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0); let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
packet.ipv4(external_addr.0, client_addr.0, 20)
}
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv4_echo_reply(echo.id, echo.seq);
let mut buffer: Vec<u8> = Vec::new();
packet.write(&mut buffer, &payload)?;
if let Err(error) = context.try_send(buffer) {
debug!("failed to transmit icmp packet: {}", error);
}
Ok(())
}
async fn process_echo_ipv6(
context: NatHandlerContext,
client: IcmpClient,
external_ipv6: Ipv6Addr,
echo: IcmpEchoHeader,
payload: Vec<u8>,
) -> Result<()> {
let reply = client
.ping6(
external_ipv6,
echo.id,
echo.seq,
&payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?;
let Some(IcmpReply::Icmpv6 {
header: _,
echo,
payload,
}) = reply
else {
return Ok(());
};
let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) { let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => { (IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => {
packet.ipv6(external_addr.0, client_addr.0, 20) packet.ipv6(external_addr.0, client_addr.0, 20)
@ -207,10 +268,6 @@ impl ProxyIcmpHandler {
if let Err(error) = context.try_send(buffer) { if let Err(error) = context.try_send(buffer) {
debug!("failed to transmit icmp packet: {}", error); debug!("failed to transmit icmp packet: {}", error);
} }
}
context.reclaim().await?;
Ok(()) Ok(())
} }
} }