krata: reorganize crates

This commit is contained in:
Alex Zenla
2024-03-07 18:12:47 +00:00
parent c0eeab4047
commit 7bc0c95f00
97 changed files with 24 additions and 24 deletions

View File

@ -0,0 +1,40 @@
[package]
name = "kratanet"
version.workspace = true
edition = "2021"
resolver = "2"
[dependencies]
advmac = { path = "../vendor/advmac" }
anyhow = { workspace = true }
async-trait = { workspace = true }
bytes = { workspace = true }
clap = { workspace = true }
env_logger = { workspace = true }
etherparse = { workspace = true }
futures = { workspace = true }
libc = { workspace = true }
log = { workspace = true }
netlink-packet-route = { workspace = true }
rtnetlink = { workspace = true }
smoltcp = { workspace = true }
tokio = { workspace = true }
tokio-tun = { workspace = true }
udp-stream = { workspace = true }
uuid = { workspace = true }
xenstore = { path = "../xen/xenstore" }
[lib]
name = "kratanet"
[[bin]]
name = "kratanet"
path = "bin/network.rs"
[[example]]
name = "ping"
path = "examples/ping.rs"
[[example]]
name = "autonet"
path = "examples/autonet.rs"

View File

@ -0,0 +1,15 @@
use anyhow::Result;
use clap::Parser;
use env_logger::Env;
use kratanet::NetworkService;
#[derive(Parser, Debug)]
struct NetworkArgs {}
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() -> Result<()> {
env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init();
let _ = NetworkArgs::parse();
let mut service = NetworkService::new().await?;
service.watch().await
}

View File

@ -0,0 +1,15 @@
use std::time::Duration;
use anyhow::Result;
use kratanet::autonet::AutoNetworkCollector;
use tokio::time::sleep;
#[tokio::main]
async fn main() -> Result<()> {
let mut collector = AutoNetworkCollector::new().await?;
loop {
let changeset = collector.read_changes().await?;
println!("{:?}", changeset);
sleep(Duration::from_secs(2)).await;
}
}

View File

@ -0,0 +1,21 @@
use std::{net::Ipv6Addr, str::FromStr, time::Duration};
use anyhow::Result;
use kratanet::icmp::{IcmpClient, IcmpProtocol};
#[tokio::main]
async fn main() -> Result<()> {
let client = IcmpClient::new(IcmpProtocol::Icmpv6)?;
let payload: [u8; 4] = [12u8, 14u8, 16u8, 32u8];
let result = client
.ping6(
Ipv6Addr::from_str("2606:4700:4700::1111")?,
0,
1,
&payload,
Duration::from_secs(10),
)
.await?;
println!("reply: {:?}", result);
Ok(())
}

View File

@ -0,0 +1,185 @@
use anyhow::{anyhow, Result};
use smoltcp::wire::{EthernetAddress, Ipv4Cidr, Ipv6Cidr};
use std::{collections::HashMap, str::FromStr};
use uuid::Uuid;
use xenstore::client::{XsdClient, XsdInterface, XsdTransaction};
pub struct AutoNetworkCollector {
client: XsdClient,
known: HashMap<Uuid, NetworkMetadata>,
}
#[derive(Debug, Clone)]
pub struct NetworkSide {
pub ipv4: Ipv4Cidr,
pub ipv6: Ipv6Cidr,
pub mac: EthernetAddress,
}
#[derive(Debug, Clone)]
pub struct NetworkMetadata {
pub domid: u32,
pub uuid: Uuid,
pub guest: NetworkSide,
pub gateway: NetworkSide,
}
impl NetworkMetadata {
pub fn interface(&self) -> String {
format!("vif{}.20", self.domid)
}
}
#[derive(Debug, Clone)]
pub struct AutoNetworkChangeset {
pub added: Vec<NetworkMetadata>,
pub removed: Vec<NetworkMetadata>,
}
impl AutoNetworkCollector {
pub async fn new() -> Result<AutoNetworkCollector> {
Ok(AutoNetworkCollector {
client: XsdClient::open().await?,
known: HashMap::new(),
})
}
pub async fn read(&mut self) -> Result<Vec<NetworkMetadata>> {
let mut networks = Vec::new();
let tx = self.client.transaction().await?;
for domid_string in tx.list("/local/domain").await? {
let Ok(domid) = domid_string.parse::<u32>() else {
continue;
};
let dom_path = format!("/local/domain/{}", domid_string);
let Some(uuid_string) = tx.read_string(&format!("{}/krata/uuid", dom_path)).await?
else {
continue;
};
let Ok(uuid) = uuid_string.parse::<Uuid>() else {
continue;
};
let Ok(guest) =
AutoNetworkCollector::read_network_side(uuid, &tx, &dom_path, "guest").await
else {
continue;
};
let Ok(gateway) =
AutoNetworkCollector::read_network_side(uuid, &tx, &dom_path, "gateway").await
else {
continue;
};
networks.push(NetworkMetadata {
domid,
uuid,
guest,
gateway,
});
}
tx.commit().await?;
Ok(networks)
}
async fn read_network_side(
uuid: Uuid,
tx: &XsdTransaction,
dom_path: &str,
side: &str,
) -> Result<NetworkSide> {
let side_path = format!("{}/krata/network/{}", dom_path, side);
let Some(ipv4) = tx.read_string(&format!("{}/ipv4", side_path)).await? else {
return Err(anyhow!(
"krata domain {} is missing {} ipv4 network entry",
uuid,
side
));
};
let Some(ipv6) = tx.read_string(&format!("{}/ipv6", side_path)).await? else {
return Err(anyhow!(
"krata domain {} is missing {} ipv6 network entry",
uuid,
side
));
};
let Some(mac) = tx.read_string(&format!("{}/mac", side_path)).await? else {
return Err(anyhow!(
"krata domain {} is missing {} mac address entry",
uuid,
side
));
};
let Ok(ipv4) = Ipv4Cidr::from_str(&ipv4) else {
return Err(anyhow!(
"krata domain {} has invalid {} ipv4 network cidr entry: {}",
uuid,
side,
ipv4
));
};
let Ok(ipv6) = Ipv6Cidr::from_str(&ipv6) else {
return Err(anyhow!(
"krata domain {} has invalid {} ipv6 network cidr entry: {}",
uuid,
side,
ipv6
));
};
let Ok(mac) = EthernetAddress::from_str(&mac) else {
return Err(anyhow!(
"krata domain {} has invalid {} mac address entry: {}",
uuid,
side,
mac
));
};
Ok(NetworkSide { ipv4, ipv6, mac })
}
pub async fn read_changes(&mut self) -> Result<AutoNetworkChangeset> {
let mut seen: Vec<Uuid> = Vec::new();
let mut added: Vec<NetworkMetadata> = Vec::new();
let mut removed: Vec<NetworkMetadata> = Vec::new();
for network in self.read().await? {
seen.push(network.uuid);
if self.known.contains_key(&network.uuid) {
continue;
}
let _ = self.known.insert(network.uuid, network.clone());
added.push(network);
}
let mut gone: Vec<Uuid> = Vec::new();
for uuid in self.known.keys() {
if seen.contains(uuid) {
continue;
}
gone.push(*uuid);
}
for uuid in &gone {
let Some(network) = self.known.remove(uuid) else {
continue;
};
removed.push(network);
}
Ok(AutoNetworkChangeset { added, removed })
}
pub fn mark_unknown(&mut self, uuid: Uuid) -> Result<bool> {
Ok(self.known.remove(&uuid).is_some())
}
}

View File

