network: implement icmp nat support

This commit is contained in:
Alex Zenla
2024-02-10 21:13:47 +00:00
parent 4f0e505e2b
commit efe425b346
10 changed files with 470 additions and 105 deletions

View File

@ -53,7 +53,6 @@ udp-stream = "0.0.11"
smoltcp = "0.11.0"
etherparse = "0.14.2"
async-trait = "0.1.77"
async-ping = "0.2.1"
[workspace.dependencies.uuid]
version = "1.6.1"
@ -74,7 +73,3 @@ features = ["macros", "rt", "rt-multi-thread"]
[workspace.dependencies.serde]
version = "1.0.196"
features = ["derive"]
[workspace.dependencies.icmp-client]
version = "0.2"
features = ["impl_tokio"]

View File

@ -18,8 +18,6 @@ udp-stream = { workspace = true }
smoltcp = { workspace = true }
etherparse = { workspace = true }
async-trait = { workspace = true }
async-ping = { workspace = true }
icmp-client = { workspace = true }
[dependencies.advmac]
path = "../libs/advmac"
@ -30,3 +28,7 @@ path = "src/lib.rs"
[[bin]]
name = "hyphanet"
path = "bin/network.rs"
[[example]]
name = "ping"
path = "examples/ping.rs"

21
network/examples/ping.rs Normal file
View File

@ -0,0 +1,21 @@
use std::{net::Ipv6Addr, str::FromStr, time::Duration};
use anyhow::Result;
use hyphanet::icmp::{IcmpClient, IcmpProtocol};
#[tokio::main]
async fn main() -> Result<()> {
let client = IcmpClient::new(IcmpProtocol::Icmp6)?;
let payload: [u8; 4] = [12u8, 14u8, 16u8, 32u8];
let result = client
.ping6(
Ipv6Addr::from_str("2606:4700:4700::1111")?,
0,
1,
&payload,
Duration::from_secs(10),
)
.await?;
println!("reply: {:?}", result);
Ok(())
}

View File

