mirror of
				https://github.com/edera-dev/krata.git
				synced 2025-11-03 23:29:39 +00:00 
			
		
		
		
	krata: reorganize crates
This commit is contained in:
		
							
								
								
									
										185
									
								
								crates/kratanet/src/autonet.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								crates/kratanet/src/autonet.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,185 @@
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use smoltcp::wire::{EthernetAddress, Ipv4Cidr, Ipv6Cidr};
 | 
			
		||||
use std::{collections::HashMap, str::FromStr};
 | 
			
		||||
use uuid::Uuid;
 | 
			
		||||
use xenstore::client::{XsdClient, XsdInterface, XsdTransaction};
 | 
			
		||||
 | 
			
		||||
pub struct AutoNetworkCollector {
 | 
			
		||||
    client: XsdClient,
 | 
			
		||||
    known: HashMap<Uuid, NetworkMetadata>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct NetworkSide {
 | 
			
		||||
    pub ipv4: Ipv4Cidr,
 | 
			
		||||
    pub ipv6: Ipv6Cidr,
 | 
			
		||||
    pub mac: EthernetAddress,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct NetworkMetadata {
 | 
			
		||||
    pub domid: u32,
 | 
			
		||||
    pub uuid: Uuid,
 | 
			
		||||
    pub guest: NetworkSide,
 | 
			
		||||
    pub gateway: NetworkSide,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkMetadata {
 | 
			
		||||
    pub fn interface(&self) -> String {
 | 
			
		||||
        format!("vif{}.20", self.domid)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct AutoNetworkChangeset {
 | 
			
		||||
    pub added: Vec<NetworkMetadata>,
 | 
			
		||||
    pub removed: Vec<NetworkMetadata>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AutoNetworkCollector {
 | 
			
		||||
    pub async fn new() -> Result<AutoNetworkCollector> {
 | 
			
		||||
        Ok(AutoNetworkCollector {
 | 
			
		||||
            client: XsdClient::open().await?,
 | 
			
		||||
            known: HashMap::new(),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn read(&mut self) -> Result<Vec<NetworkMetadata>> {
 | 
			
		||||
        let mut networks = Vec::new();
 | 
			
		||||
        let tx = self.client.transaction().await?;
 | 
			
		||||
        for domid_string in tx.list("/local/domain").await? {
 | 
			
		||||
            let Ok(domid) = domid_string.parse::<u32>() else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let dom_path = format!("/local/domain/{}", domid_string);
 | 
			
		||||
            let Some(uuid_string) = tx.read_string(&format!("{}/krata/uuid", dom_path)).await?
 | 
			
		||||
            else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let Ok(uuid) = uuid_string.parse::<Uuid>() else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let Ok(guest) =
 | 
			
		||||
                AutoNetworkCollector::read_network_side(uuid, &tx, &dom_path, "guest").await
 | 
			
		||||
            else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let Ok(gateway) =
 | 
			
		||||
                AutoNetworkCollector::read_network_side(uuid, &tx, &dom_path, "gateway").await
 | 
			
		||||
            else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            networks.push(NetworkMetadata {
 | 
			
		||||
                domid,
 | 
			
		||||
                uuid,
 | 
			
		||||
                guest,
 | 
			
		||||
                gateway,
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        tx.commit().await?;
 | 
			
		||||
        Ok(networks)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn read_network_side(
 | 
			
		||||
        uuid: Uuid,
 | 
			
		||||
        tx: &XsdTransaction,
 | 
			
		||||
        dom_path: &str,
 | 
			
		||||
        side: &str,
 | 
			
		||||
    ) -> Result<NetworkSide> {
 | 
			
		||||
        let side_path = format!("{}/krata/network/{}", dom_path, side);
 | 
			
		||||
        let Some(ipv4) = tx.read_string(&format!("{}/ipv4", side_path)).await? else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} is missing {} ipv4 network entry",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Some(ipv6) = tx.read_string(&format!("{}/ipv6", side_path)).await? else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} is missing {} ipv6 network entry",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Some(mac) = tx.read_string(&format!("{}/mac", side_path)).await? else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} is missing {} mac address entry",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Ok(ipv4) = Ipv4Cidr::from_str(&ipv4) else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} has invalid {} ipv4 network cidr entry: {}",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side,
 | 
			
		||||
                ipv4
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Ok(ipv6) = Ipv6Cidr::from_str(&ipv6) else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} has invalid {} ipv6 network cidr entry: {}",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side,
 | 
			
		||||
                ipv6
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Ok(mac) = EthernetAddress::from_str(&mac) else {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "krata domain {} has invalid {} mac address entry: {}",
 | 
			
		||||
                uuid,
 | 
			
		||||
                side,
 | 
			
		||||
                mac
 | 
			
		||||
            ));
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(NetworkSide { ipv4, ipv6, mac })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn read_changes(&mut self) -> Result<AutoNetworkChangeset> {
 | 
			
		||||
        let mut seen: Vec<Uuid> = Vec::new();
 | 
			
		||||
        let mut added: Vec<NetworkMetadata> = Vec::new();
 | 
			
		||||
        let mut removed: Vec<NetworkMetadata> = Vec::new();
 | 
			
		||||
 | 
			
		||||
        for network in self.read().await? {
 | 
			
		||||
            seen.push(network.uuid);
 | 
			
		||||
            if self.known.contains_key(&network.uuid) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            let _ = self.known.insert(network.uuid, network.clone());
 | 
			
		||||
            added.push(network);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let mut gone: Vec<Uuid> = Vec::new();
 | 
			
		||||
        for uuid in self.known.keys() {
 | 
			
		||||
            if seen.contains(uuid) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            gone.push(*uuid);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for uuid in &gone {
 | 
			
		||||
            let Some(network) = self.known.remove(uuid) else {
 | 
			
		||||
                continue;
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            removed.push(network);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(AutoNetworkChangeset { added, removed })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn mark_unknown(&mut self, uuid: Uuid) -> Result<bool> {
 | 
			
		||||
        Ok(self.known.remove(&uuid).is_some())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										175
									
								
								crates/kratanet/src/backend.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								crates/kratanet/src/backend.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,175 @@
 | 
			
		||||
use crate::autonet::NetworkMetadata;
 | 
			
		||||
use crate::chandev::ChannelDevice;
 | 
			
		||||
use crate::nat::Nat;
 | 
			
		||||
use crate::proxynat::ProxyNatHandlerFactory;
 | 
			
		||||
use crate::raw_socket::{AsyncRawSocketChannel, RawSocketHandle, RawSocketProtocol};
 | 
			
		||||
use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
 | 
			
		||||
use crate::EXTRA_MTU;
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use futures::TryStreamExt;
 | 
			
		||||
use log::{info, trace, warn};
 | 
			
		||||
use smoltcp::iface::{Config, Interface, SocketSet};
 | 
			
		||||
use smoltcp::phy::Medium;
 | 
			
		||||
use smoltcp::time::Instant;
 | 
			
		||||
use smoltcp::wire::{HardwareAddress, IpCidr};
 | 
			
		||||
use tokio::select;
 | 
			
		||||
use tokio::sync::mpsc::{channel, Receiver};
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
const TX_CHANNEL_BUFFER_LEN: usize = 3000;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct NetworkBackend {
 | 
			
		||||
    metadata: NetworkMetadata,
 | 
			
		||||
    bridge: VirtualBridge,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
enum NetworkStackSelect {
 | 
			
		||||
    Receive(Option<BytesMut>),
 | 
			
		||||
    Send(Option<BytesMut>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct NetworkStack<'a> {
 | 
			
		||||
    tx: Receiver<BytesMut>,
 | 
			
		||||
    kdev: AsyncRawSocketChannel,
 | 
			
		||||
    udev: ChannelDevice,
 | 
			
		||||
    interface: Interface,
 | 
			
		||||
    sockets: SocketSet<'a>,
 | 
			
		||||
    nat: Nat,
 | 
			
		||||
    bridge: BridgeJoinHandle,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkStack<'_> {
 | 
			
		||||
    async fn poll(&mut self) -> Result<bool> {
 | 
			
		||||
        let what = select! {
 | 
			
		||||
            x = self.kdev.receiver.recv() => NetworkStackSelect::Receive(x),
 | 
			
		||||
            x = self.bridge.from_bridge_receiver.recv() => NetworkStackSelect::Send(x),
 | 
			
		||||
            x = self.bridge.from_broadcast_receiver.recv() => NetworkStackSelect::Send(x.ok()),
 | 
			
		||||
            x = self.tx.recv() => NetworkStackSelect::Send(x),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        match what {
 | 
			
		||||
            NetworkStackSelect::Receive(Some(packet)) => {
 | 
			
		||||
                if let Err(error) = self.bridge.to_bridge_sender.try_send(packet.clone()) {
 | 
			
		||||
                    trace!("failed to send guest packet to bridge: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if let Err(error) = self.nat.receive_sender.try_send(packet.clone()) {
 | 
			
		||||
                    trace!("failed to send guest packet to nat: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                self.udev.rx = Some(packet);
 | 
			
		||||
                self.interface
 | 
			
		||||
                    .poll(Instant::now(), &mut self.udev, &mut self.sockets);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            NetworkStackSelect::Send(Some(packet)) => {
 | 
			
		||||
                if let Err(error) = self.kdev.sender.try_send(packet) {
 | 
			
		||||
                    warn!("failed to transmit packet to interface: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            NetworkStackSelect::Receive(None) | NetworkStackSelect::Send(None) => {
 | 
			
		||||
                return Ok(false);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(true)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkBackend {
 | 
			
		||||
    pub fn new(metadata: NetworkMetadata, bridge: VirtualBridge) -> Result<Self> {
 | 
			
		||||
        Ok(Self { metadata, bridge })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn init(&mut self) -> Result<()> {
 | 
			
		||||
        let interface = self.metadata.interface();
 | 
			
		||||
        let (connection, handle, _) = rtnetlink::new_connection()?;
 | 
			
		||||
        tokio::spawn(connection);
 | 
			
		||||
 | 
			
		||||
        let mut links = handle.link().get().match_name(interface.clone()).execute();
 | 
			
		||||
        let link = links.try_next().await?;
 | 
			
		||||
        if link.is_none() {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "unable to find network interface named {}",
 | 
			
		||||
                interface
 | 
			
		||||
            ));
 | 
			
		||||
        }
 | 
			
		||||
        let link = link.unwrap();
 | 
			
		||||
        handle.link().set(link.header.index).up().execute().await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn run(&self) -> Result<()> {
 | 
			
		||||
        let mut stack = self.create_network_stack().await?;
 | 
			
		||||
        loop {
 | 
			
		||||
            if !stack.poll().await? {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn create_network_stack(&self) -> Result<NetworkStack> {
 | 
			
		||||
        let interface = self.metadata.interface();
 | 
			
		||||
        let proxy = Box::new(ProxyNatHandlerFactory::new());
 | 
			
		||||
        let addresses: Vec<IpCidr> = vec![
 | 
			
		||||
            self.metadata.gateway.ipv4.into(),
 | 
			
		||||
            self.metadata.gateway.ipv6.into(),
 | 
			
		||||
        ];
 | 
			
		||||
        let mut kdev =
 | 
			
		||||
            RawSocketHandle::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?;
 | 
			
		||||
        let mtu = kdev.mtu_of_interface(&interface)? + EXTRA_MTU;
 | 
			
		||||
        let (tx_sender, tx_receiver) = channel::<BytesMut>(TX_CHANNEL_BUFFER_LEN);
 | 
			
		||||
        let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone());
 | 
			
		||||
        let mac = self.metadata.gateway.mac;
 | 
			
		||||
        let nat = Nat::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone())?;
 | 
			
		||||
        let hardware_addr = HardwareAddress::Ethernet(mac);
 | 
			
		||||
        let config = Config::new(hardware_addr);
 | 
			
		||||
        let mut iface = Interface::new(config, &mut udev, Instant::now());
 | 
			
		||||
        iface.update_ip_addrs(|addrs| {
 | 
			
		||||
            addrs
 | 
			
		||||
                .extend_from_slice(&addresses)
 | 
			
		||||
                .expect("failed to set ip addresses");
 | 
			
		||||
        });
 | 
			
		||||
        let sockets = SocketSet::new(vec![]);
 | 
			
		||||
        let handle = self.bridge.join(self.metadata.guest.mac).await?;
 | 
			
		||||
        let kdev = AsyncRawSocketChannel::new(mtu, kdev)?;
 | 
			
		||||
        Ok(NetworkStack {
 | 
			
		||||
            tx: tx_receiver,
 | 
			
		||||
            kdev,
 | 
			
		||||
            udev,
 | 
			
		||||
            interface: iface,
 | 
			
		||||
            sockets,
 | 
			
		||||
            nat,
 | 
			
		||||
            bridge: handle,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn launch(self) -> Result<JoinHandle<()>> {
 | 
			
		||||
        Ok(tokio::task::spawn(async move {
 | 
			
		||||
            info!(
 | 
			
		||||
                "lauched network backend for krata guest {}",
 | 
			
		||||
                self.metadata.uuid
 | 
			
		||||
            );
 | 
			
		||||
            if let Err(error) = self.run().await {
 | 
			
		||||
                warn!(
 | 
			
		||||
                    "network backend for krata guest {} failed: {}",
 | 
			
		||||
                    self.metadata.uuid, error
 | 
			
		||||
                );
 | 
			
		||||
            }
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for NetworkBackend {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        info!(
 | 
			
		||||
            "destroyed network backend for krata guest {}",
 | 
			
		||||
            self.metadata.uuid
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										89
									
								
								crates/kratanet/src/chandev.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								crates/kratanet/src/chandev.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,89 @@
 | 
			
		||||
// Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/channelized_smoltcp_device.rs
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use smoltcp::phy::{Checksum, Device, Medium};
 | 
			
		||||
use tokio::sync::mpsc::Sender;
 | 
			
		||||
 | 
			
		||||
const TEAR_OFF_BUFFER_SIZE: usize = 65536;
 | 
			
		||||
 | 
			
		||||
pub struct ChannelDevice {
 | 
			
		||||
    pub mtu: usize,
 | 
			
		||||
    pub medium: Medium,
 | 
			
		||||
    pub tx: Sender<BytesMut>,
 | 
			
		||||
    pub rx: Option<BytesMut>,
 | 
			
		||||
    tear_off_buffer: BytesMut,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ChannelDevice {
 | 
			
		||||
    pub fn new(mtu: usize, medium: Medium, tx: Sender<BytesMut>) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            mtu,
 | 
			
		||||
            medium,
 | 
			
		||||
            tx,
 | 
			
		||||
            rx: None,
 | 
			
		||||
            tear_off_buffer: BytesMut::with_capacity(TEAR_OFF_BUFFER_SIZE),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct RxToken(pub BytesMut);
 | 
			
		||||
 | 
			
		||||
impl Device for ChannelDevice {
 | 
			
		||||
    type RxToken<'a> = RxToken where Self: 'a;
 | 
			
		||||
    type TxToken<'a> = &'a mut ChannelDevice where Self: 'a;
 | 
			
		||||
 | 
			
		||||
    fn receive(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        _timestamp: smoltcp::time::Instant,
 | 
			
		||||
    ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
 | 
			
		||||
        self.rx.take().map(|x| (RxToken(x), self))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
 | 
			
		||||
        if self.tx.capacity() == 0 {
 | 
			
		||||
            debug!("ran out of transmission capacity");
 | 
			
		||||
            return None;
 | 
			
		||||
        }
 | 
			
		||||
        Some(self)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities {
 | 
			
		||||
        let mut capabilities = smoltcp::phy::DeviceCapabilities::default();
 | 
			
		||||
        capabilities.medium = self.medium;
 | 
			
		||||
        capabilities.max_transmission_unit = self.mtu;
 | 
			
		||||
        capabilities.checksum = smoltcp::phy::ChecksumCapabilities::ignored();
 | 
			
		||||
        capabilities.checksum.tcp = Checksum::Tx;
 | 
			
		||||
        capabilities.checksum.ipv4 = Checksum::Tx;
 | 
			
		||||
        capabilities.checksum.icmpv4 = Checksum::Tx;
 | 
			
		||||
        capabilities.checksum.icmpv6 = Checksum::Tx;
 | 
			
		||||
        capabilities
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl smoltcp::phy::RxToken for RxToken {
 | 
			
		||||
    fn consume<R, F>(mut self, f: F) -> R
 | 
			
		||||
    where
 | 
			
		||||
        F: FnOnce(&mut [u8]) -> R,
 | 
			
		||||
    {
 | 
			
		||||
        f(&mut self.0[..])
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> smoltcp::phy::TxToken for &'a mut ChannelDevice {
 | 
			
		||||
    fn consume<R, F>(self, len: usize, f: F) -> R
 | 
			
		||||
    where
 | 
			
		||||
        F: FnOnce(&mut [u8]) -> R,
 | 
			
		||||
    {
 | 
			
		||||
        self.tear_off_buffer.resize(len, 0);
 | 
			
		||||
        let result = f(&mut self.tear_off_buffer[..]);
 | 
			
		||||
        let chunk = self.tear_off_buffer.split();
 | 
			
		||||
        if let Err(error) = self.tx.try_send(chunk) {
 | 
			
		||||
            warn!("failed to transmit packet: {}", error);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.tear_off_buffer.capacity() < self.mtu {
 | 
			
		||||
            self.tear_off_buffer = BytesMut::with_capacity(TEAR_OFF_BUFFER_SIZE);
 | 
			
		||||
        }
 | 
			
		||||
        result
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										145
									
								
								crates/kratanet/src/hbridge.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								crates/kratanet/src/hbridge.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,145 @@
 | 
			
		||||
use std::net::{IpAddr, Ipv4Addr};
 | 
			
		||||
 | 
			
		||||
use advmac::MacAddr6;
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use futures::TryStreamExt;
 | 
			
		||||
use log::error;
 | 
			
		||||
use smoltcp::wire::EthernetAddress;
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncReadExt, AsyncWriteExt},
 | 
			
		||||
    select,
 | 
			
		||||
    sync::mpsc::channel,
 | 
			
		||||
    task::JoinHandle,
 | 
			
		||||
};
 | 
			
		||||
use tokio_tun::Tun;
 | 
			
		||||
 | 
			
		||||
use crate::vbridge::{BridgeJoinHandle, VirtualBridge};
 | 
			
		||||
 | 
			
		||||
const RX_BUFFER_QUEUE_LEN: usize = 100;
 | 
			
		||||
const HOST_IPV4_ADDR: Ipv4Addr = Ipv4Addr::new(10, 75, 0, 1);
 | 
			
		||||
 | 
			
		||||
pub struct HostBridge {
 | 
			
		||||
    task: JoinHandle<()>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl HostBridge {
 | 
			
		||||
    pub async fn new(mtu: usize, interface: String, bridge: &VirtualBridge) -> Result<HostBridge> {
 | 
			
		||||
        let tun = Tun::builder()
 | 
			
		||||
            .name(&interface)
 | 
			
		||||
            .tap(true)
 | 
			
		||||
            .mtu(mtu as i32)
 | 
			
		||||
            .packet_info(false)
 | 
			
		||||
            .try_build()?;
 | 
			
		||||
 | 
			
		||||
        let (connection, handle, _) = rtnetlink::new_connection()?;
 | 
			
		||||
        tokio::spawn(connection);
 | 
			
		||||
 | 
			
		||||
        let mut mac = MacAddr6::random();
 | 
			
		||||
        mac.set_local(true);
 | 
			
		||||
        mac.set_multicast(false);
 | 
			
		||||
 | 
			
		||||
        let mut links = handle.link().get().match_name(interface.clone()).execute();
 | 
			
		||||
        let link = links.try_next().await?;
 | 
			
		||||
        if link.is_none() {
 | 
			
		||||
            return Err(anyhow!(
 | 
			
		||||
                "unable to find network interface named {}",
 | 
			
		||||
                interface
 | 
			
		||||
            ));
 | 
			
		||||
        }
 | 
			
		||||
        let link = link.unwrap();
 | 
			
		||||
 | 
			
		||||
        handle
 | 
			
		||||
            .address()
 | 
			
		||||
            .add(link.header.index, IpAddr::V4(HOST_IPV4_ADDR), 16)
 | 
			
		||||
            .execute()
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
        handle
 | 
			
		||||
            .address()
 | 
			
		||||
            .add(link.header.index, IpAddr::V6(mac.to_link_local_ipv6()), 10)
 | 
			
		||||
            .execute()
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
        handle
 | 
			
		||||
            .link()
 | 
			
		||||
            .set(link.header.index)
 | 
			
		||||
            .address(mac.to_array().to_vec())
 | 
			
		||||
            .up()
 | 
			
		||||
            .execute()
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
        let mac = EthernetAddress(mac.to_array());
 | 
			
		||||
        let bridge_handle = bridge.join(mac).await?;
 | 
			
		||||
 | 
			
		||||
        let task = tokio::task::spawn(async move {
 | 
			
		||||
            if let Err(error) = HostBridge::process(mtu, tun, bridge_handle).await {
 | 
			
		||||
                error!("failed to process host bridge: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        Ok(HostBridge { task })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(mtu: usize, tun: Tun, mut bridge_handle: BridgeJoinHandle) -> Result<()> {
 | 
			
		||||
        let (rx_sender, mut rx_receiver) = channel::<BytesMut>(RX_BUFFER_QUEUE_LEN);
 | 
			
		||||
        let (mut read, mut write) = tokio::io::split(tun);
 | 
			
		||||
        tokio::task::spawn(async move {
 | 
			
		||||
            let mut buffer = vec![0u8; mtu];
 | 
			
		||||
            loop {
 | 
			
		||||
                let size = match read.read(&mut buffer).await {
 | 
			
		||||
                    Ok(size) => size,
 | 
			
		||||
                    Err(error) => {
 | 
			
		||||
                        error!("failed to read tap device: {}", error);
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
                match rx_sender.send(buffer[0..size].into()).await {
 | 
			
		||||
                    Ok(_) => {}
 | 
			
		||||
                    Err(error) => {
 | 
			
		||||
                        error!(
 | 
			
		||||
                            "failed to send data from tap device to processor: {}",
 | 
			
		||||
                            error
 | 
			
		||||
                        );
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
        loop {
 | 
			
		||||
            select! {
 | 
			
		||||
                x = bridge_handle.from_bridge_receiver.recv() => match x {
 | 
			
		||||
                    Some(bytes) => {
 | 
			
		||||
                        write.write_all(&bytes).await?;
 | 
			
		||||
                    },
 | 
			
		||||
                    None => {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                },
 | 
			
		||||
                x = bridge_handle.from_broadcast_receiver.recv() => match x {
 | 
			
		||||
                    Ok(bytes) => {
 | 
			
		||||
                        write.write_all(&bytes).await?;
 | 
			
		||||
                    },
 | 
			
		||||
                    Err(error) => {
 | 
			
		||||
                        return Err(error.into());
 | 
			
		||||
                    }
 | 
			
		||||
                },
 | 
			
		||||
                x = rx_receiver.recv() => match x {
 | 
			
		||||
                    Some(bytes) => {
 | 
			
		||||
                        bridge_handle.to_bridge_sender.send(bytes).await?;
 | 
			
		||||
                    },
 | 
			
		||||
                    None => {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for HostBridge {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        self.task.abort();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										250
									
								
								crates/kratanet/src/icmp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								crates/kratanet/src/icmp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,250 @@
 | 
			
		||||
use crate::raw_socket::{RawSocketHandle, RawSocketProtocol};
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use etherparse::{
 | 
			
		||||
    IcmpEchoHeader, Icmpv4Header, Icmpv4Slice, Icmpv4Type, Icmpv6Header, Icmpv6Slice, Icmpv6Type,
 | 
			
		||||
    IpNumber, NetSlice, SlicedPacket,
 | 
			
		||||
};
 | 
			
		||||
use log::warn;
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::HashMap,
 | 
			
		||||
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
 | 
			
		||||
    os::fd::{FromRawFd, IntoRawFd},
 | 
			
		||||
    sync::Arc,
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    net::UdpSocket,
 | 
			
		||||
    sync::{oneshot, Mutex},
 | 
			
		||||
    task::JoinHandle,
 | 
			
		||||
    time::timeout,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum IcmpProtocol {
 | 
			
		||||
    Icmpv4,
 | 
			
		||||
    Icmpv6,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IcmpProtocol {
 | 
			
		||||
    pub fn to_socket_protocol(&self) -> RawSocketProtocol {
 | 
			
		||||
        match self {
 | 
			
		||||
            IcmpProtocol::Icmpv4 => RawSocketProtocol::Icmpv4,
 | 
			
		||||
            IcmpProtocol::Icmpv6 => RawSocketProtocol::Icmpv6,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 | 
			
		||||
struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum IcmpReply {
 | 
			
		||||
    Icmpv4 {
 | 
			
		||||
        header: Icmpv4Header,
 | 
			
		||||
        echo: IcmpEchoHeader,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    Icmpv6 {
 | 
			
		||||
        header: Icmpv6Header,
 | 
			
		||||
        echo: IcmpEchoHeader,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct IcmpClient {
 | 
			
		||||
    socket: Arc<UdpSocket>,
 | 
			
		||||
    handlers: IcmpHandlerMap,
 | 
			
		||||
    task: Arc<JoinHandle<Result<()>>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IcmpClient {
 | 
			
		||||
    pub fn new(protocol: IcmpProtocol) -> Result<IcmpClient> {
 | 
			
		||||
        let handle = RawSocketHandle::new(protocol.to_socket_protocol())?;
 | 
			
		||||
        let socket = unsafe { std::net::UdpSocket::from_raw_fd(handle.into_raw_fd()) };
 | 
			
		||||
        let socket: Arc<UdpSocket> = Arc::new(socket.try_into()?);
 | 
			
		||||
        let handlers = Arc::new(Mutex::new(HashMap::new()));
 | 
			
		||||
        let task = Arc::new(tokio::task::spawn(IcmpClient::process(
 | 
			
		||||
            protocol,
 | 
			
		||||
            socket.clone(),
 | 
			
		||||
            handlers.clone(),
 | 
			
		||||
        )));
 | 
			
		||||
        Ok(IcmpClient {
 | 
			
		||||
            socket,
 | 
			
		||||
            handlers,
 | 
			
		||||
            task,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        protocol: IcmpProtocol,
 | 
			
		||||
        socket: Arc<UdpSocket>,
 | 
			
		||||
        handlers: IcmpHandlerMap,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let mut buffer = vec![0u8; 2048];
 | 
			
		||||
        loop {
 | 
			
		||||
            let (size, addr) = socket.recv_from(&mut buffer).await?;
 | 
			
		||||
            let packet = &buffer[0..size];
 | 
			
		||||
 | 
			
		||||
            let (token, reply) = match protocol {
 | 
			
		||||
                IcmpProtocol::Icmpv4 => {
 | 
			
		||||
                    let sliced = match SlicedPacket::from_ip(packet) {
 | 
			
		||||
                        Ok(sliced) => sliced,
 | 
			
		||||
                        Err(error) => {
 | 
			
		||||
                            warn!("received icmp packet but failed to parse it: {}", error);
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let Some(NetSlice::Ipv4(ipv4)) = sliced.net else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    if ipv4.header().protocol() != IpNumber::ICMP {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    let Ok(icmpv4) = Icmpv4Slice::from_slice(ipv4.payload().payload) else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let Icmpv4Type::EchoReply(echo) = icmpv4.header().icmp_type else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let token = IcmpHandlerToken(
 | 
			
		||||
                        IpAddr::V4(ipv4.header().source_addr()),
 | 
			
		||||
                        Some(echo.id),
 | 
			
		||||
                        echo.seq,
 | 
			
		||||
                    );
 | 
			
		||||
                    let reply = IcmpReply::Icmpv4 {
 | 
			
		||||
                        header: icmpv4.header(),
 | 
			
		||||
                        echo,
 | 
			
		||||
                        payload: icmpv4.payload().to_vec(),
 | 
			
		||||
                    };
 | 
			
		||||
                    (token, reply)
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                IcmpProtocol::Icmpv6 => {
 | 
			
		||||
                    let Ok(icmpv6) = Icmpv6Slice::from_slice(packet) else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let Icmpv6Type::EchoReply(echo) = icmpv6.header().icmp_type else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let SocketAddr::V6(addr) = addr else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
 | 
			
		||||
 | 
			
		||||
                    let reply = IcmpReply::Icmpv6 {
 | 
			
		||||
                        header: icmpv6.header(),
 | 
			
		||||
                        echo,
 | 
			
		||||
                        payload: icmpv6.payload().to_vec(),
 | 
			
		||||
                    };
 | 
			
		||||
                    (token, reply)
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            if let Some(sender) = handlers.lock().await.remove(&token) {
 | 
			
		||||
                let _ = sender.send(reply);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn add_handler(&self, token: IcmpHandlerToken) -> Result<oneshot::Receiver<IcmpReply>> {
 | 
			
		||||
        let (tx, rx) = oneshot::channel();
 | 
			
		||||
        if self
 | 
			
		||||
            .handlers
 | 
			
		||||
            .lock()
 | 
			
		||||
            .await
 | 
			
		||||
            .insert(token.clone(), tx)
 | 
			
		||||
            .is_some()
 | 
			
		||||
        {
 | 
			
		||||
            return Err(anyhow!("duplicate icmp request: {:?}", token));
 | 
			
		||||
        }
 | 
			
		||||
        Ok(rx)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn remove_handler(&self, token: IcmpHandlerToken) -> Result<()> {
 | 
			
		||||
        self.handlers.lock().await.remove(&token);
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn ping4(
 | 
			
		||||
        &self,
 | 
			
		||||
        addr: Ipv4Addr,
 | 
			
		||||
        id: u16,
 | 
			
		||||
        seq: u16,
 | 
			
		||||
        payload: &[u8],
 | 
			
		||||
        deadline: Duration,
 | 
			
		||||
    ) -> Result<Option<IcmpReply>> {
 | 
			
		||||
        let token = IcmpHandlerToken(IpAddr::V4(addr), Some(id), seq);
 | 
			
		||||
        let rx = self.add_handler(token.clone()).await?;
 | 
			
		||||
 | 
			
		||||
        let echo = IcmpEchoHeader { id, seq };
 | 
			
		||||
        let mut header = Icmpv4Header::new(Icmpv4Type::EchoRequest(echo));
 | 
			
		||||
        header.update_checksum(payload);
 | 
			
		||||
        let mut buffer: Vec<u8> = Vec::new();
 | 
			
		||||
        header.write(&mut buffer)?;
 | 
			
		||||
        buffer.extend_from_slice(payload);
 | 
			
		||||
 | 
			
		||||
        self.socket
 | 
			
		||||
            .send_to(&buffer, SocketAddr::V4(SocketAddrV4::new(addr, 0)))
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
        let result = timeout(deadline, rx).await;
 | 
			
		||||
        self.remove_handler(token).await?;
 | 
			
		||||
        let reply = match result {
 | 
			
		||||
            Ok(Ok(packet)) => Some(packet),
 | 
			
		||||
            Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
 | 
			
		||||
            Err(_) => None,
 | 
			
		||||
        };
 | 
			
		||||
        Ok(reply)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn ping6(
 | 
			
		||||
        &self,
 | 
			
		||||
        addr: Ipv6Addr,
 | 
			
		||||
        id: u16,
 | 
			
		||||
        seq: u16,
 | 
			
		||||
        payload: &[u8],
 | 
			
		||||
        deadline: Duration,
 | 
			
		||||
    ) -> Result<Option<IcmpReply>> {
 | 
			
		||||
        let token = IcmpHandlerToken(IpAddr::V6(addr), Some(id), seq);
 | 
			
		||||
        let rx = self.add_handler(token.clone()).await?;
 | 
			
		||||
 | 
			
		||||
        let echo = IcmpEchoHeader { id, seq };
 | 
			
		||||
        let header = Icmpv6Header::new(Icmpv6Type::EchoRequest(echo));
 | 
			
		||||
        let mut buffer: Vec<u8> = Vec::new();
 | 
			
		||||
        header.write(&mut buffer)?;
 | 
			
		||||
        buffer.extend_from_slice(payload);
 | 
			
		||||
 | 
			
		||||
        self.socket
 | 
			
		||||
            .send_to(&buffer, SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)))
 | 
			
		||||
            .await?;
 | 
			
		||||
 | 
			
		||||
        let result = timeout(deadline, rx).await;
 | 
			
		||||
        self.remove_handler(token).await?;
 | 
			
		||||
        let reply = match result {
 | 
			
		||||
            Ok(Ok(packet)) => Some(packet),
 | 
			
		||||
            Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
 | 
			
		||||
            Err(_) => None,
 | 
			
		||||
        };
 | 
			
		||||
        Ok(reply)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for IcmpClient {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        if Arc::strong_count(&self.task) <= 1 {
 | 
			
		||||
            self.task.abort();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										117
									
								
								crates/kratanet/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								crates/kratanet/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,117 @@
 | 
			
		||||
use std::{collections::HashMap, time::Duration};
 | 
			
		||||
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use autonet::{AutoNetworkChangeset, AutoNetworkCollector, NetworkMetadata};
 | 
			
		||||
use futures::{future::join_all, TryFutureExt};
 | 
			
		||||
use hbridge::HostBridge;
 | 
			
		||||
use log::warn;
 | 
			
		||||
use tokio::{task::JoinHandle, time::sleep};
 | 
			
		||||
use uuid::Uuid;
 | 
			
		||||
use vbridge::VirtualBridge;
 | 
			
		||||
 | 
			
		||||
use crate::backend::NetworkBackend;
 | 
			
		||||
 | 
			
		||||
pub mod autonet;
 | 
			
		||||
pub mod backend;
 | 
			
		||||
pub mod chandev;
 | 
			
		||||
pub mod hbridge;
 | 
			
		||||
pub mod icmp;
 | 
			
		||||
pub mod nat;
 | 
			
		||||
pub mod pkt;
 | 
			
		||||
pub mod proxynat;
 | 
			
		||||
pub mod raw_socket;
 | 
			
		||||
pub mod vbridge;
 | 
			
		||||
 | 
			
		||||
const HOST_BRIDGE_MTU: usize = 1500;
 | 
			
		||||
pub const EXTRA_MTU: usize = 20;
 | 
			
		||||
 | 
			
		||||
pub struct NetworkService {
 | 
			
		||||
    pub backends: HashMap<Uuid, JoinHandle<()>>,
 | 
			
		||||
    pub bridge: VirtualBridge,
 | 
			
		||||
    pub hbridge: HostBridge,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkService {
 | 
			
		||||
    pub async fn new() -> Result<NetworkService> {
 | 
			
		||||
        let bridge = VirtualBridge::new()?;
 | 
			
		||||
        let hbridge =
 | 
			
		||||
            HostBridge::new(HOST_BRIDGE_MTU + EXTRA_MTU, "krata0".to_string(), &bridge).await?;
 | 
			
		||||
        Ok(NetworkService {
 | 
			
		||||
            backends: HashMap::new(),
 | 
			
		||||
            bridge,
 | 
			
		||||
            hbridge,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkService {
 | 
			
		||||
    pub async fn watch(&mut self) -> Result<()> {
 | 
			
		||||
        let mut collector = AutoNetworkCollector::new().await?;
 | 
			
		||||
        loop {
 | 
			
		||||
            let changeset = collector.read_changes().await?;
 | 
			
		||||
            self.process_network_changeset(&mut collector, changeset)
 | 
			
		||||
                .await?;
 | 
			
		||||
            sleep(Duration::from_secs(2)).await;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process_network_changeset(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        collector: &mut AutoNetworkCollector,
 | 
			
		||||
        changeset: AutoNetworkChangeset,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        for removal in &changeset.removed {
 | 
			
		||||
            if let Some(handle) = self.backends.remove(&removal.uuid) {
 | 
			
		||||
                handle.abort();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let futures = changeset
 | 
			
		||||
            .added
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|metadata| {
 | 
			
		||||
                self.add_network_backend(metadata)
 | 
			
		||||
                    .map_err(|x| (metadata.clone(), x))
 | 
			
		||||
            })
 | 
			
		||||
            .collect::<Vec<_>>();
 | 
			
		||||
 | 
			
		||||
        sleep(Duration::from_secs(1)).await;
 | 
			
		||||
        let mut failed: Vec<Uuid> = Vec::new();
 | 
			
		||||
        let mut launched: Vec<(Uuid, JoinHandle<()>)> = Vec::new();
 | 
			
		||||
        let results = join_all(futures).await;
 | 
			
		||||
        for result in results {
 | 
			
		||||
            match result {
 | 
			
		||||
                Ok(launch) => {
 | 
			
		||||
                    launched.push(launch);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                Err((metadata, error)) => {
 | 
			
		||||
                    warn!(
 | 
			
		||||
                        "failed to launch network backend for krata guest {}: {}",
 | 
			
		||||
                        metadata.uuid, error
 | 
			
		||||
                    );
 | 
			
		||||
                    failed.push(metadata.uuid);
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (uuid, handle) in launched {
 | 
			
		||||
            self.backends.insert(uuid, handle);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for uuid in failed {
 | 
			
		||||
            collector.mark_unknown(uuid)?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn add_network_backend(
 | 
			
		||||
        &self,
 | 
			
		||||
        metadata: &NetworkMetadata,
 | 
			
		||||
    ) -> Result<(Uuid, JoinHandle<()>)> {
 | 
			
		||||
        let mut network = NetworkBackend::new(metadata.clone(), self.bridge.clone())?;
 | 
			
		||||
        network.init().await?;
 | 
			
		||||
        Ok((metadata.uuid, network.launch().await?))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								crates/kratanet/src/nat/handler.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								crates/kratanet/src/nat/handler.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use tokio::sync::mpsc::Sender;
 | 
			
		||||
 | 
			
		||||
use super::key::NatKey;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct NatHandlerContext {
 | 
			
		||||
    pub mtu: usize,
 | 
			
		||||
    pub key: NatKey,
 | 
			
		||||
    pub transmit_sender: Sender<BytesMut>,
 | 
			
		||||
    pub reclaim_sender: Sender<NatKey>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NatHandlerContext {
 | 
			
		||||
    pub fn try_transmit(&self, buffer: BytesMut) -> Result<()> {
 | 
			
		||||
        self.transmit_sender.try_send(buffer)?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn reclaim(&self) -> Result<()> {
 | 
			
		||||
        self.reclaim_sender.try_send(self.key)?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait NatHandler: Send {
 | 
			
		||||
    async fn receive(&self, packet: &[u8]) -> Result<bool>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
pub trait NatHandlerFactory: Send {
 | 
			
		||||
    async fn nat(&self, context: NatHandlerContext) -> Option<Box<dyn NatHandler>>;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										29
									
								
								crates/kratanet/src/nat/key.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								crates/kratanet/src/nat/key.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
use std::fmt::Display;
 | 
			
		||||
 | 
			
		||||
use smoltcp::wire::{EthernetAddress, IpEndpoint};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
 | 
			
		||||
pub enum NatKeyProtocol {
 | 
			
		||||
    Tcp,
 | 
			
		||||
    Udp,
 | 
			
		||||
    Icmp,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
 | 
			
		||||
pub struct NatKey {
 | 
			
		||||
    pub protocol: NatKeyProtocol,
 | 
			
		||||
    pub client_mac: EthernetAddress,
 | 
			
		||||
    pub local_mac: EthernetAddress,
 | 
			
		||||
    pub client_ip: IpEndpoint,
 | 
			
		||||
    pub external_ip: IpEndpoint,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Display for NatKey {
 | 
			
		||||
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 | 
			
		||||
        write!(
 | 
			
		||||
            f,
 | 
			
		||||
            "{} -> {} {:?} {} -> {}",
 | 
			
		||||
            self.client_mac, self.local_mac, self.protocol, self.client_ip, self.external_ip
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										42
									
								
								crates/kratanet/src/nat/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								crates/kratanet/src/nat/mod.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use tokio::sync::mpsc::Sender;
 | 
			
		||||
 | 
			
		||||
use self::handler::NatHandlerFactory;
 | 
			
		||||
use self::processor::NatProcessor;
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use smoltcp::wire::EthernetAddress;
 | 
			
		||||
use smoltcp::wire::IpCidr;
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
pub mod handler;
 | 
			
		||||
pub mod key;
 | 
			
		||||
pub mod processor;
 | 
			
		||||
pub mod table;
 | 
			
		||||
 | 
			
		||||
pub struct Nat {
 | 
			
		||||
    pub receive_sender: Sender<BytesMut>,
 | 
			
		||||
    task: JoinHandle<()>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Nat {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        mtu: usize,
 | 
			
		||||
        factory: Box<dyn NatHandlerFactory>,
 | 
			
		||||
        local_mac: EthernetAddress,
 | 
			
		||||
        local_cidrs: Vec<IpCidr>,
 | 
			
		||||
        transmit_sender: Sender<BytesMut>,
 | 
			
		||||
    ) -> Result<Self> {
 | 
			
		||||
        let (receive_sender, task) =
 | 
			
		||||
            NatProcessor::launch(mtu, factory, local_mac, local_cidrs, transmit_sender)?;
 | 
			
		||||
        Ok(Self {
 | 
			
		||||
            receive_sender,
 | 
			
		||||
            task,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for Nat {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        self.task.abort();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										330
									
								
								crates/kratanet/src/nat/processor.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										330
									
								
								crates/kratanet/src/nat/processor.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,330 @@
 | 
			
		||||
use crate::pkt::RecvPacket;
 | 
			
		||||
use crate::pkt::RecvPacketIp;
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use etherparse::Icmpv4Header;
 | 
			
		||||
use etherparse::Icmpv4Type;
 | 
			
		||||
use etherparse::Icmpv6Header;
 | 
			
		||||
use etherparse::Icmpv6Type;
 | 
			
		||||
use etherparse::IpNumber;
 | 
			
		||||
use etherparse::IpPayloadSlice;
 | 
			
		||||
use etherparse::Ipv4Slice;
 | 
			
		||||
use etherparse::Ipv6Slice;
 | 
			
		||||
use etherparse::SlicedPacket;
 | 
			
		||||
use etherparse::TcpHeaderSlice;
 | 
			
		||||
use etherparse::UdpHeaderSlice;
 | 
			
		||||
use log::warn;
 | 
			
		||||
use log::{debug, trace};
 | 
			
		||||
use smoltcp::wire::EthernetAddress;
 | 
			
		||||
use smoltcp::wire::IpAddress;
 | 
			
		||||
use smoltcp::wire::IpCidr;
 | 
			
		||||
use smoltcp::wire::IpEndpoint;
 | 
			
		||||
use std::collections::hash_map::Entry;
 | 
			
		||||
use tokio::select;
 | 
			
		||||
use tokio::sync::mpsc::channel;
 | 
			
		||||
use tokio::sync::mpsc::Receiver;
 | 
			
		||||
use tokio::sync::mpsc::Sender;
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
use super::handler::NatHandler;
 | 
			
		||||
use super::handler::NatHandlerContext;
 | 
			
		||||
use super::handler::NatHandlerFactory;
 | 
			
		||||
use super::key::NatKey;
 | 
			
		||||
use super::key::NatKeyProtocol;
 | 
			
		||||
use super::table::NatTable;
 | 
			
		||||
 | 
			
		||||
const RECEIVE_CHANNEL_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const RECLAIM_CHANNEL_QUEUE_LEN: usize = 30;
 | 
			
		||||
 | 
			
		||||
pub struct NatProcessor {
 | 
			
		||||
    mtu: usize,
 | 
			
		||||
    local_mac: EthernetAddress,
 | 
			
		||||
    local_cidrs: Vec<IpCidr>,
 | 
			
		||||
    table: NatTable,
 | 
			
		||||
    factory: Box<dyn NatHandlerFactory>,
 | 
			
		||||
    transmit_sender: Sender<BytesMut>,
 | 
			
		||||
    reclaim_sender: Sender<NatKey>,
 | 
			
		||||
    reclaim_receiver: Receiver<NatKey>,
 | 
			
		||||
    receive_receiver: Receiver<BytesMut>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum NatProcessorSelect {
 | 
			
		||||
    Reclaim(Option<NatKey>),
 | 
			
		||||
    ReceivedPacket(Option<BytesMut>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NatProcessor {
 | 
			
		||||
    pub fn launch(
 | 
			
		||||
        mtu: usize,
 | 
			
		||||
        factory: Box<dyn NatHandlerFactory>,
 | 
			
		||||
        local_mac: EthernetAddress,
 | 
			
		||||
        local_cidrs: Vec<IpCidr>,
 | 
			
		||||
        transmit_sender: Sender<BytesMut>,
 | 
			
		||||
    ) -> Result<(Sender<BytesMut>, JoinHandle<()>)> {
 | 
			
		||||
        let (reclaim_sender, reclaim_receiver) = channel(RECLAIM_CHANNEL_QUEUE_LEN);
 | 
			
		||||
        let (receive_sender, receive_receiver) = channel(RECEIVE_CHANNEL_QUEUE_LEN);
 | 
			
		||||
        let mut processor = Self {
 | 
			
		||||
            mtu,
 | 
			
		||||
            local_mac,
 | 
			
		||||
            local_cidrs,
 | 
			
		||||
            factory,
 | 
			
		||||
            table: NatTable::new(),
 | 
			
		||||
            transmit_sender,
 | 
			
		||||
            reclaim_sender,
 | 
			
		||||
            receive_receiver,
 | 
			
		||||
            reclaim_receiver,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let handle = tokio::task::spawn(async move {
 | 
			
		||||
            if let Err(error) = processor.process().await {
 | 
			
		||||
                warn!("nat processing failed: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        Ok((receive_sender, handle))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn process(&mut self) -> Result<()> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                x = self.reclaim_receiver.recv() => NatProcessorSelect::Reclaim(x),
 | 
			
		||||
                x = self.receive_receiver.recv() => NatProcessorSelect::ReceivedPacket(x),
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                NatProcessorSelect::Reclaim(Some(key)) => {
 | 
			
		||||
                    if self.table.inner.remove(&key).is_some() {
 | 
			
		||||
                        debug!("reclaimed nat key: {}", key);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                NatProcessorSelect::ReceivedPacket(Some(packet)) => {
 | 
			
		||||
                    if let Ok(slice) = SlicedPacket::from_ethernet(&packet) {
 | 
			
		||||
                        let Ok(packet) = RecvPacket::new(&packet, &slice) else {
 | 
			
		||||
                            continue;
 | 
			
		||||
                        };
 | 
			
		||||
 | 
			
		||||
                        self.process_packet(&packet).await?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                NatProcessorSelect::ReceivedPacket(None) | NatProcessorSelect::Reclaim(None) => {
 | 
			
		||||
                    break
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn process_reclaim(&mut self) -> Result<Option<NatKey>> {
 | 
			
		||||
        Ok(if let Some(key) = self.reclaim_receiver.recv().await {
 | 
			
		||||
            if self.table.inner.remove(&key).is_some() {
 | 
			
		||||
                debug!("reclaimed nat key: {}", key);
 | 
			
		||||
                Some(key)
 | 
			
		||||
            } else {
 | 
			
		||||
                None
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            None
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn process_packet<'a>(&mut self, packet: &RecvPacket<'a>) -> Result<()> {
 | 
			
		||||
        let Some(ether) = packet.ether else {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let mac = EthernetAddress(ether.destination());
 | 
			
		||||
        if mac != self.local_mac {
 | 
			
		||||
            trace!(
 | 
			
		||||
                "received packet with destination {} which is not the local mac {}",
 | 
			
		||||
                mac,
 | 
			
		||||
                self.local_mac
 | 
			
		||||
            );
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let key = match packet.ip {
 | 
			
		||||
            Some(RecvPacketIp::Ipv4(ipv4)) => self.extract_key_ipv4(packet, ipv4)?,
 | 
			
		||||
            Some(RecvPacketIp::Ipv6(ipv6)) => self.extract_key_ipv6(packet, ipv6)?,
 | 
			
		||||
            _ => None,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let Some(key) = key else {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        for cidr in &self.local_cidrs {
 | 
			
		||||
            if cidr.contains_addr(&key.external_ip.addr) {
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let context = NatHandlerContext {
 | 
			
		||||
            mtu: self.mtu,
 | 
			
		||||
            key,
 | 
			
		||||
            transmit_sender: self.transmit_sender.clone(),
 | 
			
		||||
            reclaim_sender: self.reclaim_sender.clone(),
 | 
			
		||||
        };
 | 
			
		||||
        let handler: Option<&mut Box<dyn NatHandler>> = match self.table.inner.entry(key) {
 | 
			
		||||
            Entry::Occupied(entry) => Some(entry.into_mut()),
 | 
			
		||||
            Entry::Vacant(entry) => {
 | 
			
		||||
                if let Some(handler) = self.factory.nat(context).await {
 | 
			
		||||
                    debug!("creating nat entry for key: {}", key);
 | 
			
		||||
                    Some(entry.insert(handler))
 | 
			
		||||
                } else {
 | 
			
		||||
                    None
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if let Some(handler) = handler {
 | 
			
		||||
            if !handler.receive(packet.raw).await? {
 | 
			
		||||
                self.reclaim_sender.try_send(key)?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_ipv4<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        ipv4: &Ipv4Slice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let source_addr = IpAddress::Ipv4(ipv4.header().source_addr().into());
 | 
			
		||||
        let dest_addr = IpAddress::Ipv4(ipv4.header().destination_addr().into());
 | 
			
		||||
        Ok(match ipv4.header().protocol() {
 | 
			
		||||
            IpNumber::TCP => {
 | 
			
		||||
                self.extract_key_tcp(packet, source_addr, dest_addr, ipv4.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpNumber::UDP => {
 | 
			
		||||
                self.extract_key_udp(packet, source_addr, dest_addr, ipv4.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpNumber::ICMP => {
 | 
			
		||||
                self.extract_key_icmpv4(packet, source_addr, dest_addr, ipv4.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            _ => None,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_ipv6<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        ipv6: &Ipv6Slice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let source_addr = IpAddress::Ipv6(ipv6.header().source_addr().into());
 | 
			
		||||
        let dest_addr = IpAddress::Ipv6(ipv6.header().destination_addr().into());
 | 
			
		||||
        Ok(match ipv6.header().next_header() {
 | 
			
		||||
            IpNumber::TCP => {
 | 
			
		||||
                self.extract_key_tcp(packet, source_addr, dest_addr, ipv6.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpNumber::UDP => {
 | 
			
		||||
                self.extract_key_udp(packet, source_addr, dest_addr, ipv6.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpNumber::IPV6_ICMP => {
 | 
			
		||||
                self.extract_key_icmpv6(packet, source_addr, dest_addr, ipv6.payload())?
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            _ => None,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_udp<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        source_addr: IpAddress,
 | 
			
		||||
        dest_addr: IpAddress,
 | 
			
		||||
        payload: &IpPayloadSlice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let Some(ether) = packet.ether else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let header = UdpHeaderSlice::from_slice(payload.payload)?;
 | 
			
		||||
        let source = IpEndpoint::new(source_addr, header.source_port());
 | 
			
		||||
        let dest = IpEndpoint::new(dest_addr, header.destination_port());
 | 
			
		||||
        Ok(Some(NatKey {
 | 
			
		||||
            protocol: NatKeyProtocol::Udp,
 | 
			
		||||
            client_mac: EthernetAddress(ether.source()),
 | 
			
		||||
            local_mac: EthernetAddress(ether.destination()),
 | 
			
		||||
            client_ip: source,
 | 
			
		||||
            external_ip: dest,
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_icmpv4<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        source_addr: IpAddress,
 | 
			
		||||
        dest_addr: IpAddress,
 | 
			
		||||
        payload: &IpPayloadSlice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let Some(ether) = packet.ether else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let (header, _) = Icmpv4Header::from_slice(payload.payload)?;
 | 
			
		||||
        let Icmpv4Type::EchoRequest(_) = header.icmp_type else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let source = IpEndpoint::new(source_addr, 0);
 | 
			
		||||
        let dest = IpEndpoint::new(dest_addr, 0);
 | 
			
		||||
        Ok(Some(NatKey {
 | 
			
		||||
            protocol: NatKeyProtocol::Icmp,
 | 
			
		||||
            client_mac: EthernetAddress(ether.source()),
 | 
			
		||||
            local_mac: EthernetAddress(ether.destination()),
 | 
			
		||||
            client_ip: source,
 | 
			
		||||
            external_ip: dest,
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_icmpv6<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        source_addr: IpAddress,
 | 
			
		||||
        dest_addr: IpAddress,
 | 
			
		||||
        payload: &IpPayloadSlice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let Some(ether) = packet.ether else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let (header, _) = Icmpv6Header::from_slice(payload.payload)?;
 | 
			
		||||
        let Icmpv6Type::EchoRequest(_) = header.icmp_type else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let source = IpEndpoint::new(source_addr, 0);
 | 
			
		||||
        let dest = IpEndpoint::new(dest_addr, 0);
 | 
			
		||||
        Ok(Some(NatKey {
 | 
			
		||||
            protocol: NatKeyProtocol::Icmp,
 | 
			
		||||
            client_mac: EthernetAddress(ether.source()),
 | 
			
		||||
            local_mac: EthernetAddress(ether.destination()),
 | 
			
		||||
            client_ip: source,
 | 
			
		||||
            external_ip: dest,
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn extract_key_tcp<'a>(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        packet: &RecvPacket<'a>,
 | 
			
		||||
        source_addr: IpAddress,
 | 
			
		||||
        dest_addr: IpAddress,
 | 
			
		||||
        payload: &IpPayloadSlice<'a>,
 | 
			
		||||
    ) -> Result<Option<NatKey>> {
 | 
			
		||||
        let Some(ether) = packet.ether else {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        };
 | 
			
		||||
        let header = TcpHeaderSlice::from_slice(payload.payload)?;
 | 
			
		||||
        let source = IpEndpoint::new(source_addr, header.source_port());
 | 
			
		||||
        let dest = IpEndpoint::new(dest_addr, header.destination_port());
 | 
			
		||||
        Ok(Some(NatKey {
 | 
			
		||||
            protocol: NatKeyProtocol::Tcp,
 | 
			
		||||
            client_mac: EthernetAddress(ether.source()),
 | 
			
		||||
            local_mac: EthernetAddress(ether.destination()),
 | 
			
		||||
            client_ip: source,
 | 
			
		||||
            external_ip: dest,
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										21
									
								
								crates/kratanet/src/nat/table.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								crates/kratanet/src/nat/table.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
 | 
			
		||||
use super::{handler::NatHandler, key::NatKey};
 | 
			
		||||
 | 
			
		||||
pub struct NatTable {
 | 
			
		||||
    pub inner: HashMap<NatKey, Box<dyn NatHandler>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for NatTable {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self::new()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NatTable {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            inner: HashMap::new(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										37
									
								
								crates/kratanet/src/pkt.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								crates/kratanet/src/pkt.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use etherparse::{Ethernet2Slice, Ipv4Slice, Ipv6Slice, LinkSlice, NetSlice, SlicedPacket};
 | 
			
		||||
 | 
			
		||||
pub enum RecvPacketIp<'a> {
 | 
			
		||||
    Ipv4(&'a Ipv4Slice<'a>),
 | 
			
		||||
    Ipv6(&'a Ipv6Slice<'a>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct RecvPacket<'a> {
 | 
			
		||||
    pub raw: &'a [u8],
 | 
			
		||||
    pub slice: &'a SlicedPacket<'a>,
 | 
			
		||||
    pub ether: Option<&'a Ethernet2Slice<'a>>,
 | 
			
		||||
    pub ip: Option<RecvPacketIp<'a>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl RecvPacket<'_> {
 | 
			
		||||
    pub fn new<'a>(raw: &'a [u8], slice: &'a SlicedPacket<'a>) -> Result<RecvPacket<'a>> {
 | 
			
		||||
        let ether = match slice.link {
 | 
			
		||||
            Some(LinkSlice::Ethernet2(ref ether)) => Some(ether),
 | 
			
		||||
            _ => None,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let ip = match slice.net {
 | 
			
		||||
            Some(NetSlice::Ipv4(ref ipv4)) => Some(RecvPacketIp::Ipv4(ipv4)),
 | 
			
		||||
            Some(NetSlice::Ipv6(ref ipv6)) => Some(RecvPacketIp::Ipv6(ipv6)),
 | 
			
		||||
            _ => None,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let packet = RecvPacket {
 | 
			
		||||
            raw,
 | 
			
		||||
            slice,
 | 
			
		||||
            ether,
 | 
			
		||||
            ip,
 | 
			
		||||
        };
 | 
			
		||||
        Ok(packet)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										276
									
								
								crates/kratanet/src/proxynat/icmp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										276
									
								
								crates/kratanet/src/proxynat/icmp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,276 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    net::{IpAddr, Ipv4Addr, Ipv6Addr},
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use bytes::{BufMut, BytesMut};
 | 
			
		||||
use etherparse::{
 | 
			
		||||
    IcmpEchoHeader, Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type, IpNumber, Ipv4Slice,
 | 
			
		||||
    Ipv6Slice, NetSlice, PacketBuilder, SlicedPacket,
 | 
			
		||||
};
 | 
			
		||||
use log::{debug, trace, warn};
 | 
			
		||||
use smoltcp::wire::IpAddress;
 | 
			
		||||
use tokio::{
 | 
			
		||||
    select,
 | 
			
		||||
    sync::mpsc::{Receiver, Sender},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    icmp::{IcmpClient, IcmpProtocol, IcmpReply},
 | 
			
		||||
    nat::handler::{NatHandler, NatHandlerContext},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
const ICMP_PING_TIMEOUT_SECS: u64 = 20;
 | 
			
		||||
const ICMP_TIMEOUT_SECS: u64 = 30;
 | 
			
		||||
 | 
			
		||||
pub struct ProxyIcmpHandler {
 | 
			
		||||
    rx_sender: Sender<BytesMut>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl NatHandler for ProxyIcmpHandler {
 | 
			
		||||
    async fn receive(&self, data: &[u8]) -> Result<bool> {
 | 
			
		||||
        if self.rx_sender.is_closed() {
 | 
			
		||||
            Ok(true)
 | 
			
		||||
        } else {
 | 
			
		||||
            self.rx_sender.try_send(data.into())?;
 | 
			
		||||
            Ok(true)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum ProxyIcmpSelect {
 | 
			
		||||
    Internal(BytesMut),
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ProxyIcmpHandler {
 | 
			
		||||
    pub fn new(rx_sender: Sender<BytesMut>) -> Self {
 | 
			
		||||
        ProxyIcmpHandler { rx_sender }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn spawn(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let client = IcmpClient::new(match context.key.external_ip.addr {
 | 
			
		||||
            IpAddress::Ipv4(_) => IcmpProtocol::Icmpv4,
 | 
			
		||||
            IpAddress::Ipv6(_) => IcmpProtocol::Icmpv6,
 | 
			
		||||
        })?;
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Err(error) = ProxyIcmpHandler::process(client, rx_receiver, context).await {
 | 
			
		||||
                warn!("processing of icmp proxy failed: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        client: IcmpClient,
 | 
			
		||||
        mut rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let deadline = tokio::time::sleep(Duration::from_secs(ICMP_TIMEOUT_SECS));
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                    ProxyIcmpSelect::Internal(data)
 | 
			
		||||
                } else {
 | 
			
		||||
                    ProxyIcmpSelect::Close
 | 
			
		||||
                },
 | 
			
		||||
                _ =  deadline => ProxyIcmpSelect::Close,
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                ProxyIcmpSelect::Internal(data) => {
 | 
			
		||||
                    let packet = SlicedPacket::from_ethernet(&data)?;
 | 
			
		||||
                    let Some(ref net) = packet.net else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    match net {
 | 
			
		||||
                        NetSlice::Ipv4(ipv4) => {
 | 
			
		||||
                            ProxyIcmpHandler::process_ipv4(&context, ipv4, &client).await?
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        NetSlice::Ipv6(ipv6) => {
 | 
			
		||||
                            ProxyIcmpHandler::process_ipv6(&context, ipv6, &client).await?
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyIcmpSelect::Close => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        context.reclaim().await?;
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process_ipv4(
 | 
			
		||||
        context: &NatHandlerContext,
 | 
			
		||||
        ipv4: &Ipv4Slice<'_>,
 | 
			
		||||
        client: &IcmpClient,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        if ipv4.header().protocol() != IpNumber::ICMP {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let (header, payload) = Icmpv4Header::from_slice(ipv4.payload().payload)?;
 | 
			
		||||
        if let Icmpv4Type::EchoRequest(echo) = header.icmp_type {
 | 
			
		||||
            let IpAddr::V4(external_ipv4) = context.key.external_ip.addr.into() else {
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let context = context.clone();
 | 
			
		||||
            let client = client.clone();
 | 
			
		||||
            let payload = payload.to_vec();
 | 
			
		||||
            tokio::task::spawn(async move {
 | 
			
		||||
                if let Err(error) = ProxyIcmpHandler::process_echo_ipv4(
 | 
			
		||||
                    context,
 | 
			
		||||
                    client,
 | 
			
		||||
                    external_ipv4,
 | 
			
		||||
                    echo,
 | 
			
		||||
                    payload,
 | 
			
		||||
                )
 | 
			
		||||
                .await
 | 
			
		||||
                {
 | 
			
		||||
                    trace!("icmp4 echo failed: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process_ipv6(
 | 
			
		||||
        context: &NatHandlerContext,
 | 
			
		||||
        ipv6: &Ipv6Slice<'_>,
 | 
			
		||||
        client: &IcmpClient,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        if ipv6.header().next_header() != IpNumber::IPV6_ICMP {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let (header, payload) = Icmpv6Header::from_slice(ipv6.payload().payload)?;
 | 
			
		||||
        if let Icmpv6Type::EchoRequest(echo) = header.icmp_type {
 | 
			
		||||
            let IpAddr::V6(external_ipv6) = context.key.external_ip.addr.into() else {
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let context = context.clone();
 | 
			
		||||
            let client = client.clone();
 | 
			
		||||
            let payload = payload.to_vec();
 | 
			
		||||
            tokio::task::spawn(async move {
 | 
			
		||||
                if let Err(error) = ProxyIcmpHandler::process_echo_ipv6(
 | 
			
		||||
                    context,
 | 
			
		||||
                    client,
 | 
			
		||||
                    external_ipv6,
 | 
			
		||||
                    echo,
 | 
			
		||||
                    payload,
 | 
			
		||||
                )
 | 
			
		||||
                .await
 | 
			
		||||
                {
 | 
			
		||||
                    trace!("icmp6 echo failed: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process_echo_ipv4(
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        client: IcmpClient,
 | 
			
		||||
        external_ipv4: Ipv4Addr,
 | 
			
		||||
        echo: IcmpEchoHeader,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let reply = client
 | 
			
		||||
            .ping4(
 | 
			
		||||
                external_ipv4,
 | 
			
		||||
                echo.id,
 | 
			
		||||
                echo.seq,
 | 
			
		||||
                &payload,
 | 
			
		||||
                Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
 | 
			
		||||
            )
 | 
			
		||||
            .await?;
 | 
			
		||||
        let Some(IcmpReply::Icmpv4 {
 | 
			
		||||
            header: _,
 | 
			
		||||
            echo,
 | 
			
		||||
            payload,
 | 
			
		||||
        }) = reply
 | 
			
		||||
        else {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
 | 
			
		||||
        let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
 | 
			
		||||
            (IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
 | 
			
		||||
                packet.ipv4(external_addr.0, client_addr.0, 20)
 | 
			
		||||
            }
 | 
			
		||||
            _ => {
 | 
			
		||||
                return Err(anyhow!("IP endpoint mismatch"));
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        let packet = packet.icmpv4_echo_reply(echo.id, echo.seq);
 | 
			
		||||
        let buffer = BytesMut::with_capacity(packet.size(payload.len()));
 | 
			
		||||
        let mut writer = buffer.writer();
 | 
			
		||||
        packet.write(&mut writer, &payload)?;
 | 
			
		||||
        let buffer = writer.into_inner();
 | 
			
		||||
        if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
            debug!("failed to transmit icmp packet: {}", error);
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process_echo_ipv6(
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        client: IcmpClient,
 | 
			
		||||
        external_ipv6: Ipv6Addr,
 | 
			
		||||
        echo: IcmpEchoHeader,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let reply = client
 | 
			
		||||
            .ping6(
 | 
			
		||||
                external_ipv6,
 | 
			
		||||
                echo.id,
 | 
			
		||||
                echo.seq,
 | 
			
		||||
                &payload,
 | 
			
		||||
                Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
 | 
			
		||||
            )
 | 
			
		||||
            .await?;
 | 
			
		||||
        let Some(IcmpReply::Icmpv6 {
 | 
			
		||||
            header: _,
 | 
			
		||||
            echo,
 | 
			
		||||
            payload,
 | 
			
		||||
        }) = reply
 | 
			
		||||
        else {
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let packet = PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
 | 
			
		||||
        let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
 | 
			
		||||
            (IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => {
 | 
			
		||||
                packet.ipv6(external_addr.0, client_addr.0, 20)
 | 
			
		||||
            }
 | 
			
		||||
            _ => {
 | 
			
		||||
                return Err(anyhow!("IP endpoint mismatch"));
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        let packet = packet.icmpv6_echo_reply(echo.id, echo.seq);
 | 
			
		||||
        let buffer = BytesMut::with_capacity(packet.size(payload.len()));
 | 
			
		||||
        let mut writer = buffer.writer();
 | 
			
		||||
        packet.write(&mut writer, &payload)?;
 | 
			
		||||
        let buffer = writer.into_inner();
 | 
			
		||||
        if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
            debug!("failed to transmit icmp packet: {}", error);
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										77
									
								
								crates/kratanet/src/proxynat/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								crates/kratanet/src/proxynat/mod.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use log::warn;
 | 
			
		||||
 | 
			
		||||
use tokio::sync::mpsc::channel;
 | 
			
		||||
 | 
			
		||||
use crate::proxynat::udp::ProxyUdpHandler;
 | 
			
		||||
 | 
			
		||||
use crate::nat::handler::{NatHandler, NatHandlerContext, NatHandlerFactory};
 | 
			
		||||
use crate::nat::key::NatKeyProtocol;
 | 
			
		||||
 | 
			
		||||
use self::icmp::ProxyIcmpHandler;
 | 
			
		||||
use self::tcp::ProxyTcpHandler;
 | 
			
		||||
 | 
			
		||||
mod icmp;
 | 
			
		||||
mod tcp;
 | 
			
		||||
mod udp;
 | 
			
		||||
 | 
			
		||||
const RX_CHANNEL_QUEUE_LEN: usize = 1000;
 | 
			
		||||
 | 
			
		||||
pub struct ProxyNatHandlerFactory {}
 | 
			
		||||
 | 
			
		||||
impl Default for ProxyNatHandlerFactory {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        Self::new()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ProxyNatHandlerFactory {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {}
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl NatHandlerFactory for ProxyNatHandlerFactory {
 | 
			
		||||
    async fn nat(&self, context: NatHandlerContext) -> Option<Box<dyn NatHandler>> {
 | 
			
		||||
        match context.key.protocol {
 | 
			
		||||
            NatKeyProtocol::Udp => {
 | 
			
		||||
                let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
 | 
			
		||||
                let mut handler = ProxyUdpHandler::new(rx_sender);
 | 
			
		||||
 | 
			
		||||
                if let Err(error) = handler.spawn(context, rx_receiver).await {
 | 
			
		||||
                    warn!("unable to spawn udp proxy handler: {}", error);
 | 
			
		||||
                    None
 | 
			
		||||
                } else {
 | 
			
		||||
                    Some(Box::new(handler))
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            NatKeyProtocol::Icmp => {
 | 
			
		||||
                let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
 | 
			
		||||
                let mut handler = ProxyIcmpHandler::new(rx_sender);
 | 
			
		||||
 | 
			
		||||
                if let Err(error) = handler.spawn(context, rx_receiver).await {
 | 
			
		||||
                    warn!("unable to spawn icmp proxy handler: {}", error);
 | 
			
		||||
                    None
 | 
			
		||||
                } else {
 | 
			
		||||
                    Some(Box::new(handler))
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            NatKeyProtocol::Tcp => {
 | 
			
		||||
                let (rx_sender, rx_receiver) = channel::<BytesMut>(RX_CHANNEL_QUEUE_LEN);
 | 
			
		||||
                let mut handler = ProxyTcpHandler::new(rx_sender);
 | 
			
		||||
 | 
			
		||||
                if let Err(error) = handler.spawn(context, rx_receiver).await {
 | 
			
		||||
                    warn!("unable to spawn tcp proxy handler: {}", error);
 | 
			
		||||
                    None
 | 
			
		||||
                } else {
 | 
			
		||||
                    Some(Box::new(handler))
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										466
									
								
								crates/kratanet/src/proxynat/tcp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										466
									
								
								crates/kratanet/src/proxynat/tcp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,466 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    net::{IpAddr, SocketAddr},
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use etherparse::{EtherType, Ethernet2Header};
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use smoltcp::{
 | 
			
		||||
    iface::{Config, Interface, SocketSet, SocketStorage},
 | 
			
		||||
    phy::Medium,
 | 
			
		||||
    socket::tcp::{self, SocketBuffer, State},
 | 
			
		||||
    time::Instant,
 | 
			
		||||
    wire::{HardwareAddress, IpAddress, IpCidr},
 | 
			
		||||
};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncReadExt, AsyncWriteExt},
 | 
			
		||||
    net::TcpStream,
 | 
			
		||||
    select,
 | 
			
		||||
    sync::mpsc::channel,
 | 
			
		||||
};
 | 
			
		||||
use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    chandev::ChannelDevice,
 | 
			
		||||
    nat::handler::{NatHandler, NatHandlerContext},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
const TCP_BUFFER_SIZE: usize = 65535;
 | 
			
		||||
const TCP_IP_BUFFER_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const TCP_ACCEPT_TIMEOUT_SECS: u64 = 120;
 | 
			
		||||
const TCP_DANGLE_TIMEOUT_SECS: u64 = 10;
 | 
			
		||||
 | 
			
		||||
pub struct ProxyTcpHandler {
 | 
			
		||||
    rx_sender: Sender<BytesMut>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl NatHandler for ProxyTcpHandler {
 | 
			
		||||
    async fn receive(&self, data: &[u8]) -> Result<bool> {
 | 
			
		||||
        if self.rx_sender.is_closed() {
 | 
			
		||||
            Ok(false)
 | 
			
		||||
        } else {
 | 
			
		||||
            self.rx_sender.try_send(data.into())?;
 | 
			
		||||
            Ok(true)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
enum ProxyTcpAcceptSelect {
 | 
			
		||||
    Internal(BytesMut),
 | 
			
		||||
    TxIpPacket(BytesMut),
 | 
			
		||||
    TimePassed,
 | 
			
		||||
    DoNothing,
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
enum ProxyTcpDataSelect {
 | 
			
		||||
    ExternalRecv(usize),
 | 
			
		||||
    ExternalSent(usize),
 | 
			
		||||
    InternalRecv(BytesMut),
 | 
			
		||||
    TxIpPacket(BytesMut),
 | 
			
		||||
    TimePassed,
 | 
			
		||||
    DoNothing,
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
enum ProxyTcpFinishSelect {
 | 
			
		||||
    InternalRecv(BytesMut),
 | 
			
		||||
    TxIpPacket(BytesMut),
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ProxyTcpHandler {
 | 
			
		||||
    pub fn new(rx_sender: Sender<BytesMut>) -> Self {
 | 
			
		||||
        ProxyTcpHandler { rx_sender }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn spawn(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let external_addr = match context.key.external_ip.addr {
 | 
			
		||||
            IpAddress::Ipv4(addr) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V4(addr.0.into()), context.key.external_ip.port)
 | 
			
		||||
            }
 | 
			
		||||
            IpAddress::Ipv6(addr) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V6(addr.0.into()), context.key.external_ip.port)
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let socket = TcpStream::connect(external_addr).await?;
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Err(error) = ProxyTcpHandler::process(context, socket, rx_receiver).await {
 | 
			
		||||
                warn!("processing of tcp proxy failed: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        mut external_socket: TcpStream,
 | 
			
		||||
        mut rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let (ip_sender, mut ip_receiver) = channel::<BytesMut>(TCP_IP_BUFFER_QUEUE_LEN);
 | 
			
		||||
        let mut external_buffer = vec![0u8; TCP_BUFFER_SIZE];
 | 
			
		||||
 | 
			
		||||
        let mut device = ChannelDevice::new(
 | 
			
		||||
            context.mtu - Ethernet2Header::LEN,
 | 
			
		||||
            Medium::Ip,
 | 
			
		||||
            ip_sender.clone(),
 | 
			
		||||
        );
 | 
			
		||||
        let config = Config::new(HardwareAddress::Ip);
 | 
			
		||||
 | 
			
		||||
        let tcp_rx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]);
 | 
			
		||||
        let tcp_tx_buffer = SocketBuffer::new(vec![0; TCP_BUFFER_SIZE]);
 | 
			
		||||
        let internal_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer);
 | 
			
		||||
        let mut iface = Interface::new(config, &mut device, Instant::now());
 | 
			
		||||
 | 
			
		||||
        iface.update_ip_addrs(|addrs| {
 | 
			
		||||
            let _ = addrs.push(IpCidr::new(context.key.external_ip.addr, 0));
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        let mut sockets = SocketSet::new([SocketStorage::EMPTY]);
 | 
			
		||||
        let internal_socket_handle = sockets.add(internal_socket);
 | 
			
		||||
        let (mut external_r, mut external_w) = external_socket.split();
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
 | 
			
		||||
            socket.connect(
 | 
			
		||||
                iface.context(),
 | 
			
		||||
                context.key.client_ip,
 | 
			
		||||
                context.key.external_ip,
 | 
			
		||||
            )?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
 | 
			
		||||
        let mut sleeper: Option<tokio::time::Sleep> = None;
 | 
			
		||||
        loop {
 | 
			
		||||
            let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
 | 
			
		||||
            if socket.is_active() && socket.state() != State::SynSent {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if socket.state() == State::Closed {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let deadline = tokio::time::sleep(Duration::from_secs(TCP_ACCEPT_TIMEOUT_SECS));
 | 
			
		||||
            let selection = if let Some(sleep) = sleeper.take() {
 | 
			
		||||
                select! {
 | 
			
		||||
                    biased;
 | 
			
		||||
                    x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Internal(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpAcceptSelect::TxIpPacket(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    _ = sleep => ProxyTcpAcceptSelect::TimePassed,
 | 
			
		||||
                    _ = deadline => ProxyTcpAcceptSelect::Close,
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                select! {
 | 
			
		||||
                    biased;
 | 
			
		||||
                    x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Internal(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpAcceptSelect::TxIpPacket(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpAcceptSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    _ = std::future::ready(()) => ProxyTcpAcceptSelect::DoNothing,
 | 
			
		||||
                    _ = deadline => ProxyTcpAcceptSelect::Close,
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
            match selection {
 | 
			
		||||
                ProxyTcpAcceptSelect::TimePassed => {
 | 
			
		||||
                    iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpAcceptSelect::DoNothing => {
 | 
			
		||||
                    sleeper = Some(tokio::time::sleep(Duration::from_micros(100)));
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpAcceptSelect::Internal(data) => {
 | 
			
		||||
                    let (_, payload) = Ethernet2Header::from_slice(&data)?;
 | 
			
		||||
                    device.rx = Some(payload.into());
 | 
			
		||||
                    iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpAcceptSelect::TxIpPacket(payload) => {
 | 
			
		||||
                    let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
 | 
			
		||||
                    let header = Ethernet2Header {
 | 
			
		||||
                        source: context.key.local_mac.0,
 | 
			
		||||
                        destination: context.key.client_mac.0,
 | 
			
		||||
                        ether_type: match context.key.external_ip.addr {
 | 
			
		||||
                            IpAddress::Ipv4(_) => EtherType::IPV4,
 | 
			
		||||
                            IpAddress::Ipv6(_) => EtherType::IPV6,
 | 
			
		||||
                        },
 | 
			
		||||
                    };
 | 
			
		||||
                    buffer.extend_from_slice(&header.to_bytes());
 | 
			
		||||
                    buffer.extend_from_slice(&payload);
 | 
			
		||||
                    if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
                        debug!("failed to transmit tcp packet: {}", error);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpAcceptSelect::Close => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let accepted = if sockets
 | 
			
		||||
            .get_mut::<tcp::Socket>(internal_socket_handle)
 | 
			
		||||
            .is_active()
 | 
			
		||||
        {
 | 
			
		||||
            true
 | 
			
		||||
        } else {
 | 
			
		||||
            debug!("failed to accept tcp connection from client");
 | 
			
		||||
            false
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let mut already_shutdown = false;
 | 
			
		||||
        let mut sleeper: Option<tokio::time::Sleep> = None;
 | 
			
		||||
        loop {
 | 
			
		||||
            if !accepted {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let socket = sockets.get_mut::<tcp::Socket>(internal_socket_handle);
 | 
			
		||||
 | 
			
		||||
            match socket.state() {
 | 
			
		||||
                State::Closed
 | 
			
		||||
                | State::Listen
 | 
			
		||||
                | State::Closing
 | 
			
		||||
                | State::LastAck
 | 
			
		||||
                | State::TimeWait => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
                State::FinWait1
 | 
			
		||||
                | State::SynSent
 | 
			
		||||
                | State::CloseWait
 | 
			
		||||
                | State::FinWait2
 | 
			
		||||
                | State::SynReceived
 | 
			
		||||
                | State::Established => {}
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let bytes_to_client = if socket.can_send() {
 | 
			
		||||
                socket.send_capacity() - socket.send_queue()
 | 
			
		||||
            } else {
 | 
			
		||||
                0
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            let (bytes_to_external, do_shutdown) = if socket.may_recv() {
 | 
			
		||||
                if let Ok(data) = socket.peek(TCP_BUFFER_SIZE) {
 | 
			
		||||
                    if data.is_empty() {
 | 
			
		||||
                        (None, false)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        (Some(data), false)
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    (None, false)
 | 
			
		||||
                }
 | 
			
		||||
            } else if !already_shutdown && matches!(socket.state(), State::CloseWait) {
 | 
			
		||||
                (None, true)
 | 
			
		||||
            } else {
 | 
			
		||||
                (None, false)
 | 
			
		||||
            };
 | 
			
		||||
            let selection = if let Some(sleep) = sleeper.take() {
 | 
			
		||||
                if !do_shutdown {
 | 
			
		||||
                    select! {
 | 
			
		||||
                        biased;
 | 
			
		||||
                        x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                            ProxyTcpDataSelect::TxIpPacket(data)
 | 
			
		||||
                        } else {
 | 
			
		||||
                            ProxyTcpDataSelect::Close
 | 
			
		||||
                        },
 | 
			
		||||
                        x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                            ProxyTcpDataSelect::InternalRecv(data)
 | 
			
		||||
                        } else {
 | 
			
		||||
                            ProxyTcpDataSelect::Close
 | 
			
		||||
                        },
 | 
			
		||||
                        x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?),
 | 
			
		||||
                        x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
 | 
			
		||||
                        _ = sleep => ProxyTcpDataSelect::TimePassed,
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    select! {
 | 
			
		||||
                        biased;
 | 
			
		||||
                        x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                            ProxyTcpDataSelect::TxIpPacket(data)
 | 
			
		||||
                        } else {
 | 
			
		||||
                            ProxyTcpDataSelect::Close
 | 
			
		||||
                        },
 | 
			
		||||
                        x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                            ProxyTcpDataSelect::InternalRecv(data)
 | 
			
		||||
                        } else {
 | 
			
		||||
                            ProxyTcpDataSelect::Close
 | 
			
		||||
                        },
 | 
			
		||||
                        _ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0),
 | 
			
		||||
                        x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
 | 
			
		||||
                        _ = sleep => ProxyTcpDataSelect::TimePassed,
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            } else if !do_shutdown {
 | 
			
		||||
                select! {
 | 
			
		||||
                    biased;
 | 
			
		||||
                    x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpDataSelect::TxIpPacket(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpDataSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpDataSelect::InternalRecv(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpDataSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    x = external_w.write(bytes_to_external.unwrap_or(b"")), if bytes_to_external.is_some() => ProxyTcpDataSelect::ExternalSent(x?),
 | 
			
		||||
                    x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
 | 
			
		||||
                    _ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing,
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                select! {
 | 
			
		||||
                    biased;
 | 
			
		||||
                    x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpDataSelect::TxIpPacket(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpDataSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                        ProxyTcpDataSelect::InternalRecv(data)
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ProxyTcpDataSelect::Close
 | 
			
		||||
                    },
 | 
			
		||||
                    _ = external_w.shutdown() => ProxyTcpDataSelect::ExternalSent(0),
 | 
			
		||||
                    x = external_r.read(&mut external_buffer[..bytes_to_client]), if bytes_to_client > 0 => ProxyTcpDataSelect::ExternalRecv(x?),
 | 
			
		||||
                    _ = std::future::ready(()) => ProxyTcpDataSelect::DoNothing,
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
            match selection {
 | 
			
		||||
                ProxyTcpDataSelect::ExternalRecv(size) => {
 | 
			
		||||
                    if size == 0 {
 | 
			
		||||
                        socket.close();
 | 
			
		||||
                    } else {
 | 
			
		||||
                        socket.send_slice(&external_buffer[..size])?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::ExternalSent(size) => {
 | 
			
		||||
                    if size == 0 {
 | 
			
		||||
                        already_shutdown = true;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        socket.recv(|_| (size, ()))?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::InternalRecv(data) => {
 | 
			
		||||
                    let (_, payload) = Ethernet2Header::from_slice(&data)?;
 | 
			
		||||
                    device.rx = Some(payload.into());
 | 
			
		||||
                    iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::TxIpPacket(payload) => {
 | 
			
		||||
                    let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
 | 
			
		||||
                    let header = Ethernet2Header {
 | 
			
		||||
                        source: context.key.local_mac.0,
 | 
			
		||||
                        destination: context.key.client_mac.0,
 | 
			
		||||
                        ether_type: match context.key.external_ip.addr {
 | 
			
		||||
                            IpAddress::Ipv4(_) => EtherType::IPV4,
 | 
			
		||||
                            IpAddress::Ipv6(_) => EtherType::IPV6,
 | 
			
		||||
                        },
 | 
			
		||||
                    };
 | 
			
		||||
                    buffer.extend_from_slice(&header.to_bytes());
 | 
			
		||||
                    buffer.extend_from_slice(&payload);
 | 
			
		||||
                    if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
                        debug!("failed to transmit tcp packet: {}", error);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::TimePassed => {
 | 
			
		||||
                    iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::DoNothing => {
 | 
			
		||||
                    sleeper = Some(tokio::time::sleep(Duration::from_micros(100)));
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpDataSelect::Close => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let _ = external_socket.shutdown().await;
 | 
			
		||||
        drop(external_socket);
 | 
			
		||||
 | 
			
		||||
        loop {
 | 
			
		||||
            let deadline = tokio::time::sleep(Duration::from_secs(TCP_DANGLE_TIMEOUT_SECS));
 | 
			
		||||
            tokio::pin!(deadline);
 | 
			
		||||
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                biased;
 | 
			
		||||
                x = ip_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                    ProxyTcpFinishSelect::TxIpPacket(data)
 | 
			
		||||
                } else {
 | 
			
		||||
                    ProxyTcpFinishSelect::Close
 | 
			
		||||
                },
 | 
			
		||||
                x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                    ProxyTcpFinishSelect::InternalRecv(data)
 | 
			
		||||
                } else {
 | 
			
		||||
                    ProxyTcpFinishSelect::Close
 | 
			
		||||
                },
 | 
			
		||||
                _ = deadline => ProxyTcpFinishSelect::Close,
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                ProxyTcpFinishSelect::InternalRecv(data) => {
 | 
			
		||||
                    let (_, payload) = Ethernet2Header::from_slice(&data)?;
 | 
			
		||||
                    device.rx = Some(payload.into());
 | 
			
		||||
                    iface.poll(Instant::now(), &mut device, &mut sockets);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpFinishSelect::TxIpPacket(payload) => {
 | 
			
		||||
                    let mut buffer = BytesMut::with_capacity(Ethernet2Header::LEN + payload.len());
 | 
			
		||||
                    let header = Ethernet2Header {
 | 
			
		||||
                        source: context.key.local_mac.0,
 | 
			
		||||
                        destination: context.key.client_mac.0,
 | 
			
		||||
                        ether_type: match context.key.external_ip.addr {
 | 
			
		||||
                            IpAddress::Ipv4(_) => EtherType::IPV4,
 | 
			
		||||
                            IpAddress::Ipv6(_) => EtherType::IPV6,
 | 
			
		||||
                        },
 | 
			
		||||
                    };
 | 
			
		||||
                    buffer.extend_from_slice(&header.to_bytes());
 | 
			
		||||
                    buffer.extend_from_slice(&payload);
 | 
			
		||||
                    if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
                        debug!("failed to transmit tcp packet: {}", error);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ProxyTcpFinishSelect::Close => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        context.reclaim().await?;
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										142
									
								
								crates/kratanet/src/proxynat/udp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								crates/kratanet/src/proxynat/udp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,142 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    net::{IpAddr, SocketAddr},
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use async_trait::async_trait;
 | 
			
		||||
use bytes::{BufMut, BytesMut};
 | 
			
		||||
use etherparse::{PacketBuilder, SlicedPacket, UdpSlice};
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use smoltcp::wire::IpAddress;
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncReadExt, AsyncWriteExt},
 | 
			
		||||
    select,
 | 
			
		||||
};
 | 
			
		||||
use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender};
 | 
			
		||||
use udp_stream::UdpStream;
 | 
			
		||||
 | 
			
		||||
use crate::nat::handler::{NatHandler, NatHandlerContext};
 | 
			
		||||
 | 
			
		||||
const UDP_TIMEOUT_SECS: u64 = 60;
 | 
			
		||||
 | 
			
		||||
pub struct ProxyUdpHandler {
 | 
			
		||||
    rx_sender: Sender<BytesMut>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[async_trait]
 | 
			
		||||
impl NatHandler for ProxyUdpHandler {
 | 
			
		||||
    async fn receive(&self, data: &[u8]) -> Result<bool> {
 | 
			
		||||
        if self.rx_sender.is_closed() {
 | 
			
		||||
            Ok(true)
 | 
			
		||||
        } else {
 | 
			
		||||
            self.rx_sender.try_send(data.into())?;
 | 
			
		||||
            Ok(true)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum ProxyUdpSelect {
 | 
			
		||||
    External(usize),
 | 
			
		||||
    Internal(BytesMut),
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ProxyUdpHandler {
 | 
			
		||||
    pub fn new(rx_sender: Sender<BytesMut>) -> Self {
 | 
			
		||||
        ProxyUdpHandler { rx_sender }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn spawn(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let external_addr = match context.key.external_ip.addr {
 | 
			
		||||
            IpAddress::Ipv4(addr) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V4(addr.0.into()), context.key.external_ip.port)
 | 
			
		||||
            }
 | 
			
		||||
            IpAddress::Ipv6(addr) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V6(addr.0.into()), context.key.external_ip.port)
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let socket = UdpStream::connect(external_addr).await?;
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Err(error) = ProxyUdpHandler::process(context, socket, rx_receiver).await {
 | 
			
		||||
                warn!("processing of udp proxy failed: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        context: NatHandlerContext,
 | 
			
		||||
        mut socket: UdpStream,
 | 
			
		||||
        mut rx_receiver: Receiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let mut external_buffer = vec![0u8; 2048];
 | 
			
		||||
 | 
			
		||||
        loop {
 | 
			
		||||
            let deadline = tokio::time::sleep(Duration::from_secs(UDP_TIMEOUT_SECS));
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                x = rx_receiver.recv() => if let Some(data) = x {
 | 
			
		||||
                    ProxyUdpSelect::Internal(data)
 | 
			
		||||
                } else {
 | 
			
		||||
                    ProxyUdpSelect::Close
 | 
			
		||||
                },
 | 
			
		||||
                x = socket.read(&mut external_buffer) => ProxyUdpSelect::External(x?),
 | 
			
		||||
                _ = deadline => ProxyUdpSelect::Close,
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                ProxyUdpSelect::External(size) => {
 | 
			
		||||
                    let data = &external_buffer[0..size];
 | 
			
		||||
                    let packet =
 | 
			
		||||
                        PacketBuilder::ethernet2(context.key.local_mac.0, context.key.client_mac.0);
 | 
			
		||||
                    let packet = match (context.key.external_ip.addr, context.key.client_ip.addr) {
 | 
			
		||||
                        (IpAddress::Ipv4(external_addr), IpAddress::Ipv4(client_addr)) => {
 | 
			
		||||
                            packet.ipv4(external_addr.0, client_addr.0, 20)
 | 
			
		||||
                        }
 | 
			
		||||
                        (IpAddress::Ipv6(external_addr), IpAddress::Ipv6(client_addr)) => {
 | 
			
		||||
                            packet.ipv6(external_addr.0, client_addr.0, 20)
 | 
			
		||||
                        }
 | 
			
		||||
                        _ => {
 | 
			
		||||
                            return Err(anyhow!("IP endpoint mismatch"));
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                    let packet =
 | 
			
		||||
                        packet.udp(context.key.external_ip.port, context.key.client_ip.port);
 | 
			
		||||
                    let buffer = BytesMut::with_capacity(packet.size(data.len()));
 | 
			
		||||
                    let mut writer = buffer.writer();
 | 
			
		||||
                    packet.write(&mut writer, data)?;
 | 
			
		||||
                    let buffer = writer.into_inner();
 | 
			
		||||
                    if let Err(error) = context.try_transmit(buffer) {
 | 
			
		||||
                        debug!("failed to transmit udp packet: {}", error);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                ProxyUdpSelect::Internal(data) => {
 | 
			
		||||
                    let packet = SlicedPacket::from_ethernet(&data)?;
 | 
			
		||||
                    let Some(ref net) = packet.net else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let Some(ip) = net.ip_payload_ref() else {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    let udp = UdpSlice::from_slice(ip.payload)?;
 | 
			
		||||
                    socket.write_all(udp.payload()).await?;
 | 
			
		||||
                }
 | 
			
		||||
                ProxyUdpSelect::Close => {
 | 
			
		||||
                    drop(socket);
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        context.reclaim().await?;
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										299
									
								
								crates/kratanet/src/raw_socket.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										299
									
								
								crates/kratanet/src/raw_socket.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,299 @@
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use std::io::ErrorKind;
 | 
			
		||||
use std::os::fd::{FromRawFd, IntoRawFd};
 | 
			
		||||
use std::os::unix::io::{AsRawFd, RawFd};
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use std::{io, mem};
 | 
			
		||||
use tokio::net::UdpSocket;
 | 
			
		||||
use tokio::select;
 | 
			
		||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
const RAW_SOCKET_TRANSMIT_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const RAW_SOCKET_RECEIVE_QUEUE_LEN: usize = 3000;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum RawSocketProtocol {
 | 
			
		||||
    Icmpv4,
 | 
			
		||||
    Icmpv6,
 | 
			
		||||
    Ethernet,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl RawSocketProtocol {
 | 
			
		||||
    pub fn to_socket_domain(&self) -> i32 {
 | 
			
		||||
        match self {
 | 
			
		||||
            RawSocketProtocol::Icmpv4 => libc::AF_INET,
 | 
			
		||||
            RawSocketProtocol::Icmpv6 => libc::AF_INET6,
 | 
			
		||||
            RawSocketProtocol::Ethernet => libc::AF_PACKET,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn to_socket_protocol(&self) -> u16 {
 | 
			
		||||
        match self {
 | 
			
		||||
            RawSocketProtocol::Icmpv4 => libc::IPPROTO_ICMP as u16,
 | 
			
		||||
            RawSocketProtocol::Icmpv6 => libc::IPPROTO_ICMPV6 as u16,
 | 
			
		||||
            RawSocketProtocol::Ethernet => (libc::ETH_P_ALL as u16).to_be(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn to_socket_type(&self) -> i32 {
 | 
			
		||||
        libc::SOCK_RAW
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const SIOCGIFINDEX: libc::c_ulong = 0x8933;
 | 
			
		||||
const SIOCGIFMTU: libc::c_ulong = 0x8921;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct RawSocketHandle {
 | 
			
		||||
    protocol: RawSocketProtocol,
 | 
			
		||||
    lower: libc::c_int,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsRawFd for RawSocketHandle {
 | 
			
		||||
    fn as_raw_fd(&self) -> RawFd {
 | 
			
		||||
        self.lower
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IntoRawFd for RawSocketHandle {
 | 
			
		||||
    fn into_raw_fd(self) -> RawFd {
 | 
			
		||||
        let fd = self.lower;
 | 
			
		||||
        mem::forget(self);
 | 
			
		||||
        fd
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl RawSocketHandle {
 | 
			
		||||
    pub fn new(protocol: RawSocketProtocol) -> io::Result<RawSocketHandle> {
 | 
			
		||||
        let lower = unsafe {
 | 
			
		||||
            let lower = libc::socket(
 | 
			
		||||
                protocol.to_socket_domain(),
 | 
			
		||||
                protocol.to_socket_type() | libc::SOCK_NONBLOCK,
 | 
			
		||||
                protocol.to_socket_protocol() as i32,
 | 
			
		||||
            );
 | 
			
		||||
            if lower == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            lower
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(RawSocketHandle { protocol, lower })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
 | 
			
		||||
        let mut socket = RawSocketHandle::new(protocol)?;
 | 
			
		||||
        socket.bind_to_interface(interface)?;
 | 
			
		||||
        Ok(socket)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn bind_to_interface(&mut self, interface: &str) -> io::Result<()> {
 | 
			
		||||
        let mut ifreq = ifreq_for(interface);
 | 
			
		||||
        let sockaddr = libc::sockaddr_ll {
 | 
			
		||||
            sll_family: libc::AF_PACKET as u16,
 | 
			
		||||
            sll_protocol: self.protocol.to_socket_protocol(),
 | 
			
		||||
            sll_ifindex: ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFINDEX)?,
 | 
			
		||||
            sll_hatype: 1,
 | 
			
		||||
            sll_pkttype: 0,
 | 
			
		||||
            sll_halen: 6,
 | 
			
		||||
            sll_addr: [0; 8],
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let res = libc::bind(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr,
 | 
			
		||||
                mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
 | 
			
		||||
            );
 | 
			
		||||
            if res == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn mtu_of_interface(&mut self, interface: &str) -> io::Result<usize> {
 | 
			
		||||
        let mut ifreq = ifreq_for(interface);
 | 
			
		||||
        ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFMTU).map(|mtu| mtu as usize)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let len = libc::recv(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                buffer.as_mut_ptr() as *mut libc::c_void,
 | 
			
		||||
                buffer.len(),
 | 
			
		||||
                0,
 | 
			
		||||
            );
 | 
			
		||||
            if len == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            Ok(len as usize)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn send(&self, buffer: &[u8]) -> io::Result<usize> {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let len = libc::send(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                buffer.as_ptr() as *const libc::c_void,
 | 
			
		||||
                buffer.len(),
 | 
			
		||||
                0,
 | 
			
		||||
            );
 | 
			
		||||
            if len == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            Ok(len as usize)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for RawSocketHandle {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            libc::close(self.lower);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[repr(C)]
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct Ifreq {
 | 
			
		||||
    ifr_name: [libc::c_char; libc::IF_NAMESIZE],
 | 
			
		||||
    ifr_data: libc::c_int,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn ifreq_for(name: &str) -> Ifreq {
 | 
			
		||||
    let mut ifreq = Ifreq {
 | 
			
		||||
        ifr_name: [0; libc::IF_NAMESIZE],
 | 
			
		||||
        ifr_data: 0,
 | 
			
		||||
    };
 | 
			
		||||
    for (i, byte) in name.as_bytes().iter().enumerate() {
 | 
			
		||||
        ifreq.ifr_name[i] = *byte as libc::c_char
 | 
			
		||||
    }
 | 
			
		||||
    ifreq
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn ifreq_ioctl(
 | 
			
		||||
    lower: libc::c_int,
 | 
			
		||||
    ifreq: &mut Ifreq,
 | 
			
		||||
    cmd: libc::c_ulong,
 | 
			
		||||
) -> io::Result<libc::c_int> {
 | 
			
		||||
    unsafe {
 | 
			
		||||
        let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq);
 | 
			
		||||
        if res == -1 {
 | 
			
		||||
            return Err(io::Error::last_os_error());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(ifreq.ifr_data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct AsyncRawSocketChannel {
 | 
			
		||||
    pub sender: Sender<BytesMut>,
 | 
			
		||||
    pub receiver: Receiver<BytesMut>,
 | 
			
		||||
    _task: Arc<JoinHandle<()>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum AsyncRawSocketChannelSelect {
 | 
			
		||||
    TransmitPacket(Option<BytesMut>),
 | 
			
		||||
    Readable(()),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRawSocketChannel {
 | 
			
		||||
    pub fn new(mtu: usize, socket: RawSocketHandle) -> Result<AsyncRawSocketChannel> {
 | 
			
		||||
        let (transmit_sender, transmit_receiver) = channel(RAW_SOCKET_TRANSMIT_QUEUE_LEN);
 | 
			
		||||
        let (receive_sender, receive_receiver) = channel(RAW_SOCKET_RECEIVE_QUEUE_LEN);
 | 
			
		||||
        let task = AsyncRawSocketChannel::launch(mtu, socket, transmit_receiver, receive_sender)?;
 | 
			
		||||
        Ok(AsyncRawSocketChannel {
 | 
			
		||||
            sender: transmit_sender,
 | 
			
		||||
            receiver: receive_receiver,
 | 
			
		||||
            _task: Arc::new(task),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn launch(
 | 
			
		||||
        mtu: usize,
 | 
			
		||||
        socket: RawSocketHandle,
 | 
			
		||||
        transmit_receiver: Receiver<BytesMut>,
 | 
			
		||||
        receive_sender: Sender<BytesMut>,
 | 
			
		||||
    ) -> Result<JoinHandle<()>> {
 | 
			
		||||
        Ok(tokio::task::spawn(async move {
 | 
			
		||||
            if let Err(error) =
 | 
			
		||||
                AsyncRawSocketChannel::process(mtu, socket, transmit_receiver, receive_sender).await
 | 
			
		||||
            {
 | 
			
		||||
                warn!("failed to process raw socket: {}", error);
 | 
			
		||||
            }
 | 
			
		||||
        }))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        mtu: usize,
 | 
			
		||||
        socket: RawSocketHandle,
 | 
			
		||||
        mut transmit_receiver: Receiver<BytesMut>,
 | 
			
		||||
        receive_sender: Sender<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) };
 | 
			
		||||
        let socket = UdpSocket::from_std(socket)?;
 | 
			
		||||
 | 
			
		||||
        let mut buffer = vec![0; mtu];
 | 
			
		||||
        loop {
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                x = transmit_receiver.recv() => AsyncRawSocketChannelSelect::TransmitPacket(x),
 | 
			
		||||
                x = socket.readable() => AsyncRawSocketChannelSelect::Readable(x?),
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                AsyncRawSocketChannelSelect::Readable(_) => {
 | 
			
		||||
                    match socket.try_recv(&mut buffer) {
 | 
			
		||||
                        Ok(len) => {
 | 
			
		||||
                            if len == 0 {
 | 
			
		||||
                                continue;
 | 
			
		||||
                            }
 | 
			
		||||
                            let buffer = (&buffer[0..len]).into();
 | 
			
		||||
                            if let Err(error) = receive_sender.try_send(buffer) {
 | 
			
		||||
                                debug!(
 | 
			
		||||
                                    "failed to process received packet from raw socket: {}",
 | 
			
		||||
                                    error
 | 
			
		||||
                                );
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        Err(ref error) => {
 | 
			
		||||
                            if error.kind() == ErrorKind::WouldBlock {
 | 
			
		||||
                                continue;
 | 
			
		||||
                            }
 | 
			
		||||
                            return Err(anyhow!("failed to read from raw socket: {}", error));
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                AsyncRawSocketChannelSelect::TransmitPacket(Some(packet)) => {
 | 
			
		||||
                    match socket.try_send(&packet) {
 | 
			
		||||
                        Ok(_len) => {}
 | 
			
		||||
                        Err(ref error) => {
 | 
			
		||||
                            if error.kind() == ErrorKind::WouldBlock {
 | 
			
		||||
                                debug!("failed to transmit: would block");
 | 
			
		||||
                                continue;
 | 
			
		||||
                            }
 | 
			
		||||
                            return Err(anyhow!(
 | 
			
		||||
                                "failed to write {} bytes to raw socket: {}",
 | 
			
		||||
                                packet.len(),
 | 
			
		||||
                                error
 | 
			
		||||
                            ));
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                AsyncRawSocketChannelSelect::TransmitPacket(None) => {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										212
									
								
								crates/kratanet/src/vbridge.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								crates/kratanet/src/vbridge.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,212 @@
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use bytes::BytesMut;
 | 
			
		||||
use etherparse::{EtherType, Ethernet2Header, IpNumber, Ipv4Header, Ipv6Header, TcpHeader};
 | 
			
		||||
use log::{debug, trace, warn};
 | 
			
		||||
use smoltcp::wire::EthernetAddress;
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::{hash_map::Entry, HashMap},
 | 
			
		||||
    sync::Arc,
 | 
			
		||||
};
 | 
			
		||||
use tokio::sync::broadcast::{
 | 
			
		||||
    channel as broadcast_channel, Receiver as BroadcastReceiver, Sender as BroadcastSender,
 | 
			
		||||
};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    select,
 | 
			
		||||
    sync::{
 | 
			
		||||
        mpsc::{channel, Receiver, Sender},
 | 
			
		||||
        Mutex,
 | 
			
		||||
    },
 | 
			
		||||
    task::JoinHandle,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
const TO_BRIDGE_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const FROM_BRIDGE_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const BROADCAST_QUEUE_LEN: usize = 3000;
 | 
			
		||||
const MEMBER_LEAVE_QUEUE_LEN: usize = 30;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct BridgeMember {
 | 
			
		||||
    pub from_bridge_sender: Sender<BytesMut>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct BridgeJoinHandle {
 | 
			
		||||
    mac: EthernetAddress,
 | 
			
		||||
    pub to_bridge_sender: Sender<BytesMut>,
 | 
			
		||||
    pub from_bridge_receiver: Receiver<BytesMut>,
 | 
			
		||||
    pub from_broadcast_receiver: BroadcastReceiver<BytesMut>,
 | 
			
		||||
    member_leave_sender: Sender<EthernetAddress>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for BridgeJoinHandle {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        if let Err(error) = self.member_leave_sender.try_send(self.mac) {
 | 
			
		||||
            warn!(
 | 
			
		||||
                "virtual bridge member {} failed to leave: {}",
 | 
			
		||||
                self.mac, error
 | 
			
		||||
            );
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type VirtualBridgeMemberMap = Arc<Mutex<HashMap<EthernetAddress, BridgeMember>>>;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct VirtualBridge {
 | 
			
		||||
    to_bridge_sender: Sender<BytesMut>,
 | 
			
		||||
    from_broadcast_sender: BroadcastSender<BytesMut>,
 | 
			
		||||
    member_leave_sender: Sender<EthernetAddress>,
 | 
			
		||||
    members: VirtualBridgeMemberMap,
 | 
			
		||||
    _task: Arc<JoinHandle<()>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum VirtualBridgeSelect {
 | 
			
		||||
    BroadcastSent(Option<BytesMut>),
 | 
			
		||||
    PacketReceived(Option<BytesMut>),
 | 
			
		||||
    MemberLeave(Option<EthernetAddress>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl VirtualBridge {
 | 
			
		||||
    pub fn new() -> Result<VirtualBridge> {
 | 
			
		||||
        let (to_bridge_sender, to_bridge_receiver) = channel::<BytesMut>(TO_BRIDGE_QUEUE_LEN);
 | 
			
		||||
        let (member_leave_sender, member_leave_reciever) =
 | 
			
		||||
            channel::<EthernetAddress>(MEMBER_LEAVE_QUEUE_LEN);
 | 
			
		||||
        let (from_broadcast_sender, from_broadcast_receiver) =
 | 
			
		||||
            broadcast_channel(BROADCAST_QUEUE_LEN);
 | 
			
		||||
 | 
			
		||||
        let members = Arc::new(Mutex::new(HashMap::new()));
 | 
			
		||||
        let handle = {
 | 
			
		||||
            let members = members.clone();
 | 
			
		||||
            let broadcast_rx_sender = from_broadcast_sender.clone();
 | 
			
		||||
            tokio::task::spawn(async move {
 | 
			
		||||
                if let Err(error) = VirtualBridge::process(
 | 
			
		||||
                    members,
 | 
			
		||||
                    member_leave_reciever,
 | 
			
		||||
                    to_bridge_receiver,
 | 
			
		||||
                    broadcast_rx_sender,
 | 
			
		||||
                    from_broadcast_receiver,
 | 
			
		||||
                )
 | 
			
		||||
                .await
 | 
			
		||||
                {
 | 
			
		||||
                    warn!("virtual bridge processing task failed: {}", error);
 | 
			
		||||
                }
 | 
			
		||||
            })
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(VirtualBridge {
 | 
			
		||||
            to_bridge_sender,
 | 
			
		||||
            from_broadcast_sender,
 | 
			
		||||
            member_leave_sender,
 | 
			
		||||
            members,
 | 
			
		||||
            _task: Arc::new(handle),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn join(&self, mac: EthernetAddress) -> Result<BridgeJoinHandle> {
 | 
			
		||||
        let (from_bridge_sender, from_bridge_receiver) = channel::<BytesMut>(FROM_BRIDGE_QUEUE_LEN);
 | 
			
		||||
        let member = BridgeMember { from_bridge_sender };
 | 
			
		||||
 | 
			
		||||
        match self.members.lock().await.entry(mac) {
 | 
			
		||||
            Entry::Occupied(_) => {
 | 
			
		||||
                return Err(anyhow!("virtual bridge member {} already exists", mac));
 | 
			
		||||
            }
 | 
			
		||||
            Entry::Vacant(entry) => {
 | 
			
		||||
                entry.insert(member);
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        debug!("virtual bridge member {} has joined", mac);
 | 
			
		||||
        Ok(BridgeJoinHandle {
 | 
			
		||||
            mac,
 | 
			
		||||
            member_leave_sender: self.member_leave_sender.clone(),
 | 
			
		||||
            from_bridge_receiver,
 | 
			
		||||
            from_broadcast_receiver: self.from_broadcast_sender.subscribe(),
 | 
			
		||||
            to_bridge_sender: self.to_bridge_sender.clone(),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn process(
 | 
			
		||||
        members: VirtualBridgeMemberMap,
 | 
			
		||||
        mut member_leave_reciever: Receiver<EthernetAddress>,
 | 
			
		||||
        mut to_bridge_receiver: Receiver<BytesMut>,
 | 
			
		||||
        broadcast_rx_sender: BroadcastSender<BytesMut>,
 | 
			
		||||
        mut from_broadcast_receiver: BroadcastReceiver<BytesMut>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let selection = select! {
 | 
			
		||||
                biased;
 | 
			
		||||
                x = from_broadcast_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()),
 | 
			
		||||
                x = to_bridge_receiver.recv() => VirtualBridgeSelect::PacketReceived(x),
 | 
			
		||||
                x = member_leave_reciever.recv() => VirtualBridgeSelect::MemberLeave(x),
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            match selection {
 | 
			
		||||
                VirtualBridgeSelect::PacketReceived(Some(mut packet)) => {
 | 
			
		||||
                    let (header, payload) = match Ethernet2Header::from_slice(&packet) {
 | 
			
		||||
                        Ok(data) => data,
 | 
			
		||||
                        Err(error) => {
 | 
			
		||||
                            debug!("virtual bridge failed to parse ethernet header: {}", error);
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
                    // recalculate TCP checksums when routing packets.
 | 
			
		||||
                    // the xen network backend / frontend drivers for linux
 | 
			
		||||
                    // use checksum offloading but since we bypass some layers
 | 
			
		||||
                    // of the kernel we have to do it ourselves.
 | 
			
		||||
                    if header.ether_type == EtherType::IPV4 {
 | 
			
		||||
                        let (ipv4, payload) = Ipv4Header::from_slice(payload)?;
 | 
			
		||||
                        if ipv4.protocol == IpNumber::TCP {
 | 
			
		||||
                            let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
 | 
			
		||||
                            tcp.checksum = tcp.calc_checksum_ipv4(&ipv4, payload)?;
 | 
			
		||||
                            let tcp_header_offset = Ethernet2Header::LEN + ipv4.header_len();
 | 
			
		||||
                            let tcp_header_bytes = tcp.to_bytes();
 | 
			
		||||
                            for (i, b) in tcp_header_bytes.iter().enumerate() {
 | 
			
		||||
                                packet[tcp_header_offset + i] = *b;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    } else if header.ether_type == EtherType::IPV6 {
 | 
			
		||||
                        let (ipv6, payload) = Ipv6Header::from_slice(payload)?;
 | 
			
		||||
                        if ipv6.next_header == IpNumber::TCP {
 | 
			
		||||
                            let (mut tcp, payload) = TcpHeader::from_slice(payload)?;
 | 
			
		||||
                            tcp.checksum = tcp.calc_checksum_ipv6(&ipv6, payload)?;
 | 
			
		||||
                            let tcp_header_offset = Ethernet2Header::LEN + ipv6.header_len();
 | 
			
		||||
                            let tcp_header_bytes = tcp.to_bytes();
 | 
			
		||||
                            for (i, b) in tcp_header_bytes.iter().enumerate() {
 | 
			
		||||
                                packet[tcp_header_offset + i] = *b;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    let destination = EthernetAddress(header.destination);
 | 
			
		||||
                    if destination.is_multicast() {
 | 
			
		||||
                        broadcast_rx_sender.send(packet)?;
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
                    match members.lock().await.get(&destination) {
 | 
			
		||||
                        Some(member) => {
 | 
			
		||||
                            member.from_bridge_sender.try_send(packet)?;
 | 
			
		||||
                            trace!(
 | 
			
		||||
                                "sending bridged packet from {} to {}",
 | 
			
		||||
                                EthernetAddress(header.source),
 | 
			
		||||
                                EthernetAddress(header.destination)
 | 
			
		||||
                            );
 | 
			
		||||
                        }
 | 
			
		||||
                        None => {
 | 
			
		||||
                            trace!("no bridge member with address: {}", destination);
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                VirtualBridgeSelect::MemberLeave(Some(mac)) => {
 | 
			
		||||
                    if members.lock().await.remove(&mac).is_some() {
 | 
			
		||||
                        debug!("virtual bridge member {} has left", mac);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                VirtualBridgeSelect::PacketReceived(None) => break,
 | 
			
		||||
                VirtualBridgeSelect::MemberLeave(None) => {}
 | 
			
		||||
                VirtualBridgeSelect::BroadcastSent(_) => {}
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user