@ -0,0 +1,175 @@
use crate::autonet::NetworkMetadata;
use crate::chandev::ChannelDevice;
use crate::nat::Nat;
use crate::proxynat::ProxyNatHandlerFactory;
use crate::raw_socket::{AsyncRawSocketChannel, RawSocketHandle, RawSocketProtocol};
use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
use crate::EXTRA_MTU;
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use futures::TryStreamExt;
use log::{info, trace, warn};
use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::phy::Medium;
use smoltcp::time::Instant;
use smoltcp::wire::{HardwareAddress, IpCidr};
use tokio::select;
use tokio::sync::mpsc::{channel, Receiver};
use tokio::task::JoinHandle;
const TX_CHANNEL_BUFFER_LEN: usize = 3000;
#[derive(Clone)]
pub struct NetworkBackend {
metadata: NetworkMetadata,
bridge: VirtualBridge,
}
#[derive(Debug)]
enum NetworkStackSelect {
Receive(Option<BytesMut>),
Send(Option<BytesMut>),
}
struct NetworkStack<'a> {
tx: Receiver<BytesMut>,
kdev: AsyncRawSocketChannel,
udev: ChannelDevice,
interface: Interface,
sockets: SocketSet<'a>,
nat: Nat,
bridge: BridgeJoinHandle,
}
impl NetworkStack<'_> {
async fn poll(&mut self) -> Result<bool> {
let what = select! {
x = self.kdev.receiver.recv() => NetworkStackSelect::Receive(x),
x = self.bridge.from_bridge_receiver.recv() => NetworkStackSelect::Send(x),
x = self.bridge.from_broadcast_receiver.recv() => NetworkStackSelect::Send(x.ok()),
x = self.tx.recv() => NetworkStackSelect::Send(x),
};
match what {
NetworkStackSelect::Receive(Some(packet)) => {
if let Err(error) = self.bridge.to_bridge_sender.try_send(packet.clone()) {
trace!("failed to send guest packet to bridge: {}", error);
}
if let Err(error) = self.nat.receive_sender.try_send(packet.clone()) {
trace!("failed to send guest packet to nat: {}", error);
}
self.udev.rx = Some(packet);
self.interface
.poll(Instant::now(), &mut self.udev, &mut self.sockets);
}
NetworkStackSelect::Send(Some(packet)) => {
if let Err(error) = self.kdev.sender.try_send(packet) {
warn!("failed to transmit packet to interface: {}", error);
}
}
NetworkStackSelect::Receive(None) | NetworkStackSelect::Send(None) => {
return Ok(false);
}
}
Ok(true)
}
}
impl NetworkBackend {
pub fn new(metadata: NetworkMetadata, bridge: VirtualBridge) -> Result<Self> {
Ok(Self { metadata, bridge })
}
pub async fn init(&mut self) -> Result<()> {
let interface = self.metadata.interface();
let (connection, handle, _) = rtnetlink::new_connection()?;
tokio::spawn(connection);
let mut links = handle.link().get().match_name(interface.clone()).execute();
let link = links.try_next().await?;
if link.is_none() {
return Err(anyhow!(
"unable to find network interface named {}",
interface
));
}
let link = link.unwrap();
handle.link().set(link.header.index).up().execute().await?;
Ok(())
}
pub async fn run(&self) -> Result<()> {
let mut stack = self.create_network_stack().await?;
loop {
if !stack.poll().await? {
break;
}
}
Ok(())
}
async fn create_network_stack(&self) -> Result<NetworkStack> {
let interface = self.metadata.interface();
let proxy = Box::new(ProxyNatHandlerFactory::new());
let addresses: Vec<IpCidr> = vec![
self.metadata.gateway.ipv4.into(),
self.metadata.gateway.ipv6.into(),
];
let mut kdev =
RawSocketHandle::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?;
let mtu = kdev.mtu_of_interface(&interface)? + EXTRA_MTU;
let (tx_sender, tx_receiver) = channel::<BytesMut>(TX_CHANNEL_BUFFER_LEN);
let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone());
let mac = self.metadata.gateway.mac;
let nat = Nat::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone())?;
let hardware_addr = HardwareAddress::Ethernet(mac);
let config = Config::new(hardware_addr);
let mut iface = Interface::new(config, &mut udev, Instant::now());
iface.update_ip_addrs(|addrs| {
addrs
.extend_from_slice(&addresses)
.expect("failed to set ip addresses");
});
let sockets = SocketSet::new(vec![]);
let handle = self.bridge.join(self.metadata.guest.mac).await?;
let kdev = AsyncRawSocketChannel::new(mtu, kdev)?;
Ok(NetworkStack {
tx: tx_receiver,
kdev,
udev,
interface: iface,
sockets,
nat,
bridge: handle,
})
}
pub async fn launch(self) -> Result<JoinHandle<()>> {
Ok(tokio::task::spawn(async move {
info!(
"lauched network backend for krata guest {}",
self.metadata.uuid
);
if let Err(error) = self.run().await {
warn!(
"network backend for krata guest {} failed: {}",
self.metadata.uuid, error
);
}
}))
}
}
impl Drop for NetworkBackend {
fn drop(&mut self) {
info!(
"destroyed network backend for krata guest {}",
self.metadata.uuid
);
}
}

View File

@ -0,0 +1,89 @@
// Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/channelized_smoltcp_device.rs
use bytes::BytesMut;
use log::{debug, warn};
use smoltcp::phy::{Checksum, Device, Medium};
use tokio::sync::mpsc::Sender;
const TEAR_OFF_BUFFER_SIZE: usize = 65536;
pub struct ChannelDevice {
pub mtu: usize,
pub medium: Medium,
pub tx: Sender<BytesMut>,
pub rx: Option<BytesMut>,
tear_off_buffer: BytesMut,
}
impl ChannelDevice {
pub fn new(mtu: usize, medium: Medium, tx: Sender<BytesMut>) -> Self {
Self {
mtu,
medium,
tx,
rx: None,
tear_off_buffer: BytesMut::with_capacity(TEAR_OFF_BUFFER_SIZE),
}
}
}
pub struct RxToken(pub BytesMut);
impl Device for ChannelDevice {
type RxToken<'a> = RxToken where Self: 'a;
type TxToken<'a> = &'a mut ChannelDevice where Self: 'a;
fn receive(
&mut self,
_timestamp: smoltcp::time::Instant,
) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
self.rx.take().map(|x| (RxToken(x), self))
}
fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
if self.tx.capacity() == 0 {
debug!("ran out of transmission capacity");
return None;
}
Some(self)
}
fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities {
let mut capabilities = smoltcp::phy::DeviceCapabilities::default();
capabilities.medium = self.medium;
capabilities.max_transmission_unit = self.mtu;
capabilities.checksum = smoltcp::phy::ChecksumCapabilities::ignored();
capabilities.checksum.tcp = Checksum::Tx;
capabilities.checksum.ipv4 = Checksum::Tx;
capabilities.checksum.icmpv4 = Checksum::Tx;
capabilities.checksum.icmpv6 = Checksum::Tx;
capabilities
}
}
impl smoltcp::phy::RxToken for RxToken {
fn consume<R, F>(mut self, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(&mut self.0[..])
}
}
impl<'a> smoltcp::phy::TxToken for &'a mut ChannelDevice {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
self.tear_off_buffer.resize(len, 0);
let result = f(&mut self.tear_off_buffer[..]);
let chunk = self.tear_off_buffer.split();
if let Err(error) = self.tx.try_send(chunk) {
warn!("failed to transmit packet: {}", error);
}
if self.tear_off_buffer.capacity() < self.mtu {
self.tear_off_buffer = BytesMut::with_capacity(TEAR_OFF_BUFFER_SIZE);
}
result
}
}

View File

@ -0,0 +1,145 @@
use std::net::{IpAddr, Ipv4Addr};
use advmac::MacAddr6;
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use futures::TryStreamExt;
use log::error;
use smoltcp::wire::EthernetAddress;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
sync::mpsc::channel,
task::JoinHandle,
};
use tokio_tun::Tun;
use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
const RX_BUFFER_QUEUE_LEN: usize = 100;
const HOST_IPV4_ADDR: Ipv4Addr = Ipv4Addr::new(10, 75, 0, 1);
pub struct HostBridge {
task: JoinHandle<()>,
}
impl HostBridge {
pub async fn new(mtu: usize, interface: String, bridge: &VirtualBridge) -> Result<HostBridge> {
let tun = Tun::builder()
.name(&interface)
.tap(true)
.mtu(mtu as i32)
.packet_info(false)
.try_build()?;
let (connection, handle, _) = rtnetlink::new_connection()?;
tokio::spawn(connection);
let mut mac = MacAddr6::random();
mac.set_local(true);
mac.set_multicast(false);
let mut links = handle.link().get().match_name(interface.clone()).execute();
let link = links.try_next().await?;
if link.is_none() {
return Err(anyhow!(
"unable to find network interface named {}",
interface
));
}
let link = link.unwrap();
handle
.address()
.add(link.header.index, IpAddr::V4(HOST_IPV4_ADDR), 16)
.execute()
.await?;
handle
.address()
.add(link.header.index, IpAddr::V6(mac.to_link_local_ipv6()), 10)
.execute()
.await?;
handle
.link()
.set(link.header.index)
.address(mac.to_array().to_vec())
.up()
.execute()
.await?;
let mac = EthernetAddress(mac.to_array());
let bridge_handle = bridge.join(mac).await?;
let task = tokio::task::spawn(async move {
if let Err(error) = HostBridge::process(mtu, tun, bridge_handle).await {
error!("failed to process host bridge: {}", error);
}
});
Ok(HostBridge { task })
}
async fn process(mtu: usize, tun: Tun, mut bridge_handle: BridgeJoinHandle) -> Result<()> {
let (rx_sender, mut rx_receiver) = channel::<BytesMut>(RX_BUFFER_QUEUE_LEN);
let (mut read, mut write) = tokio::io::split(tun);
tokio::task::spawn(async move {
let mut buffer = vec![0u8; mtu];
loop {
let size = match read.read(&mut buffer).await {
Ok(size) => size,
Err(error) => {
error!("failed to read tap device: {}", error);
break;
}
};
match rx_sender.send(buffer[0..size].into()).await {
Ok(_) => {}
Err(error) => {
error!(
"failed to send data from tap device to processor: {}",
error
);
break;
}
}
}
});
loop {
select! {
x = bridge_handle.from_bridge_receiver.recv() => match x {
Some(bytes) => {
write.write_all(&bytes).await?;
},
None => {
break;
}
},
x = bridge_handle.from_broadcast_receiver.recv() => match x {
Ok(bytes) => {
write.write_all(&bytes).await?;
},
Err(error) => {
return Err(error.into());
}
},
x = rx_receiver.recv() => match x {
Some(bytes) => {
bridge_handle.to_bridge_sender.send(bytes).await?;
},
None => {
break;
}
}
};
}
Ok(())
}
}
impl Drop for HostBridge {
fn drop(&mut self) {
self.task.abort();
}
}

250
crates/kratanet/src/icmp.rs Normal file
View File