@ -1,7 +1,7 @@
use crate::chandev::ChannelDevice;
use crate::nat::NatRouter;
use crate::proxynat::ProxyNatHandlerFactory;
use crate::raw_socket::AsyncRawSocket;
use crate::raw_socket::{AsyncRawSocket, RawSocketProtocol};
use advmac::MacAddr6;
use anyhow::{anyhow, Result};
use futures::TryStreamExt;
@ -28,6 +28,7 @@ enum NetworkStackSelect<'a> {
}
struct NetworkStack<'a> {
mtu: usize,
tx: Receiver<Vec<u8>>,
kdev: AsyncRawSocket,
udev: ChannelDevice,
@ -101,7 +102,7 @@ impl NetworkBackend {
pub async fn run(&self) -> Result<()> {
let mut stack = self.create_network_stack()?;
let mut buffer = vec![0u8; 1500];
let mut buffer = vec![0u8; stack.mtu];
loop {
stack.poll(&mut buffer).await?;
}
@ -112,9 +113,11 @@ impl NetworkBackend {
let address = IpCidr::from_str(&self.network)
.map_err(|_| anyhow!("failed to parse cidr: {}", self.network))?;
let addresses: Vec<IpCidr> = vec![address];
let kdev = AsyncRawSocket::bind(&self.interface)?;
let mut kdev =
AsyncRawSocket::bound_to_interface(&self.interface, RawSocketProtocol::Ethernet)?;
let mtu = kdev.mtu_of_interface(&self.interface)?;
let (tx_sender, tx_receiver) = channel::<Vec<u8>>(4);
let mut udev = ChannelDevice::new(1500, tx_sender.clone());
let mut udev = ChannelDevice::new(mtu, tx_sender.clone());
let mac = MacAddr6::random();
let mac = smoltcp::wire::EthernetAddress(mac.to_array());
let nat = NatRouter::new(proxy, mac, addresses.clone(), tx_sender.clone());
@ -128,6 +131,7 @@ impl NetworkBackend {
});
let sockets = SocketSet::new(vec![]);
Ok(NetworkStack {
mtu,
tx: tx_receiver,
kdev,
udev,

245
network/src/icmp.rs Normal file
View File

@ -0,0 +1,245 @@
use crate::raw_socket::{RawSocketHandle, RawSocketProtocol};
use anyhow::{anyhow, Result};
use etherparse::{
IcmpEchoHeader, Icmpv4Header, Icmpv4Slice, Icmpv4Type, Icmpv6Header, Icmpv6Slice, Icmpv6Type,
IpNumber, NetSlice, SlicedPacket,
};
use log::warn;
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
os::fd::{FromRawFd, IntoRawFd},
sync::Arc,
time::Duration,
};
use tokio::{
net::UdpSocket,
sync::{oneshot, Mutex},
task::JoinHandle,
time::timeout,
};
#[derive(Debug)]
pub enum IcmpProtocol {
Icmp4,
Icmp6,
}
impl IcmpProtocol {
pub fn to_socket_protocol(&self) -> RawSocketProtocol {
match self {
IcmpProtocol::Icmp4 => RawSocketProtocol::Icmpv4,
IcmpProtocol::Icmp6 => RawSocketProtocol::Icmpv6,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
#[derive(Debug)]
pub enum IcmpReply {
Icmp4 {
header: Icmpv4Header,
echo: IcmpEchoHeader,
payload: Vec<u8>,
},
Icmp6 {
header: Icmpv6Header,
echo: IcmpEchoHeader,
payload: Vec<u8>,
},
}
type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
pub struct IcmpClient {
socket: Arc<UdpSocket>,
handlers: IcmpHandlerMap,
task: Arc<JoinHandle<Result<()>>>,
}
impl IcmpClient {
pub fn new(protocol: IcmpProtocol) -> Result<IcmpClient> {
let handle = RawSocketHandle::new(protocol.to_socket_protocol())?;
let socket = unsafe { std::net::UdpSocket::from_raw_fd(handle.into_raw_fd()) };
let socket: Arc<UdpSocket> = Arc::new(socket.try_into()?);
let handlers = Arc::new(Mutex::new(HashMap::new()));
let task = Arc::new(tokio::task::spawn(IcmpClient::process(
protocol,
socket.clone(),
handlers.clone(),
)));
Ok(IcmpClient {
socket,
handlers,
task,
})
}
async fn process(
protocol: IcmpProtocol,
socket: Arc<UdpSocket>,
handlers: IcmpHandlerMap,
) -> Result<()> {
let mut buffer = vec![0u8; 2048];
loop {
let (size, addr) = socket.recv_from(&mut buffer).await?;
let packet = &buffer[0..size];
let (token, reply) = match protocol {
IcmpProtocol::Icmp4 => {
let sliced = match SlicedPacket::from_ip(packet) {
Ok(sliced) => sliced,
Err(error) => {
warn!("received icmp packet but failed to parse it: {}", error);
continue;
}
};
let Some(NetSlice::Ipv4(ipv4)) = sliced.net else {
continue;
};
if ipv4.header().protocol() != IpNumber::ICMP {
continue;
}
let Ok(icmpv4) = Icmpv4Slice::from_slice(ipv4.payload().payload) else {
continue;
};
let Icmpv4Type::EchoReply(echo) = icmpv4.header().icmp_type else {
continue;
};
let token = IcmpHandlerToken(
IpAddr::V4(ipv4.header().source_addr()),
Some(echo.id),
echo.seq,
);
let reply = IcmpReply::Icmp4 {
header: icmpv4.header(),
echo,
payload: icmpv4.payload().to_vec(),
};
(token, reply)
}
IcmpProtocol::Icmp6 => {
let Ok(icmpv6) = Icmpv6Slice::from_slice(packet) else {
continue;
};
let Icmpv6Type::EchoReply(echo) = icmpv6.header().icmp_type else {
continue;
};
let SocketAddr::V6(addr) = addr else {
continue;
};
let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
let reply = IcmpReply::Icmp6 {
header: icmpv6.header(),
echo,
payload: icmpv6.payload().to_vec(),
};
(token, reply)
}
};
if let Some(sender) = handlers.lock().await.remove(&token) {
let _ = sender.send(reply);
}
}
}
async fn add_handler(&self, token: IcmpHandlerToken) -> Result<oneshot::Receiver<IcmpReply>> {
let (tx, rx) = oneshot::channel();
if self
.handlers
.lock()
.await
.insert(token.clone(), tx)
.is_some()
{
return Err(anyhow!("duplicate icmp request: {:?}", token));
}
Ok(rx)
}
async fn remove_handler(&self, token: IcmpHandlerToken) -> Result<()> {
self.handlers.lock().await.remove(&token);
Ok(())
}
pub async fn ping4(
&self,
addr: Ipv4Addr,
id: u16,
seq: u16,
payload: &[u8],
deadline: Duration,
) -> Result<Option<IcmpReply>> {
let token = IcmpHandlerToken(IpAddr::V4(addr), Some(id), seq);
let rx = self.add_handler(token.clone()).await?;
let echo = IcmpEchoHeader { id, seq };
let header = Icmpv4Header::new(Icmpv4Type::EchoRequest(echo));
let mut buffer: Vec<u8> = Vec::new();
header.write(&mut buffer)?;
buffer.extend_from_slice(payload);
self.socket
.send_to(&buffer, SocketAddr::V4(SocketAddrV4::new(addr, 0)))
.await?;
let result = timeout(deadline, rx).await;
self.remove_handler(token).await?;
let reply = match result {
Ok(Ok(packet)) => Some(packet),
Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
Err(_) => None,
};
Ok(reply)
}
pub async fn ping6(
&self,
addr: Ipv6Addr,
id: u16,
seq: u16,
payload: &[u8],
deadline: Duration,
) -> Result<Option<IcmpReply>> {
let token = IcmpHandlerToken(IpAddr::V6(addr), Some(id), seq);
let rx = self.add_handler(token.clone()).await?;
let echo = IcmpEchoHeader { id, seq };
let header = Icmpv6Header::new(Icmpv6Type::EchoRequest(echo));
let mut buffer: Vec<u8> = Vec::new();
header.write(&mut buffer)?;
buffer.extend_from_slice(payload);
self.socket
.send_to(&buffer, SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)))
.await?;
let result = timeout(deadline, rx).await;
self.remove_handler(token).await?;
let reply = match result {
Ok(Ok(packet)) => Some(packet),
Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
Err(_) => None,
};
Ok(reply)
}
}
impl Drop for IcmpClient {
fn drop(&mut self) {
if Arc::strong_count(&self.task) <= 1 {
self.task.abort();
}
}
}

View File

@ -8,11 +8,12 @@ use tokio::time::sleep;
use crate::backend::NetworkBackend;
mod backend;
mod chandev;
mod nat;
mod proxynat;
mod raw_socket;
pub mod backend;
pub mod chandev;
pub mod icmp;
pub mod nat;
pub mod proxynat;
pub mod raw_socket;
pub struct NetworkService {
pub network: String,

View File

@ -69,6 +69,12 @@ pub struct NatTable {
inner: HashMap<NatKey, Box<dyn NatHandler>>,
}
impl Default for NatTable {
fn default() -> Self {
Self::new()
}
}
impl NatTable {
pub fn new() -> Self {
Self {

View File

@ -1,13 +1,11 @@
use std::time::Duration;
use std::{net::IpAddr, time::Duration};
use anyhow::{anyhow, Result};
use async_ping::{
icmp_client::Config,
icmp_packet::{Icmp, Icmpv4},
PingClient,
};
use async_trait::async_trait;
use etherparse::{Icmpv4Header, Icmpv4Type, IpNumber, PacketBuilder, SlicedPacket};
use etherparse::{
Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type, IpNumber, NetSlice, PacketBuilder,
SlicedPacket,
};
use log::{debug, warn};
use smoltcp::wire::IpAddress;
use tokio::{
@ -15,7 +13,10 @@ use tokio::{
sync::mpsc::{Receiver, Sender},
};
use crate::nat::{NatHandler, NatKey};
use crate::{
icmp::{IcmpClient, IcmpProtocol, IcmpReply},
nat::{NatHandler, NatKey},
};
const ICMP_PING_TIMEOUT_SECS: u64 = 20;
const ICMP_TIMEOUT_SECS: u64 = 30;
@ -49,15 +50,7 @@ impl ProxyIcmpHandler {
tx_sender: Sender<Vec<u8>>,
reclaim_sender: Sender<NatKey>,
) -> Result<()> {
let client = PingClient::<icmp_client::impl_tokio::Client>::new(Some(Config::new()), None)?;
{
let client = client.clone();
tokio::spawn(async move {
client.handle_v4_recv_from().await;
});
}
let client = IcmpClient::new(IcmpProtocol::Icmp4)?;
let key = self.key;
tokio::spawn(async move {
if let Err(error) =
@ -70,7 +63,7 @@ impl ProxyIcmpHandler {
}
async fn process(
client: PingClient<icmp_client::impl_tokio::Client>,
client: IcmpClient,
key: NatKey,
mut rx_receiver: Receiver<Vec<u8>>,
tx_sender: Sender<Vec<u8>>,
@ -94,28 +87,36 @@ impl ProxyIcmpHandler {
continue;
};
let Some(ip) = net.ip_payload_ref() else {
continue;
};
if ip.ip_number != IpNumber::ICMP {
match net {
NetSlice::Ipv4(ipv4) => {
if ipv4.header().protocol() != IpNumber::ICMP {
continue;
}
let (header, payload) = Icmpv4Header::from_slice(ip.payload)?;
let (header, payload) =
Icmpv4Header::from_slice(ipv4.payload().payload)?;
if let Icmpv4Type::EchoRequest(echo) = header.icmp_type {
let result = client
.ping(
key.external_ip.addr.into(),
Some(echo.id),
Some(echo.seq),
let IpAddr::V4(external_ipv4) = key.external_ip.addr.into() else {
continue;
};
let Some(IcmpReply::Icmp4 {
header: _,
echo,
payload,
}) = client
.ping4(
external_ipv4,
echo.id,
echo.seq,
payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await;
match result {
Ok((icmp, _)) => match icmp {
Icmp::V4(Icmpv4::EchoReply(reply)) => {
.await?
else {
continue;
};
let packet =
PacketBuilder::ethernet2(key.local_mac.0, key.client_mac.0);
let packet = match (key.external_ip.addr, key.client_ip.addr) {
@ -123,6 +124,51 @@ impl ProxyIcmpHandler {
IpAddress::Ipv4(external_addr),
IpAddress::Ipv4(client_addr),
) => packet.ipv4(external_addr.0, client_addr.0, 20),
_ => {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv4_echo_reply(echo.id, echo.seq);
let mut buffer: Vec<u8> = Vec::new();
packet.write(&mut buffer, &payload)?;
if let Err(error) = tx_sender.try_send(buffer) {
debug!("failed to transmit icmp packet: {}", error);
}
}
}
NetSlice::Ipv6(ipv6) => {
if ipv6.header().next_header() != IpNumber::ICMP {
continue;
}
let (header, payload) =
Icmpv6Header::from_slice(ipv6.payload().payload)?;
if let Icmpv6Type::EchoRequest(echo) = header.icmp_type {
let IpAddr::V6(external_ipv6) = key.external_ip.addr.into() else {
continue;
};
let Some(IcmpReply::Icmp6 {
header: _,
echo,
payload,
}) = client
.ping6(
external_ipv6,
echo.id,
echo.seq,
payload,
Duration::from_secs(ICMP_PING_TIMEOUT_SECS),
)
.await?
else {
continue;
};
let packet =
PacketBuilder::ethernet2(key.local_mac.0, key.client_mac.0);
let packet = match (key.external_ip.addr, key.client_ip.addr) {
(
IpAddress::Ipv6(external_addr),
IpAddress::Ipv6(client_addr),
@ -131,25 +177,13 @@ impl ProxyIcmpHandler {
return Err(anyhow!("IP endpoint mismatch"));
}
};
let packet = packet.icmpv4_echo_reply(
reply.identifier.0,
reply.sequence_number.0,
);
let packet = packet.icmpv6_echo_reply(echo.id, echo.seq);
let mut buffer: Vec<u8> = Vec::new();
packet.write(&mut buffer, &reply.payload)?;
packet.write(&mut buffer, &payload)?;
if let Err(error) = tx_sender.try_send(buffer) {
debug!("failed to transmit icmp packet: {}", error);
}
}
Icmp::V4(Icmpv4::Other(_type, _code, _payload)) => {}
_ => {}
},
Err(error) => {
debug!("proxy for icmp failed to emulate ICMP ping: {}", error);
}
}
}
}

View File

@ -16,6 +16,12 @@ mod udp;
pub struct ProxyNatHandlerFactory {}
impl Default for ProxyNatHandlerFactory {
fn default() -> Self {
Self::new()
}
}
impl ProxyNatHandlerFactory {
pub fn new() -> Self {
Self {}

View File

@ -1,5 +1,6 @@
use anyhow::Result;
use futures::ready;
use std::os::fd::IntoRawFd;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};
@ -7,14 +8,42 @@ use std::{io, mem};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Debug)]
pub enum RawSocketProtocol {
Icmpv4,
Icmpv6,
Ethernet,
}
impl RawSocketProtocol {
pub fn to_socket_domain(&self) -> i32 {
match self {
RawSocketProtocol::Icmpv4 => libc::AF_INET,
RawSocketProtocol::Icmpv6 => libc::AF_INET6,
RawSocketProtocol::Ethernet => libc::AF_PACKET,
}
}
pub fn to_socket_protocol(&self) -> u16 {
match self {
RawSocketProtocol::Icmpv4 => libc::IPPROTO_ICMP as u16,
RawSocketProtocol::Icmpv6 => libc::IPPROTO_ICMPV6 as u16,
RawSocketProtocol::Ethernet => (libc::ETH_P_ALL as u16).to_be(),
}
}
pub fn to_socket_type(&self) -> i32 {
libc::SOCK_RAW
}
}
const SIOCGIFINDEX: libc::c_ulong = 0x8933;
const SIOCGIFMTU: libc::c_ulong = 0x8921;
#[derive(Debug)]
pub struct RawSocketHandle {
pub mtu: usize,
protocol: libc::c_short,
protocol: RawSocketProtocol,
lower: libc::c_int,
ifreq: Ifreq,
}
impl AsRawFd for RawSocketHandle {
@ -23,14 +52,21 @@ impl AsRawFd for RawSocketHandle {
}
}
impl IntoRawFd for RawSocketHandle {
fn into_raw_fd(self) -> RawFd {
let fd = self.lower;
mem::forget(self);
fd
}
}
impl RawSocketHandle {
pub fn new(interface: &str) -> io::Result<RawSocketHandle> {
let protocol: libc::c_short = 0x0003;
pub fn new(protocol: RawSocketProtocol) -> io::Result<RawSocketHandle> {
let lower = unsafe {
let lower = libc::socket(
libc::AF_PACKET,
libc::SOCK_RAW | libc::SOCK_NONBLOCK,
protocol.to_be() as i32,
protocol.to_socket_domain(),
protocol.to_socket_type() | libc::SOCK_NONBLOCK,
protocol.to_socket_protocol() as i32,
);
if lower == -1 {
return Err(io::Error::last_os_error());
@ -38,25 +74,21 @@ impl RawSocketHandle {
lower
};
Ok(RawSocketHandle {
mtu: 1500,
protocol,
lower,
ifreq: ifreq_for(interface),
})
Ok(RawSocketHandle { protocol, lower })
}
pub fn bind(interface: &str) -> Result<Self> {
let mut socket = RawSocketHandle::new(interface)?;
socket.bind_interface()?;
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
let mut socket = RawSocketHandle::new(protocol)?;
socket.bind_to_interface(interface)?;
Ok(socket)
}
pub fn bind_interface(&mut self) -> io::Result<()> {
pub fn bind_to_interface(&mut self, interface: &str) -> io::Result<()> {
let mut ifreq = ifreq_for(interface);
let sockaddr = libc::sockaddr_ll {
sll_family: libc::AF_PACKET as u16,
sll_protocol: self.protocol.to_be() as u16,
sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, SIOCGIFINDEX)?,
sll_protocol: self.protocol.to_socket_protocol(),
sll_ifindex: ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFINDEX)?,
sll_hatype: 1,
sll_pkttype: 0,
sll_halen: 6,
@ -77,6 +109,11 @@ impl RawSocketHandle {
Ok(())
}
pub fn mtu_of_interface(&mut self, interface: &str) -> io::Result<usize> {
let mut ifreq = ifreq_for(interface);
ifreq_ioctl(self.lower, &mut ifreq, SIOCGIFMTU).map(|mtu| mtu as usize)
}
pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
unsafe {
let len = libc::recv(
@ -120,7 +157,7 @@ impl Drop for RawSocketHandle {
#[derive(Debug)]
struct Ifreq {
ifr_name: [libc::c_char; libc::IF_NAMESIZE],
ifr_data: libc::c_int, /* ifr_ifindex or ifr_mtu */
ifr_data: libc::c_int,
}
fn ifreq_for(name: &str) -> Ifreq {
@ -160,10 +197,24 @@ impl AsyncRawSocket {
})
}
pub fn bind(interface: &str) -> Result<Self> {
let socket = RawSocketHandle::bind(interface)?;
pub fn bound_to_interface(interface: &str, protocol: RawSocketProtocol) -> Result<Self> {
let socket = RawSocketHandle::bound_to_interface(interface, protocol)?;
AsyncRawSocket::new(socket)
}
pub fn mtu_of_interface(&mut self, interface: &str) -> Result<usize> {
Ok(self.inner.get_mut().mtu_of_interface(interface)?)
}
}
impl TryFrom<RawSocketHandle> for AsyncRawSocket {
type Error = anyhow::Error;
fn try_from(value: RawSocketHandle) -> Result<Self, Self::Error> {
Ok(Self {
inner: AsyncFd::new(value)?,
})
}
}
impl AsyncRead for AsyncRawSocket {