@ -0,0 +1,250 @@
use crate::raw_socket::{RawSocketHandle, RawSocketProtocol};
use anyhow::{anyhow, Result};
use etherparse::{
IcmpEchoHeader, Icmpv4Header, Icmpv4Slice, Icmpv4Type, Icmpv6Header, Icmpv6Slice, Icmpv6Type,
IpNumber, NetSlice, SlicedPacket,
};
use log::warn;
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::fd::{FromRawFd, IntoRawFd},
sync::Arc,
time::Duration,
};
use tokio::{
net::UdpSocket,
sync::{oneshot, Mutex},
task::JoinHandle,
time::timeout,
};
#[derive(Debug)]
pub enum IcmpProtocol {
Icmpv4,
Icmpv6,
}
impl IcmpProtocol {
pub fn to_socket_protocol(&self) -> RawSocketProtocol {
match self {
IcmpProtocol::Icmpv4 => RawSocketProtocol::Icmpv4,
IcmpProtocol::Icmpv6 => RawSocketProtocol::Icmpv6,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
#[derive(Debug)]
pub enum IcmpReply {
Icmpv4 {
header: Icmpv4Header,
echo: IcmpEchoHeader,
payload: Vec<u8>,
},
Icmpv6 {
header: Icmpv6Header,
echo: IcmpEchoHeader,
payload: Vec<u8>,
},
}
type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
#[derive(Clone)]
pub struct IcmpClient {
socket: Arc<UdpSocket>,
handlers: IcmpHandlerMap,
task: Arc<JoinHandle<Result<()>>>,
}
impl IcmpClient {
pub fn new(protocol: IcmpProtocol) -> Result<IcmpClient> {
let handle = RawSocketHandle::new(protocol.to_socket_protocol())?;
let socket = unsafe { std::net::UdpSocket::from_raw_fd(handle.into_raw_fd()) };
let socket: Arc<UdpSocket> = Arc::new(socket.try_into()?);
let handlers = Arc::new(Mutex::new(HashMap::new()));
let task = Arc::new(tokio::task::spawn(IcmpClient::process(
protocol,
socket.clone(),
handlers.clone(),
)));
Ok(IcmpClient {
socket,
handlers,
task,
})
}
async fn process(
protocol: IcmpProtocol,
socket: Arc<UdpSocket>,
handlers: IcmpHandlerMap,
) -> Result<()> {
let mut buffer = vec![0u8; 2048];
loop {
let (size, addr) = socket.recv_from(&mut buffer).await?;
let packet = &buffer[0..size];
let (token, reply) = match protocol {
IcmpProtocol::Icmpv4 => {
let sliced = match SlicedPacket::from_ip(packet) {
Ok(sliced) => sliced,
Err(error) => {
warn!("received icmp packet but failed to parse it: {}", error);
continue;
}
};
let Some(NetSlice::Ipv4(ipv4)) = sliced.net else {
continue;
};
if ipv4.header().protocol() != IpNumber::ICMP {
continue;
}
let Ok(icmpv4) = Icmpv4Slice::from_slice(ipv4.payload().payload) else {
continue;
};
let Icmpv4Type::EchoReply(echo) = icmpv4.header().icmp_type else {
continue;
};
let token = IcmpHandlerToken(
IpAddr::V4(ipv4.header().source_addr()),
Some(echo.id),
echo.seq,
);
let reply = IcmpReply::Icmpv4 {
header: icmpv4.header(),
echo,
payload: icmpv4.payload().to_vec(),
};
(token, reply)
}
IcmpProtocol::Icmpv6 => {
let Ok(icmpv6) = Icmpv6Slice::from_slice(packet) else {
continue;
};
let Icmpv6Type::EchoReply(echo) = icmpv6.header().icmp_type else {
continue;
};
let SocketAddr::V6(addr) = addr else {
continue;
};
let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
let reply = IcmpReply::Icmpv6 {
header: icmpv6.header(),
echo,
payload: icmpv6.payload().to_vec(),
};
(token, reply)
}
};
if let Some(sender) = handlers.lock().await.remove(&token) {
let _ = sender.send(reply);
}
}
}
async fn add_handler(&self, token: IcmpHandlerToken) -> Result<oneshot::Receiver<IcmpReply>> {
let (tx, rx) = oneshot::channel();
if self
.handlers
.lock()
.await
.insert(token.clone(), tx)
.is_some()
{
return Err(anyhow!("duplicate icmp request: {:?}", token));
}
Ok(rx)
}
async fn remove_handler(&self, token: IcmpHandlerToken) -> Result<()> {
self.handlers.lock().await.remove(&token);
Ok(())
}
pub async fn ping4(
&self,
addr: Ipv4Addr,
id: u16,
seq: u16,
payload: &[u8],
deadline: Duration,
) -> Result<Option<IcmpReply>> {
let token = IcmpHandlerToken(IpAddr::V4(addr), Some(id), seq);
let rx = self.add_handler(token.clone()).await?;
let echo = IcmpEchoHeader { id, seq };
let mut header = Icmpv4Header::new(Icmpv4Type::EchoRequest(echo));
header.update_checksum(payload);
let mut buffer: Vec<u8> = Vec::new();
header.write(&mut buffer)?;
buffer.extend_from_slice(payload);
self.socket
.send_to(&buffer, SocketAddr::V4(SocketAddrV4::new(addr, 0)))
.await?;
let result = timeout(deadline, rx).await;
self.remove_handler(token).await?;
let reply = match result {
Ok(Ok(packet)) => Some(packet),
Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
Err(_) => None,
};
Ok(reply)
}
pub async fn ping6(
&self,
addr: Ipv6Addr,
id: u16,
seq: u16,
payload: &[u8],
deadline: Duration,
) -> Result<Option<IcmpReply>> {
let token = IcmpHandlerToken(IpAddr::V6(addr), Some(id), seq);
let rx = self.add_handler(token.clone()).await?;
let echo = IcmpEchoHeader { id, seq };
let header = Icmpv6Header::new(Icmpv6Type::EchoRequest(echo));
let mut buffer: Vec<u8> = Vec::new();
header.write(&mut buffer)?;
buffer.extend_from_slice(payload);
self.socket
.send_to(&buffer, SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)))
.await?;
let result = timeout(deadline, rx).await;
self.remove_handler(token).await?;
let reply = match result {
Ok(Ok(packet)) => Some(packet),
Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
Err(_) => None,
};
Ok(reply)
}
}
impl Drop for IcmpClient {
fn drop(&mut self) {
if Arc::strong_count(&self.task) <= 1 {
self.task.abort();
}
}
}

117
crates/kratanet/src/lib.rs Normal file
View File

@ -0,0 +1,117 @@
use std::{collections::HashMap, time::Duration};
use anyhow::Result;
use autonet::{AutoNetworkChangeset, AutoNetworkCollector, NetworkMetadata};
use futures::{future::join_all, TryFutureExt};
use hbridge::HostBridge;
use log::warn;
use tokio::{task::JoinHandle, time::sleep};
use uuid::Uuid;
use vbridge::VirtualBridge;
use crate::backend::NetworkBackend;
pub mod autonet;
pub mod backend;
pub mod chandev;
pub mod hbridge;
pub mod icmp;
pub mod nat;
pub mod pkt;
pub mod proxynat;
pub mod raw_socket;
pub mod vbridge;
const HOST_BRIDGE_MTU: usize = 1500;
pub const EXTRA_MTU: usize = 20;
pub struct NetworkService {
pub backends: HashMap<Uuid, JoinHandle<()>>,
pub bridge: VirtualBridge,
pub hbridge: HostBridge,
}
impl NetworkService {
pub async fn new() -> Result<NetworkService> {
let bridge = VirtualBridge::new()?;
let hbridge =
HostBridge::new(HOST_BRIDGE_MTU + EXTRA_MTU, "krata0".to_string(), &bridge).await?;
Ok(NetworkService {
backends: HashMap::new(),
bridge,
hbridge,
})
}
}
impl NetworkService {
pub async fn watch(&mut self) -> Result<()> {
let mut collector = AutoNetworkCollector::new().await?;
loop {
let changeset = collector.read_changes().await?;
self.process_network_changeset(&mut collector, changeset)
.await?;
sleep(Duration::from_secs(2)).await;
}
}
async fn process_network_changeset(
&mut self,
collector: &mut AutoNetworkCollector,
changeset: AutoNetworkChangeset,
) -> Result<()> {
for removal in &changeset.removed {
if let Some(handle) = self.backends.remove(&removal.uuid) {
handle.abort();
}
}
let futures = changeset
.added
.iter()
.map(|metadata| {
self.add_network_backend(metadata)
.map_err(|x| (metadata.clone(), x))
})
.collect::<Vec<_>>();
sleep(Duration::from_secs(1)).await;
let mut failed: Vec<Uuid> = Vec::new();
let mut launched: Vec<(Uuid, JoinHandle<()>)> = Vec::new();
let results = join_all(futures).await;
for result in results {
match result {
Ok(launch) => {
launched.push(launch);
}
Err((metadata, error)) => {
warn!(
"failed to launch network backend for krata guest {}: {}",
metadata.uuid, error
);
failed.push(metadata.uuid);
}
};
}
for (uuid, handle) in launched {
self.backends.insert(uuid, handle);
}
for uuid in failed {
collector.mark_unknown(uuid)?;
}
Ok(())
}
async fn add_network_backend(
&self,
metadata: &NetworkMetadata,
) -> Result<(Uuid, JoinHandle<()>)> {
let mut network = NetworkBackend::new(metadata.clone(), self.bridge.clone())?;
network.init().await?;
Ok((metadata.uuid, network.launch().await?))
}
}

View File

@ -0,0 +1,36 @@
use anyhow::Result;
use async_trait::async_trait;
use bytes::BytesMut;
use tokio::sync::mpsc::Sender;
use super::key::NatKey;
#[derive(Debug, Clone)]
pub struct NatHandlerContext {
pub mtu: usize,
pub key: NatKey,
pub transmit_sender: Sender<BytesMut>,
pub reclaim_sender: Sender<NatKey>,
}
impl NatHandlerContext {
pub fn try_transmit(&self, buffer: BytesMut) -> Result<()> {
self.transmit_sender.try_send(buffer)?;
Ok(())
}
pub async fn reclaim(&self) -> Result<()> {
self.reclaim_sender.try_send(self.key)?;
Ok(())
}
}
#[async_trait]
pub trait NatHandler: Send {
async fn receive(&self, packet: &[u8]) -> Result<bool>;
}
#[async_trait]
pub trait NatHandlerFactory: Send {
async fn nat(&self, context: NatHandlerContext) -> Option<Box<dyn NatHandler>>;
}

View File

@ -0,0 +1,29 @@
use std::fmt::Display;
use smoltcp::wire::{EthernetAddress, IpEndpoint};
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub enum NatKeyProtocol {
Tcp,
Udp,
Icmp,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct NatKey {
pub protocol: NatKeyProtocol,
pub client_mac: EthernetAddress,
pub local_mac: EthernetAddress,
pub client_ip: IpEndpoint,
pub external_ip: IpEndpoint,
}
impl Display for NatKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} -> {} {:?} {} -> {}",
self.client_mac, self.local_mac, self.protocol, self.client_ip, self.external_ip
)
}
}

View File

@ -0,0 +1,42 @@
use anyhow::Result;
use tokio::sync::mpsc::Sender;
use self::handler::NatHandlerFactory;
use self::processor::NatProcessor;
use bytes::BytesMut;
use smoltcp::wire::EthernetAddress;
use smoltcp::wire::IpCidr;
use tokio::task::JoinHandle;
pub mod handler;
pub mod key;
pub mod processor;
pub mod table;
pub struct Nat {
pub receive_sender: Sender<BytesMut>,
task: JoinHandle<()>,
}
impl Nat {
pub fn new(
mtu: usize,
factory: Box<dyn NatHandlerFactory>,
local_mac: EthernetAddress,
local_cidrs: Vec<IpCidr>,
transmit_sender: Sender<BytesMut>,
) -> Result<Self> {
let (receive_sender, task) =
NatProcessor::launch(mtu, factory, local_mac, local_cidrs, transmit_sender)?;
Ok(Self {
receive_sender,
task,
})
}
}
impl Drop for Nat {
fn drop(&mut self) {
self.task.abort();
}
}

View File

@ -0,0 +1,330 @@
use crate::pkt::RecvPacket;
use crate::pkt::RecvPacketIp;
use anyhow::Result;
use bytes::BytesMut;
use etherparse::Icmpv4Header;
use etherparse::Icmpv4Type;
use etherparse::Icmpv6Header;
use etherparse::Icmpv6Type;
use etherparse::IpNumber;
use etherparse::IpPayloadSlice;
use etherparse::Ipv4Slice;
use etherparse::Ipv6Slice;
use etherparse::SlicedPacket;
use etherparse::TcpHeaderSlice;
use etherparse::UdpHeaderSlice;
use log::warn;
use log::{debug, trace};
use smoltcp::wire::EthernetAddress;
use smoltcp::wire::IpAddress;
use smoltcp::wire::IpCidr;
use smoltcp::wire::IpEndpoint;
use std::collections::hash_map::Entry;
use tokio::select;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::task::JoinHandle;
use super::handler::NatHandler;
use super::handler::NatHandlerContext;
use super::handler::NatHandlerFactory;
use super::key::NatKey;
use super::key::NatKeyProtocol;
use super::table::NatTable;
const RECEIVE_CHANNEL_QUEUE_LEN: usize = 3000;
const RECLAIM_CHANNEL_QUEUE_LEN: usize = 30;
pub struct NatProcessor {
mtu: usize,
local_mac: EthernetAddress,
local_cidrs: Vec<IpCidr>,
table: NatTable,
factory: Box<dyn NatHandlerFactory>,
transmit_sender: Sender<BytesMut>,
reclaim_sender: Sender<NatKey>,
reclaim_receiver: Receiver<NatKey>,
receive_receiver: Receiver<BytesMut>,
}
enum NatProcessorSelect {
Reclaim(Option<NatKey>),
ReceivedPacket(Option<BytesMut>),
}
impl NatProcessor {
pub fn launch(
mtu: usize,
factory: Box<dyn NatHandlerFactory>,
local_mac: EthernetAddress,
local_cidrs: Vec<IpCidr>,
transmit_sender: Sender<BytesMut>,
) -> Result<(Sender<BytesMut>, JoinHandle<()>)> {
let (reclaim_sender, reclaim_receiver) = channel(RECLAIM_CHANNEL_QUEUE_LEN);
let (receive_sender, receive_receiver) = channel(RECEIVE_CHANNEL_QUEUE_LEN);
let mut processor = Self {
mtu,
local_mac,
local_cidrs,
factory,
table: NatTable::new(),
transmit_sender,
reclaim_sender,
receive_receiver,
reclaim_receiver,
};
let handle = tokio::task::spawn(async move {
if let Err(error) = processor.process().await {
warn!("nat processing failed: {}", error);
}
});
Ok((receive_sender, handle))
}
pub async fn process(&mut self) -> Result<()> {
loop {
let selection = select! {
x = self.reclaim_receiver.recv() => NatProcessorSelect::Reclaim(x),
x = self.receive_receiver.recv() => NatProcessorSelect::ReceivedPacket(x),
};
match selection {
NatProcessorSelect::Reclaim(Some(key)) => {
if self.table.inner.remove(&key).is_some() {
debug!("reclaimed nat key: {}", key);
}
}
NatProcessorSelect::ReceivedPacket(Some(packet)) => {
if let Ok(slice) = SlicedPacket::from_ethernet(&packet) {
let Ok(packet) = RecvPacket::new(&packet, &slice) else {
continue;
};
self.process_packet(&packet).await?;
}
}
NatProcessorSelect::ReceivedPacket(None) | NatProcessorSelect::Reclaim(None) => {
break
}
}
}
Ok(())
}
pub async fn process_reclaim(&mut self) -> Result<Option<NatKey>> {
Ok(if let Some(key) = self.reclaim_receiver.recv().await {
if self.table.inner.remove(&key).is_some() {
debug!("reclaimed nat key: {}", key);
Some(key)
} else {
None
}
} else {
None
})
}
pub async fn process_packet<'a>(&mut self, packet: &RecvPacket<'a>) -> Result<()> {
let Some(ether) = packet.ether else {
return Ok(());
};
let mac = EthernetAddress(ether.destination());
if mac != self.local_mac {
trace!(
"received packet with destination {} which is not the local mac {}",
mac,
self.local_mac
);
return Ok(());
}
let key = match packet.ip {
Some(RecvPacketIp::Ipv4(ipv4)) => self.extract_key_ipv4(packet, ipv4)?,
Some(RecvPacketIp::Ipv6(ipv6)) => self.extract_key_ipv6(packet, ipv6)?,
_ => None,
};
let Some(key) = key else {
return Ok(());
};
for cidr in &self.local_cidrs {
if cidr.contains_addr(&key.external_ip.addr) {
return Ok(());
}
}
let context = NatHandlerContext {
mtu: self.mtu,
key,
transmit_sender: self.transmit_sender.clone(),
reclaim_sender: self.reclaim_sender.clone(),
};
let handler: Option<&mut Box<dyn NatHandler>> = match self.table.inner.entry(key) {
Entry::Occupied(entry) => Some(entry.into_mut()),
Entry::Vacant(entry) => {
if let Some(handler) = self.factory.nat(context).await {
debug!("creating nat entry for key: {}", key);
Some(entry.insert(handler))
} else {
None
}
}
};
if let Some(handler) = handler {
if !handler.receive(packet.raw).await? {
self.reclaim_sender.try_send(key)?;
}
}
Ok(())
}
pub fn extract_key_ipv4<'a>(
&mut self,
packet: &RecvPacket<'a>,
ipv4: &Ipv4Slice<'a>,
) -> Result<Option<NatKey>> {
let source_addr = IpAddress::Ipv4(ipv4.header().source_addr().into());
let dest_addr = IpAddress::Ipv4(ipv4.header().destination_addr().into());
Ok(match ipv4.header().protocol() {
IpNumber::TCP => {
self.extract_key_tcp(packet, source_addr, dest_addr, ipv4.payload())?
}
IpNumber::UDP => {
self.extract_key_udp(packet, source_addr, dest_addr, ipv4.payload())?
}
IpNumber::ICMP => {
self.extract_key_icmpv4(packet, source_addr, dest_addr, ipv4.payload())?
}
_ => None,
})
}
pub fn extract_key_ipv6<'a>(
&mut self,
packet: &RecvPacket<'a>,
ipv6: &Ipv6Slice<'a>,
) -> Result<Option<NatKey>> {
let source_addr = IpAddress::Ipv6(ipv6.header().source_addr().into());
let dest_addr = IpAddress::Ipv6(ipv6.header().destination_addr().into());
Ok(match ipv6.header().next_header() {
IpNumber::TCP => {
self.extract_key_tcp(packet, source_addr, dest_addr, ipv6.payload())?
}
IpNumber::UDP => {
self.extract_key_udp(packet, source_addr, dest_addr, ipv6.payload())?
}
IpNumber::IPV6_ICMP => {
self.extract_key_icmpv6(packet, source_addr, dest_addr, ipv6.payload())?
}
_ => None,
})
}
pub fn extract_key_udp<'a>(
&mut self,
packet: &RecvPacket<'a>,
source_addr: IpAddress,
dest_addr: IpAddress,
payload: &IpPayloadSlice<'a>,
) -> Result<Option<NatKey>> {
let Some(ether) = packet.ether else {
return Ok(None);
};
let header = UdpHeaderSlice::from_slice(payload.payload)?;
let source = IpEndpoint::new(source_addr, header.source_port());
let dest = IpEndpoint::new(dest_addr, header.destination_port());
Ok(Some(NatKey {
protocol: NatKeyProtocol::Udp,
client_mac: EthernetAddress(ether.source()),
local_mac: EthernetAddress(ether.destination()),
client_ip: source,
external_ip: dest,
}))
}
pub fn extract_key_icmpv4<'a>(
&mut self,
packet: &RecvPacket<'a>,
source_addr: IpAddress,
dest_addr: IpAddress,
payload: &IpPayloadSlice<'a>,
) -> Result<Option<NatKey>> {
let Some(ether) = packet.ether else {
return Ok(None);
};
let (header, _) = Icmpv4Header::from_slice(payload.payload)?;
let Icmpv4Type::EchoRequest(_) = header.icmp_type else {
return Ok(None);
};
let source = IpEndpoint::new(source_addr, 0);
let dest = IpEndpoint::new(dest_addr, 0);
Ok(Some(NatKey {
protocol: NatKeyProtocol::Icmp,
client_mac: EthernetAddress(ether.source()),
local_mac: EthernetAddress(ether.destination()),
client_ip: source,
external_ip: dest,
}))
}
pub fn extract_key_icmpv6<'a>(
&mut self,
packet: &RecvPacket<'a>,
source_addr: IpAddress,
dest_addr: IpAddress,
payload: &IpPayloadSlice<'a>,
) -> Result<Option<NatKey>> {
let Some(ether) = packet.ether else {
return Ok(None);
};
let (header, _) = Icmpv6Header::from_slice(payload.payload)?;
let Icmpv6Type::EchoRequest(_) = header.icmp_type else {
return Ok(None);
};
let source = IpEndpoint::new(source_addr, 0);
let dest = IpEndpoint::new(dest_addr, 0);
Ok(Some(NatKey {
protocol: NatKeyProtocol::Icmp,
client_mac: EthernetAddress(ether.source()),
local_mac: EthernetAddress(ether.destination()),
client_ip: source,
external_ip: dest,
}))
}
pub fn extract_key_tcp<'a>(
&mut self,
packet: &RecvPacket<'a>,
source_addr: IpAddress,
dest_addr: IpAddress,
payload: &IpPayloadSlice<'a>,
) -> Result<Option<NatKey>> {
let Some(ether) = packet.ether else {
return Ok(None);
};
let header = TcpHeaderSlice::from_slice(payload.payload)?;
let source = IpEndpoint::new(source_addr, header.source_port());
let dest = IpEndpoint::new(dest_addr, header.destination_port());
Ok(Some(NatKey {
protocol: NatKeyProtocol::Tcp,
client_mac: EthernetAddress(ether.source()),
local_mac: EthernetAddress(ether.destination()),
client_ip: source,
external_ip: dest,
}))
}
}

View File

@ -0,0 +1,21 @@
use std::collections::HashMap;
use super::{handler::NatHandler, key::NatKey};
pub struct NatTable {
pub inner: HashMap<NatKey, Box<dyn NatHandler>>,
}
impl Default for NatTable {
fn default() -> Self {
Self::new()
}
}
impl NatTable {
pub fn new() -> Self {
Self {
inner: HashMap::new(),
}
}
}

View File

@ -0,0 +1,37 @@
use anyhow::Result;
use etherparse::{Ethernet2Slice, Ipv4Slice, Ipv6Slice, LinkSlice, NetSlice, SlicedPacket};
pub enum RecvPacketIp<'a> {
Ipv4(&'a Ipv4Slice<'a>),
Ipv6(&'a Ipv6Slice<'a>),
}
pub struct RecvPacket<'a> {
pub raw: &'a [u8],
pub slice: &'a SlicedPacket<'a>,
pub ether: Option<&'a Ethernet2Slice<'a>>,
pub ip: Option<RecvPacketIp<'a>>,
}
impl RecvPacket<'_> {
pub fn new<'a>(raw: &'a [u8], slice: &'a SlicedPacket<'a>) -> Result<RecvPacket<'a>> {
let ether = match slice.link {
Some(LinkSlice::Ethernet2(ref ether)) => Some(ether),
_ => None,
};
let ip = match slice.net {
Some(NetSlice::Ipv4(ref ipv4)) => Some(RecvPacketIp::Ipv4(ipv4)),
Some(NetSlice::Ipv6(ref ipv6)) => Some(RecvPacketIp::Ipv6(ipv6)),
_ => None,
};
let packet = RecvPacket {
raw,
slice,
ether,
ip,
};
Ok(packet)
}
}

View File

@ -0,0 +1,276 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
time::Duration,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use etherparse::{
IcmpEchoHeader, Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type, IpNumber, Ipv4Slice,
Ipv6Slice, NetSlice, PacketBuilder, SlicedPacket,
};
use log::{debug, trace, warn};
use smoltcp::wire::IpAddress;
use tokio::{
select,
sync::mpsc::{Receiver, Sender},
};
use crate::{
icmp::{IcmpClient, IcmpProtocol, IcmpReply},
nat::handler::{NatHandler, NatHandlerContext},
};
const ICMP_PING_TIMEOUT_SECS: u64 = 20;
const ICMP_TIMEOUT_SECS: u64 = 30;
pub struct ProxyIcmpHandler {
rx_sender: Sender<BytesMut>,
}
#[async_trait]
impl NatHandler for ProxyIcmpHandler {
async fn receive(&self, data: &[u8]) -> Result<bool> {
if self.rx_sender.is_closed() {
Ok(true)
} else {
self.rx_sender.try_send(data.into())?;
Ok(true)
}
}
}
enum ProxyIcmpSelect {
Internal(BytesMut),
Close,
}
impl ProxyIcmpHandler {
pub fn new(rx_sender: Sender<BytesMut>) -> Self {
ProxyIcmpHandler { rx_sender }
}
pub async fn spawn(
&mut self,
context: NatHandlerContext,
rx_receiver: Receiver<BytesMut>,
) -> Result<()> {
let client = IcmpClient::new(match context.key.external_ip.addr {
IpAddress::Ipv4(_) => IcmpProtocol::Icmpv4,
IpAddress::Ipv6(_) => IcmpProtocol::Icmpv6,
})?;
tokio::spawn(async move {
if let Err(error) = ProxyIcmpHandler::process(client, rx_receiver, context).await {
warn!("processing of icmp proxy failed: {}", error);
}
});
Ok(())
}
async fn process(
client: IcmpClient,
mut rx_receiver: Receiver<BytesMut>,
context: NatHandlerContext,
) -> Result<()> {
loop {
let deadline = tokio::time::sleep(Duration::from_secs(ICMP_TIMEOUT_SECS));
let selection = select! {
x = rx_receiver.recv() => if let Some(data) = x {
ProxyIcmpSelect::Internal(data)
} else {
ProxyIcmpSelect::Close
},
_ = deadline => ProxyIcmpSelect::Close,
};
match selection {
ProxyIcmpSelect::Internal(data) => {
let packet = SlicedPacket::from_ethernet(&data)?;
let Some(ref net) = packet.net else {
continue;
};
match net {
NetSlice::Ipv4(ipv4) => {
ProxyIcmpHandler::process_ipv4(&context, ipv4, &client).await?
}
NetSlice::Ipv6(ipv6) => {
ProxyIcmpHandler::process_ipv6(&context, ipv6, &client).await?
}
}
}
ProxyIcmpSelect::Close => {
break;
}
}
}
context.reclaim().await?;
Ok(())
}
async fn process_ipv4(
context: &NatHandlerContext,
ipv4: &Ipv4Slice<'_>,
client: &IcmpClient,
) -> Result<()> {
if ipv4.header().protocol() != IpNumber::ICMP {
return Ok(());
}
let (header, payload) = Icmpv4Header::from_slice(ipv4.payload().payload)?;
if let Icmpv4Type::EchoRequest(echo) = header.icmp_type {
let IpAddr::V4(external_ipv4) = context.key.external_ip.addr.into() else {
return Ok(());
};
let context = context.clone();
let client = client.clone();
let payload = payload.to_vec();
tokio::task::spawn(async move {
if let Err(error) = ProxyIcmpHandler::process_echo_ipv4(
context,
client,
external_ipv4,
echo,
payload,
)
.await
{
trace!("icmp4 echo failed: {}", error);
}
});
}
Ok(())
}
async fn process_ipv6(
context: &NatHandlerContext,
ipv6: &Ipv6Slice<'_>,
client: &IcmpClient,
) -> Result<()> {
if ipv6.header().next_header() != IpNumber::IPV6_ICMP {
return Ok(());
}
let (header, payload) = Icmpv6Header::from_slice(ipv6.payload().payload)?;
if let Icmpv6Type::EchoRequest(echo) = header.icmp_type {
let IpAddr::V6(external_ipv6) = context.key.external_ip.addr.into() else {
return Ok(());
};
let context = context.clone();
let client = client.clone();
let payload = payload.to_vec();
tokio::task::spawn(async move {
if let Err(error) = ProxyIcmpHandler::process_echo_ipv6(
context,
client,
external_ipv6,
echo,
payload,
)
.await
{
trace!("icmp6 echo failed: {}", error);
}
});
}
Ok(())
}
async fn process_echo_ipv4(
context: NatHandlerContext,
client: IcmpClient,
external_ipv4: Ipv4Addr,
echo: IcmpEchoHeader,
payload: Vec<u8>,
) -> Result<()> {
let reply = client
.ping4(
external_ipv4,
echo.id,
echo.seq,
&payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?;
let Some(IcmpReply::Icmpv4 {
header: _,
echo,
payload,
}) = reply
else {
return Ok(());
};
let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
packet.ipv4(external_addr.0, client_addr.0, 20)
}
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv4_echo_reply(echo.id, echo.seq);
let buffer = BytesMut::with_capacity(packet.size(payload.len()));
let mut writer = buffer.writer();
packet.write(&mut writer, &payload)?;
let buffer = writer.into_inner();
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit icmp packet: {}", error);
}
Ok(())
}
async fn process_echo_ipv6(
context: NatHandlerContext,
client: IcmpClient,
external_ipv6: Ipv6Addr,
echo: IcmpEchoHeader,
payload: Vec<u8>,
) -> Result<()> {
let reply = client
.ping6(
external_ipv6,
echo.id,
echo.seq,
&payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?;
let Some(IcmpReply::Icmpv6 {
header: _,
echo,
payload,
}) = reply
else {
return Ok(());
};
let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => {
packet.ipv6(external_addr.0, client_addr.0, 20)
}
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv6_echo_reply(echo.id, echo.seq);
let buffer = BytesMut::with_capacity(packet.size(payload.len()));
let mut writer = buffer.writer();
packet.write(&mut writer, &payload)?;
let buffer = writer.into_inner();
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit icmp packet: {}", error);
}
Ok(())
}
}

View File

@ -0,0 +1,77 @@
use async_trait::async_trait;
use bytes::BytesMut;
use log::warn;
use tokio::sync::mpsc::channel;
use crate::proxynat::udp::ProxyUdpHandler;
use crate::nat::handler::{NatHandler, NatHandlerContext, NatHandlerFactory};
use crate::nat::key::NatKeyProtocol;
use self::icmp::ProxyIcmpHandler;
use self::tcp::ProxyTcpHandler;
mod icmp;
mod tcp;
mod udp;
const RX_CHANNEL_QUEUE_LEN: usize = 1000;
pub struct ProxyNatHandlerFactory {}
impl Default for ProxyNatHandlerFactory {
fn default() -> Self {
Self::new()
}
}
impl ProxyNatHandlerFactory {
pub fn new() -> Self {
Self {}
}
}
#[async_trait]
impl NatHandlerFactory for ProxyNatHandlerFactory {
async fn nat(&self, context: NatHandlerContext) -> Option<Box<dyn NatHandler>> {
match context.key.protocol {
NatKeyProtocol::Udp => {
let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
let mut handler = ProxyUdpHandler::new(rx_sender);
if let Err(error) = handler.spawn(context, rx_receiver).await {
warn!("unable to spawn udp proxy handler: {}", error);
None
} else {
Some(Box::new(handler))
}
}
NatKeyProtocol::Icmp => {
let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
let mut handler = ProxyIcmpHandler::new(rx_sender);
if let Err(error) = handler.spawn(context, rx_receiver).await {
warn!("unable to spawn icmp proxy handler: {}", error);
None
} else {
Some(Box::new(handler))
}
}
NatKeyProtocol::Tcp => {
let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
let mut handler = ProxyTcpHandler::new(rx_sender);
if let Err(error) = handler.spawn(context, rx_receiver).await {
warn!("unable to spawn tcp proxy handler: {}", error);
None
} else {
Some(Box::new(handler))
}
}
}
}
}

View File

@ -0,0 +1,466 @@
use std::{
net::{IpAddr, SocketAddr},
time::Duration,
};
use anyhow::Result;
use async_trait::async_trait;
use bytes::BytesMut;
use etherparse::{EtherType, Ethernet2Header};
use log::{debug, warn};
use smoltcp::{
iface::{Config, Interface, SocketSet, SocketStorage},
phy::Medium,
socket::tcp::{self, SocketBuffer, State},
time::Instant,
wire::{HardwareAddress, IpAddress, IpCidr},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
select,
sync::mpsc::channel,
};
use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender};
use crate::{
chandev::ChannelDevice,
nat::handler::{NatHandler, NatHandlerContext},
};
const TCP_BUFFER_SIZE: usize = 65535;
const TCP_IP_BUFFER_QUEUE_LEN: usize = 3000;
const TCP_ACCEPT_TIMEOUT_SECS: u64 = 120;
const TCP_DANGLE_TIMEOUT_SECS: u64 = 10;
pub struct ProxyTcpHandler {
rx_sender: Sender<BytesMut>,
}
#[async_trait]
impl NatHandler for ProxyTcpHandler {
async fn receive(&self, data: &[u8]) -> Result<bool> {
if self.rx_sender.is_closed() {
Ok(false)
} else {
self.rx_sender.try_send(data.into())?;
Ok(true)
}
}
}
#[derive(Debug)]
enum ProxyTcpAcceptSelect {
Internal(BytesMut),
TxIpPacket(BytesMut),
TimePassed,
DoNothing,
Close,
}
#[derive(Debug)]
enum ProxyTcpDataSelect {
ExternalRecv(usize),
ExternalSent(usize),
InternalRecv(BytesMut),
TxIpPacket(BytesMut),
TimePassed,
DoNothing,
Close,
}
#[derive(Debug)]
enum ProxyTcpFinishSelect {
InternalRecv(BytesMut),
TxIpPacket(BytesMut),
Close,
}
impl ProxyTcpHandler {
pub fn new(rx_sender: Sender<BytesMut>) -> Self {
ProxyTcpHandler { rx_sender }
}
pub async fn spawn(
&mut self,
context: NatHandlerContext,
rx_receiver: Receiver<BytesMut>,
) -> Result<()> {
let external_addr = match context.key.external_ip.addr {
IpAddress::Ipv4(addr) => {
SocketAddr::new(IpAddr::V4(addr.0.into()), context.key.external_ip.port)
}
IpAddress::Ipv6(addr) => {
SocketAddr::new(IpAddr::V6(addr.0.into()), context.key.external_ip.port)
}
};
let socket = TcpStream::connect(external_addr).await?;
tokio::spawn(async move {
if let Err(error) = ProxyTcpHandler::process(context, socket, rx_receiver).await {
warn!("processing of tcp proxy failed: {}", error);
}
});
Ok(())
}
async fn process(
context: NatHandlerContext,
mut external_socket: TcpStream,
mut rx_receiver: Receiver<BytesMut>,
) -> Result<()> {
let (ip_sender, mut ip_receiver) = channel::<BytesMut>(TCP_IP_BUFFER_QUEUE_LEN);
let mut external_buffer = vec![0u8; TCP_BUFFER_SIZE];
let mut device = ChannelDevice::new(
context.mtu - Ethernet2Header::LEN,
Medium::Ip,
ip_sender.clone(),
);
let config = Config::new(HardwareAddress::Ip);
let tcp_rx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]);
let tcp_tx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]);
let internal_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
let mut iface = Interface::new(config, &mut device, Instant::now());
iface.update_ip_addrs(|addrs| {
let _ = addrs.push(IpCidr::new(context.key.external_ip.addr, 0));
});
let mut sockets = SocketSet::new([SocketStorage::EMPTY]);
let internal_socket_handle = sockets.add(internal_socket);
let (mut external_r, mut external_w) = external_socket.split();
{
let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
socket.connect(
iface.context(),
context.key.client_ip,
context.key.external_ip,
)?;
}
iface.poll(Instant::now(), &mut device, &mut sockets);
let mut sleeper: Option<tokio::time::Sleep> = None;
loop {
let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
if socket.is_active() && socket.state() != State::SynSent {
break;
}
if socket.state() == State::Closed {
break;
}
let deadline = tokio::time::sleep(Duration::from_secs(TCP_ACCEPT_TIMEOUT_SECS));
let selection = if let Some(sleep) = sleeper.take() {
select! {
biased;
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpAcceptSelect::Internal(data)
} else {
ProxyTcpAcceptSelect::Close
},
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpAcceptSelect::TxIpPacket(data)
} else {
ProxyTcpAcceptSelect::Close
},
_ = sleep => ProxyTcpAcceptSelect::TimePassed,
_ = deadline => ProxyTcpAcceptSelect::Close,
}
} else {
select! {
biased;
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpAcceptSelect::Internal(data)
} else {
ProxyTcpAcceptSelect::Close
},
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpAcceptSelect::TxIpPacket(data)
} else {
ProxyTcpAcceptSelect::Close
},
_ = std::future::ready(()) => ProxyTcpAcceptSelect::DoNothing,
_ = deadline => ProxyTcpAcceptSelect::Close,
}
};
match selection {
ProxyTcpAcceptSelect::TimePassed => {
iface.poll(Instant::now(), &mut device, &mut sockets);
}
ProxyTcpAcceptSelect::DoNothing => {
sleeper = Some(tokio::time::sleep(Duration::from_micros(100)));
}
ProxyTcpAcceptSelect::Internal(data) => {
let (_, payload) = Ethernet2Header::from_slice(&data)?;
device.rx = Some(payload.into());
iface.poll(Instant::now(), &mut device, &mut sockets);
}
ProxyTcpAcceptSelect::TxIpPacket(payload) => {
let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
let header = Ethernet2Header {
source: context.key.local_mac.0,
destination: context.key.client_mac.0,
ether_type: match context.key.external_ip.addr {
IpAddress::Ipv4(_) => EtherType::IPV4,
IpAddress::Ipv6(_) => EtherType::IPV6,
},
};
buffer.extend_from_slice(&header.to_bytes());
buffer.extend_from_slice(&payload);
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit tcp packet: {}", error);
}
}
ProxyTcpAcceptSelect::Close => {
break;
}
}
}
let accepted = if sockets
.get_mut::<tcp::Socket>(internal_socket_handle)
.is_active()
{
true
} else {
debug!("failed to accept tcp connection from client");
false
};
let mut already_shutdown = false;
let mut sleeper: Option<tokio::time::Sleep> = None;
loop {
if !accepted {
break;
}
let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
match socket.state() {
State::Closed
| State::Listen
| State::Closing
| State::LastAck
| State::TimeWait => {
break;
}
State::FinWait1
| State::SynSent
| State::CloseWait
| State::FinWait2
| State::SynReceived
| State::Established => {}
}
let bytes_to_client = if socket.can_send() {
socket.send_capacity() - socket.send_queue()
} else {
0
};
let (bytes_to_external, do_shutdown) = if socket.may_recv() {
if let Ok(data) = socket.peek(TCP_BUFFER_SIZE) {
if data.is_empty() {
(None, false)
} else {
(Some(data), false)
}
} else {
(None, false)
}
} else if !already_shutdown && matches!(socket.state(), State::CloseWait) {
(None, true)
} else {
(None, false)
};
let selection = if let Some(sleep) = sleeper.take() {
if !do_shutdown {
select! {
biased;
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::TxIpPacket(data)
} else {
ProxyTcpDataSelect::Close
},
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::InternalRecv(data)
} else {
ProxyTcpDataSelect::Close
},
x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?),
x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
_ = sleep => ProxyTcpDataSelect::TimePassed,
}
} else {
select! {
biased;
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::TxIpPacket(data)
} else {
ProxyTcpDataSelect::Close
},
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::InternalRecv(data)
} else {
ProxyTcpDataSelect::Close
},
_ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0),
x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
_ = sleep => ProxyTcpDataSelect::TimePassed,
}
}
} else if !do_shutdown {
select! {
biased;
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::TxIpPacket(data)
} else {
ProxyTcpDataSelect::Close
},
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::InternalRecv(data)
} else {
ProxyTcpDataSelect::Close
},
x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?),
x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
_ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing,
}
} else {
select! {
biased;
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::TxIpPacket(data)
} else {
ProxyTcpDataSelect::Close
},
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpDataSelect::InternalRecv(data)
} else {
ProxyTcpDataSelect::Close
},
_ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0),
x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
_ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing,
}
};
match selection {
ProxyTcpDataSelect::ExternalRecv(size) => {
if size == 0 {
socket.close();
} else {
socket.send_slice(&external_buffer[..size])?;
}
}
ProxyTcpDataSelect::ExternalSent(size) => {
if size == 0 {
already_shutdown = true;
} else {
socket.recv(|_| (size, ()))?;
}
}
ProxyTcpDataSelect::InternalRecv(data) => {
let (_, payload) = Ethernet2Header::from_slice(&data)?;
device.rx = Some(payload.into());
iface.poll(Instant::now(), &mut device, &mut sockets);
}
ProxyTcpDataSelect::TxIpPacket(payload) => {
let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
let header = Ethernet2Header {
source: context.key.local_mac.0,
destination: context.key.client_mac.0,
ether_type: match context.key.external_ip.addr {
IpAddress::Ipv4(_) => EtherType::IPV4,
IpAddress::Ipv6(_) => EtherType::IPV6,
},
};
buffer.extend_from_slice(&header.to_bytes());
buffer.extend_from_slice(&payload);
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit tcp packet: {}", error);
}
}
ProxyTcpDataSelect::TimePassed => {
iface.poll(Instant::now(), &mut device, &mut sockets);
}
ProxyTcpDataSelect::DoNothing => {
sleeper = Some(tokio::time::sleep(Duration::from_micros(100)));
}
ProxyTcpDataSelect::Close => {
break;
}
}
}
let _ = external_socket.shutdown().await;
drop(external_socket);
loop {
let deadline = tokio::time::sleep(Duration::from_secs(TCP_DANGLE_TIMEOUT_SECS));
tokio::pin!(deadline);
let selection = select! {
biased;
x = ip_receiver.recv() => if let Some(data) = x {
ProxyTcpFinishSelect::TxIpPacket(data)
} else {
ProxyTcpFinishSelect::Close
},
x = rx_receiver.recv() => if let Some(data) = x {
ProxyTcpFinishSelect::InternalRecv(data)
} else {
ProxyTcpFinishSelect::Close
},
_ = deadline => ProxyTcpFinishSelect::Close,
};
match selection {
ProxyTcpFinishSelect::InternalRecv(data) => {
let (_, payload) = Ethernet2Header::from_slice(&data)?;
device.rx = Some(payload.into());
iface.poll(Instant::now(), &mut device, &mut sockets);
}
ProxyTcpFinishSelect::TxIpPacket(payload) => {
let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
let header = Ethernet2Header {
source: context.key.local_mac.0,
destination: context.key.client_mac.0,
ether_type: match context.key.external_ip.addr {
IpAddress::Ipv4(_) => EtherType::IPV4,
IpAddress::Ipv6(_) => EtherType::IPV6,
},
};
buffer.extend_from_slice(&header.to_bytes());
buffer.extend_from_slice(&payload);
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit tcp packet: {}", error);
}
}
ProxyTcpFinishSelect::Close => {
break;
}
}
}
context.reclaim().await?;
Ok(())
}
}

View File

@ -0,0 +1,142 @@
use std::{
net::{IpAddr, SocketAddr},
time::Duration,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use etherparse::{PacketBuilder, SlicedPacket, UdpSlice};
use log::{debug, warn};
use smoltcp::wire::IpAddress;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
};
use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender};
use udp_stream::UdpStream;
use crate::nat::handler::{NatHandler, NatHandlerContext};
const UDP_TIMEOUT_SECS: u64 = 60;
pub struct ProxyUdpHandler {
rx_sender: Sender<BytesMut>,
}
#[async_trait]
impl NatHandler for ProxyUdpHandler {
async fn receive(&self, data: &[u8]) -> Result<bool> {
if self.rx_sender.is_closed() {
Ok(true)
} else {
self.rx_sender.try_send(data.into())?;
Ok(true)
}
}
}
enum ProxyUdpSelect {
External(usize),
Internal(BytesMut),
Close,
}
impl ProxyUdpHandler {
pub fn new(rx_sender: Sender<BytesMut>) -> Self {
ProxyUdpHandler { rx_sender }
}
pub async fn spawn(
&mut self,
context: NatHandlerContext,
rx_receiver: Receiver<BytesMut>,
) -> Result<()> {
let external_addr = match context.key.external_ip.addr {
IpAddress::Ipv4(addr) => {
SocketAddr::new(IpAddr::V4(addr.0.into()), context.key.external_ip.port)
}
IpAddress::Ipv6(addr) => {
SocketAddr::new(IpAddr::V6(addr.0.into()), context.key.external_ip.port)
}
};
let socket = UdpStream::connect(external_addr).await?;
tokio::spawn(async move {
if let Err(error) = ProxyUdpHandler::process(context, socket, rx_receiver).await {
warn!("processing of udp proxy failed: {}", error);
}
});
Ok(())
}
async fn process(
context: NatHandlerContext,
mut socket: UdpStream,
mut rx_receiver: Receiver<BytesMut>,
) -> Result<()> {
let mut external_buffer = vec![0u8; 2048];
loop {
let deadline = tokio::time::sleep(Duration::from_secs(UDP_TIMEOUT_SECS));
let selection = select! {
x = rx_receiver.recv() => if let Some(data) = x {
ProxyUdpSelect::Internal(data)
} else {
ProxyUdpSelect::Close
},
x = socket.read(&mut external_buffer) => ProxyUdpSelect::External(x?),
_ = deadline => ProxyUdpSelect::Close,
};
match selection {
ProxyUdpSelect::External(size) => {
let data = &external_buffer[0..size];
let packet =
PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
(IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
packet.ipv4(external_addr.0, client_addr.0, 20)
}
(IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => {
packet.ipv6(external_addr.0, client_addr.0, 20)
}
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet =
packet.udp(context.key.external_ip.port, context.key.client_ip.port);
let buffer = BytesMut::with_capacity(packet.size(data.len()));
let mut writer = buffer.writer();
packet.write(&mut writer, data)?;
let buffer = writer.into_inner();
if let Err(error) = context.try_transmit(buffer) {
debug!("failed to transmit udp packet: {}", error);
}
}
ProxyUdpSelect::Internal(data) => {
let packet = SlicedPacket::from_ethernet(&data)?;
let Some(ref net) = packet.net else {
continue;
};
let Some(ip) = net.ip_payload_ref() else {
continue;
};
let udp = UdpSlice::from_slice(ip.payload)?;
socket.write_all(udp.payload()).await?;
}
ProxyUdpSelect::Close => {
drop(socket);
break;
}
}
}
context.reclaim().await?;
Ok(())
}
}

View File

@ -0,0 +1,299 @@
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use log::{debug, warn};
use std::io::ErrorKind;
use std::os::fd::{FromRawFd, IntoRawFd};
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use std::{io, mem};
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 3000;
const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 3000;
#[derive(Debug)]
pub enum RawSocketProtocol {
Icmpv4,
Icmpv6,
Ethernet,
}
impl RawSocketProtocol {
pub fn to_socket_domain(&self) -> i32 {
match self {
RawSocketProtocol::Icmpv4 => libc::AF_INET,
RawSocketProtocol::Icmpv6 => libc::AF_INET6,
RawSocketProtocol::Ethernet => libc::AF_PACKET,
}
}
pub fn to_socket_protocol(&self) -> u16 {
match self {
RawSocketProtocol::Icmpv4 => libc::IPPROTO_ICMP as u16,
RawSocketProtocol::Icmpv6 => libc::IPPROTO_ICMPV6 as u16,
RawSocketProtocol::Ethernet => (libc::ETH_P_ALL as u16).to_be(),
}
}
pub fn to_socket_type(&self) -> i32 {
libc::SOCK_RAW
}
}
const SIOCGIFINDEX: libc::c_ulong = 0x8933;
const SIOCGIFMTU: libc::c_ulong = 0x8921;
#[derive(Debug)]
pub struct RawSocketHandle {
protocol: RawSocketProtocol,
lower: libc::c_int,
}
impl AsRawFd for RawSocketHandle {
fn as_raw_fd(&self) -> RawFd {
self.lower
}
}
impl IntoRawFd for RawSocketHandle {
fn into_raw_fd(self) -> RawFd {
let fd = self.lower;
mem::forget(self);
fd
}
}
impl RawSocketHandle {
pub fn new(protocol: RawSocketProtocol) -> io::Result<RawSocketHandle> {
let lower = unsafe {
let lower = libc::socket(
protocol.to_socket_domain(),
protocol.to_socket_type() | libc::SOCK_NONBLOCK,
protocol.to_socket_protocol() as i32,
);
if lower == -1 {
return Err(io::Error::last_os_error());
}
lower
};
Ok(RawSocketHandle { protocol, lower })
}
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
let mut socket = RawSocketHandle::new(protocol)?;
socket.bind_to_interface(interface)?;
Ok(socket)
}
pub fn bind_to_interface(&mut self, interface: &str) -> io::Result<()> {
let mut ifreq = ifreq_for(interface);
let sockaddr = libc::sockaddr_ll {
sll_family: libc::AF_PACKET as u16,
sll_protocol: self.protocol.to_socket_protocol(),
sll_ifindex: ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFINDEX)?,
sll_hatype: 1,
sll_pkttype: 0,
sll_halen: 6,
sll_addr: [0; 8],
};
unsafe {
let res = libc::bind(
self.lower,
&sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr,
mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
);
if res == -1 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
pub fn mtu_of_interface(&mut self, interface: &str) -> io::Result<usize> {
let mut ifreq = ifreq_for(interface);
ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFMTU).map(|mtu| mtu as usize)
}
pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
unsafe {
let len = libc::recv(
self.lower,
buffer.as_mut_ptr() as *mut libc::c_void,
buffer.len(),
0,
);
if len == -1 {
return Err(io::Error::last_os_error());
}
Ok(len as usize)
}
}
pub fn send(&self, buffer: &[u8]) -> io::Result<usize> {
unsafe {
let len = libc::send(
self.lower,
buffer.as_ptr() as *const libc::c_void,
buffer.len(),
0,
);
if len == -1 {
return Err(io::Error::last_os_error());
}
Ok(len as usize)
}
}
}
impl Drop for RawSocketHandle {
fn drop(&mut self) {
unsafe {
libc::close(self.lower);
}
}
}
#[repr(C)]
#[derive(Debug)]
struct Ifreq {
ifr_name: [libc::c_char; libc::IF_NAMESIZE],
ifr_data: libc::c_int,
}
fn ifreq_for(name: &str) -> Ifreq {
let mut ifreq = Ifreq {
ifr_name: [0; libc::IF_NAMESIZE],
ifr_data: 0,
};
for (i, byte) in name.as_bytes().iter().enumerate() {
ifreq.ifr_name[i] = *byte as libc::c_char
}
ifreq
}
fn ifreq_ioctl(
lower: libc::c_int,
ifreq: &mut Ifreq,
cmd: libc::c_ulong,
) -> io::Result<libc::c_int> {
unsafe {
let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq);
if res == -1 {
return Err(io::Error::last_os_error());
}
}
Ok(ifreq.ifr_data)
}
pub struct AsyncRawSocketChannel {
pub sender: Sender<BytesMut>,
pub receiver: Receiver<BytesMut>,
_task: Arc<JoinHandle<()>>,
}
enum AsyncRawSocketChannelSelect {
TransmitPacket(Option<BytesMut>),
Readable(()),
}
impl AsyncRawSocketChannel {
pub fn new(mtu: usize, socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
let task = AsyncRawSocketChannel::launch(mtu, socket, transmit_receiver, receive_sender)?;
Ok(AsyncRawSocketChannel {
sender: transmit_sender,
receiver: receive_receiver,
_task: Arc::new(task),
})
}
fn launch(
mtu: usize,
socket: RawSocketHandle,
transmit_receiver: Receiver<BytesMut>,
receive_sender: Sender<BytesMut>,
) -> Result<JoinHandle<()>> {
Ok(tokio::task::spawn(async move {
if let Err(error) =
AsyncRawSocketChannel::process(mtu, socket, transmit_receiver, receive_sender).await
{
warn!("failed to process raw socket: {}", error);
}
}))
}
async fn process(
mtu: usize,
socket: RawSocketHandle,
mut transmit_receiver: Receiver<BytesMut>,
receive_sender: Sender<BytesMut>,
) -> Result<()> {
let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
let socket = UdpSocket::from_std(socket)?;
let mut buffer = vec![0; mtu];
loop {
let selection = select! {
x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
};
match selection {
AsyncRawSocketChannelSelect::Readable(_) => {
match socket.try_recv(&mut buffer) {
Ok(len) => {
if len == 0 {
continue;
}
let buffer = (&buffer[0..len]).into();
if let Err(error) = receive_sender.try_send(buffer) {
debug!(
"failed to process received packet from raw socket: {}",
error
);
}
}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
continue;
}
return Err(anyhow!("failed to read from raw socket: {}", error));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
match socket.try_send(&packet) {
Ok(_len) => {}
Err(ref error) => {
if error.kind() == ErrorKind::WouldBlock {
debug!("failed to transmit: would block");
continue;
}
return Err(anyhow!(
"failed to write {} bytes to raw socket: {}",
packet.len(),
error
));
}
};
}
AsyncRawSocketChannelSelect::TransmitPacket(None) => {
break;
}
}
}
Ok(())
}
}

View File

@ -0,0 +1,212 @@
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use etherparse::{EtherType, Ethernet2Header, IpNumber, Ipv4Header, Ipv6Header, TcpHeader};
use log::{debug, trace, warn};
use smoltcp::wire::EthernetAddress;
use std::{
collections::{hash_map::Entry, HashMap},
sync::Arc,
};
use tokio::sync::broadcast::{
channel as broadcast_channel, Receiver as BroadcastReceiver, Sender as BroadcastSender,
};
use tokio::{
select,
sync::{
mpsc::{channel, Receiver, Sender},
Mutex,
},
task::JoinHandle,
};
const TO_BRIDGE_QUEUE_LEN: usize = 3000;
const FROM_BRIDGE_QUEUE_LEN: usize = 3000;
const BROADCAST_QUEUE_LEN: usize = 3000;
const MEMBER_LEAVE_QUEUE_LEN: usize = 30;
#[derive(Debug)]
struct BridgeMember {
pub from_bridge_sender: Sender<BytesMut>,
}
pub struct BridgeJoinHandle {
mac: EthernetAddress,
pub to_bridge_sender: Sender<BytesMut>,
pub from_bridge_receiver: Receiver<BytesMut>,
pub from_broadcast_receiver: BroadcastReceiver<BytesMut>,
member_leave_sender: Sender<EthernetAddress>,
}
impl Drop for BridgeJoinHandle {
fn drop(&mut self) {
if let Err(error) = self.member_leave_sender.try_send(self.mac) {
warn!(
"virtual bridge member {} failed to leave: {}",
self.mac, error
);
}
}
}
type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>;
#[derive(Clone)]
pub struct VirtualBridge {
to_bridge_sender: Sender<BytesMut>,
from_broadcast_sender: BroadcastSender<BytesMut>,
member_leave_sender: Sender<EthernetAddress>,
members: VirtualBridgeMemberMap,
_task: Arc<JoinHandle<()>>,
}
enum VirtualBridgeSelect {
BroadcastSent(Option<BytesMut>),
PacketReceived(Option<BytesMut>),
MemberLeave(Option<EthernetAddress>),
}
impl VirtualBridge {
pub fn new() -> Result<VirtualBridge> {
let (to_bridge_sender, to_bridge_receiver) = channel::<BytesMut>(TO_BRIDGE_QUEUE_LEN);
let (member_leave_sender, member_leave_reciever) =
channel::<EthernetAddress>(MEMBER_LEAVE_QUEUE_LEN);
let (from_broadcast_sender, from_broadcast_receiver) =
broadcast_channel(BROADCAST_QUEUE_LEN);
let members = Arc::new(Mutex::new(HashMap::new()));
let handle = {
let members = members.clone();
let broadcast_rx_sender = from_broadcast_sender.clone();
tokio::task::spawn(async move {
if let Err(error) = VirtualBridge::process(
members,
member_leave_reciever,
to_bridge_receiver,
broadcast_rx_sender,
from_broadcast_receiver,
)
.await
{
warn!("virtual bridge processing task failed: {}", error);
}
})
};
Ok(VirtualBridge {
to_bridge_sender,
from_broadcast_sender,
member_leave_sender,
members,
_task: Arc::new(handle),
})
}
pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> {
let (from_bridge_sender, from_bridge_receiver) = channel::<BytesMut>(FROM_BRIDGE_QUEUE_LEN);
let member = BridgeMember { from_bridge_sender };
match self.members.lock().await.entry(mac) {
Entry::Occupied(_) => {
return Err(anyhow!("virtual bridge member {} already exists", mac));
}
Entry::Vacant(entry) => {
entry.insert(member);
}
};
debug!("virtual bridge member {} has joined", mac);
Ok(BridgeJoinHandle {
mac,
member_leave_sender: self.member_leave_sender.clone(),
from_bridge_receiver,
from_broadcast_receiver: self.from_broadcast_sender.subscribe(),
to_bridge_sender: self.to_bridge_sender.clone(),
})
}
async fn process(
members: VirtualBridgeMemberMap,
mut member_leave_reciever: Receiver<EthernetAddress>,
mut to_bridge_receiver: Receiver<BytesMut>,
broadcast_rx_sender: BroadcastSender<BytesMut>,
mut from_broadcast_receiver: BroadcastReceiver<BytesMut>,
) -> Result<()> {
loop {
let selection = select! {
biased;
x = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()),
x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x),
x = member_leave_reciever.recv() => VirtualBridgeSelect::MemberLeave(x),
};
match selection {
VirtualBridgeSelect::PacketReceived(Some(mut packet)) => {
let (header, payload) = match Ethernet2Header::from_slice(&packet) {
Ok(data) => data,
Err(error) => {
debug!("virtual bridge failed to parse ethernet header: {}", error);
continue;
}
};
// recalculate TCP checksums when routing packets.
// the xen network backend / frontend drivers for linux
// use checksum offloading but since we bypass some layers
// of the kernel we have to do it ourselves.
if header.ether_type == EtherType::IPV4 {
let (ipv4, payload) = Ipv4Header::from_slice(payload)?;
if ipv4.protocol == IpNumber::TCP {
let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
tcp.checksum = tcp.calc_checksum_ipv4(&ipv4, payload)?;
let tcp_header_offset = Ethernet2Header::LEN + ipv4.header_len();
let tcp_header_bytes = tcp.to_bytes();
for (i, b) in tcp_header_bytes.iter().enumerate() {
packet[tcp_header_offset + i] = *b;
}
}
} else if header.ether_type == EtherType::IPV6 {
let (ipv6, payload) = Ipv6Header::from_slice(payload)?;
if ipv6.next_header == IpNumber::TCP {
let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
tcp.checksum = tcp.calc_checksum_ipv6(&ipv6, payload)?;
let tcp_header_offset = Ethernet2Header::LEN + ipv6.header_len();
let tcp_header_bytes = tcp.to_bytes();
for (i, b) in tcp_header_bytes.iter().enumerate() {
packet[tcp_header_offset + i] = *b;
}
}
}
let destination = EthernetAddress(header.destination);
if destination.is_multicast() {
broadcast_rx_sender.send(packet)?;
continue;
}
match members.lock().await.get(&destination) {
Some(member) => {
member.from_bridge_sender.try_send(packet)?;
trace!(
"sending bridged packet from {} to {}",
EthernetAddress(header.source),
EthernetAddress(header.destination)
);
}
None => {
trace!("no bridge member with address: {}", destination);
}
}
}
VirtualBridgeSelect::MemberLeave(Some(mac)) => {
if members.lock().await.remove(&mac).is_some() {
debug!("virtual bridge member {} has left", mac);
}
}
VirtualBridgeSelect::PacketReceived(None) => break,
VirtualBridgeSelect::MemberLeave(None) => {}
VirtualBridgeSelect::BroadcastSent(_) => {}
}
}
Ok(())
}
}