hypha: work in progress implementation of outbound internet access

This commit is contained in:
Alex Zenla
2024-02-06 14:35:55 +00:00
parent 44d3799dd3
commit cfe8887c6b
18 changed files with 2102 additions and 66 deletions

30
libs/ipstack/src/error.rs Normal file
View File

@ -0,0 +1,30 @@
#[allow(dead_code)]
#[derive(thiserror::Error, Debug)]
pub enum IpStackError {
#[error("The transport protocol is not supported")]
UnsupportedTransportProtocol,
#[error("The packet is invalid")]
InvalidPacket,
#[error("Write error: {0}")]
PacketWriteError(etherparse::WriteError),
#[error("Invalid Tcp packet")]
InvalidTcpPacket,
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Accept Error")]
AcceptError,
#[error("Send Error {0}")]
SendError(#[from] tokio::sync::mpsc::error::SendError<crate::stream::IpStackStream>),
}
impl From<IpStackError> for std::io::Error {
fn from(e: IpStackError) -> Self {
match e {
IpStackError::IoError(e) => e,
_ => std::io::Error::new(std::io::ErrorKind::Other, e),
}
}
}
pub type Result<T, E = IpStackError> = std::result::Result<T, E>;

194
libs/ipstack/src/lib.rs Normal file
View File

@ -0,0 +1,194 @@
pub use error::{IpStackError, Result};
use packet::{NetworkPacket, NetworkTuple};
use std::{
collections::{
hash_map::Entry::{Occupied, Vacant},
HashMap,
},
time::Duration,
};
use stream::IpStackStream;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
select,
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
};
#[cfg(feature = "log")]
use tracing::{error, trace};
use crate::{
packet::IpStackPacketProtocol,
stream::{IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport},
};
mod error;
mod packet;
pub mod stream;
const DROP_TTL: u8 = 0;
#[cfg(unix)]
const TTL: u8 = 64;
#[cfg(windows)]
const TTL: u8 = 128;
#[cfg(unix)]
const TUN_FLAGS: [u8; 2] = [0x00, 0x00];
#[cfg(any(target_os = "linux", target_os = "android"))]
const TUN_PROTO_IP6: [u8; 2] = [0x86, 0xdd];
#[cfg(any(target_os = "linux", target_os = "android"))]
const TUN_PROTO_IP4: [u8; 2] = [0x08, 0x00];
#[cfg(any(target_os = "macos", target_os = "ios"))]
const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A];
#[cfg(any(target_os = "macos", target_os = "ios"))]
const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02];
pub struct IpStackConfig {
pub mtu: u16,
pub packet_information: bool,
pub tcp_timeout: Duration,
pub udp_timeout: Duration,
}
impl Default for IpStackConfig {
fn default() -> Self {
IpStackConfig {
mtu: u16::MAX,
packet_information: false,
tcp_timeout: Duration::from_secs(60),
udp_timeout: Duration::from_secs(30),
}
}
}
impl IpStackConfig {
pub fn tcp_timeout(&mut self, timeout: Duration) {
self.tcp_timeout = timeout;
}
pub fn udp_timeout(&mut self, timeout: Duration) {
self.udp_timeout = timeout;
}
pub fn mtu(&mut self, mtu: u16) {
self.mtu = mtu;
}
pub fn packet_information(&mut self, packet_information: bool) {
self.packet_information = packet_information;
}
}
pub struct IpStack {
accept_receiver: UnboundedReceiver<IpStackStream>,
}
impl IpStack {
pub fn new<D>(config: IpStackConfig, mut device: D) -> IpStack
where
D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
{
let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();
tokio::spawn(async move {
let mut streams: HashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = HashMap::new();
let mut buffer = [0u8; u16::MAX as usize];
let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
loop {
// dbg!(streams.len());
select! {
Ok(n) = device.read(&mut buffer) => {
let offset = if config.packet_information && cfg!(unix) {4} else {0};
let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else {
accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?;
continue;
};
if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
accept_sender.send(IpStackStream::UnknownTransport(IpStackUnknownTransport::new(packet.src_addr().ip(),packet.dst_addr().ip(),packet.payload,&packet.ip,config.mtu,pkt_sender.clone())))?;
continue;
}
match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
// let t = packet.transport_protocol();
if let Err(_x) = entry.get().send(packet){
#[cfg(feature = "log")]
trace!("{}", _x);
// match t{
// IpStackPacketProtocol::Tcp(_t) => {
// // dbg!(t.flags());
// }
// IpStackPacketProtocol::Udp => {
// // dbg!("udp");
// }
// IpStackPacketProtocol::Unknown => {
// // dbg!("unknown");
// }
// }
}
}
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{
Ok(stream) => {
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
}
Err(_e) => {
#[cfg(feature = "log")]
error!("{}", _e);
}
}
}
IpStackPacketProtocol::Udp => {
let stream = IpStackUdpStream::new(packet.src_addr(),packet.dst_addr(),packet.payload, pkt_sender.clone(),config.mtu,config.udp_timeout);
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Udp(stream))?;
}
IpStackPacketProtocol::Unknown => {
unreachable!()
}
}
}
}
}
Some(packet) = pkt_receiver.recv() => {
if packet.ttl() == 0{
streams.remove(&packet.reverse_network_tuple());
continue;
}
#[allow(unused_mut)]
let Ok(mut packet_byte) = packet.to_bytes() else{
#[cfg(feature = "log")]
trace!("to_bytes error");
continue;
};
#[cfg(unix)]
if config.packet_information {
if packet.src_addr().is_ipv4(){
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
} else{
packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
}
}
device.write_all(&packet_byte).await?;
// device.flush().await.unwrap();
}
}
}
#[allow(unreachable_code)]
Ok::<(), IpStackError>(())
});
IpStack { accept_receiver }
}
pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
if let Some(s) = self.accept_receiver.recv().await {
Ok(s)
} else {
Err(IpStackError::AcceptError)
}
}
}

217
libs/ipstack/src/packet.rs Normal file
View File

@ -0,0 +1,217 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use etherparse::{Ethernet2Header, IpHeader, PacketHeaders, TcpHeader, UdpHeader, WriteError};
use tracing::debug;
use crate::error::IpStackError;
#[derive(Eq, Hash, PartialEq, Debug)]
pub struct NetworkTuple {
pub src: SocketAddr,
pub dst: SocketAddr,
pub tcp: bool,
}
pub mod tcp_flags {
pub const CWR: u8 = 0b10000000;
pub const ECE: u8 = 0b01000000;
pub const URG: u8 = 0b00100000;
pub const ACK: u8 = 0b00010000;
pub const PSH: u8 = 0b00001000;
pub const RST: u8 = 0b00000100;
pub const SYN: u8 = 0b00000010;
pub const FIN: u8 = 0b00000001;
}
pub(crate) enum IpStackPacketProtocol {
Tcp(TcpPacket),
Unknown,
Udp,
}
pub(crate) enum TransportHeader {
Tcp(TcpHeader),
Udp(UdpHeader),
Unknown,
}
pub struct NetworkPacket {
pub(crate) ip: IpHeader,
pub(crate) transport: TransportHeader,
pub(crate) payload: Vec<u8>,
}
impl NetworkPacket {
pub fn parse(buf: &[u8]) -> Result<Self, IpStackError> {
debug!("read: {:?}", buf);
let p = PacketHeaders::from_ethernet_slice(buf).map_err(|_| IpStackError::InvalidPacket)?;
let ip = p.ip.ok_or(IpStackError::InvalidPacket)?;
let transport = match p.transport {
Some(etherparse::TransportHeader::Tcp(h)) => TransportHeader::Tcp(h),
Some(etherparse::TransportHeader::Udp(u)) => TransportHeader::Udp(u),
_ => TransportHeader::Unknown,
};
let payload = if let TransportHeader::Unknown = transport {
buf[ip.header_len()..].to_vec()
} else {
p.payload.to_vec()
};
Ok(NetworkPacket {
ip,
transport,
payload,
})
}
pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol {
match self.transport {
TransportHeader::Udp(_) => IpStackPacketProtocol::Udp,
TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()),
_ => IpStackPacketProtocol::Unknown,
}
}
pub fn src_addr(&self) -> SocketAddr {
let port = match &self.transport {
TransportHeader::Udp(udp) => udp.source_port,
TransportHeader::Tcp(tcp) => tcp.source_port,
_ => 0,
};
match &self.ip {
IpHeader::Version4(ip, _) => {
SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.source)), port)
}
IpHeader::Version6(ip, _) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.source)), port)
}
}
}
pub fn dst_addr(&self) -> SocketAddr {
let port = match &self.transport {
TransportHeader::Udp(udp) => udp.destination_port,
TransportHeader::Tcp(tcp) => tcp.destination_port,
_ => 0,
};
match &self.ip {
IpHeader::Version4(ip, _) => {
SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.destination)), port)
}
IpHeader::Version6(ip, _) => {
SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.destination)), port)
}
}
}
pub fn network_tuple(&self) -> NetworkTuple {
NetworkTuple {
src: self.src_addr(),
dst: self.dst_addr(),
tcp: matches!(self.transport, TransportHeader::Tcp(_)),
}
}
pub fn reverse_network_tuple(&self) -> NetworkTuple {
NetworkTuple {
src: self.dst_addr(),
dst: self.src_addr(),
tcp: matches!(self.transport, TransportHeader::Tcp(_)),
}
}
pub fn to_bytes(&self) -> Result<Vec<u8>, IpStackError> {
let mut buf = Vec::new();
let header = Ethernet2Header {
source: [255; 6],
destination: [255; 6],
ether_type: 0x0800,
};
header.write(&mut buf).map_err(IpStackError::IoError)?;
self.ip
.write(&mut buf)
.map_err(IpStackError::PacketWriteError)?;
match self.transport {
TransportHeader::Tcp(ref h) => h
.write(&mut buf)
.map_err(WriteError::from)
.map_err(IpStackError::PacketWriteError)?,
TransportHeader::Udp(ref h) => {
h.write(&mut buf).map_err(IpStackError::PacketWriteError)?
}
_ => {}
};
// self.transport
// .write(&mut buf)
// .map_err(IpStackError::PacketWriteError)?;
buf.extend_from_slice(&self.payload);
debug!("write: {:?}", buf);
Ok(buf)
}
pub fn ttl(&self) -> u8 {
match &self.ip {
IpHeader::Version4(ip, _) => ip.time_to_live,
IpHeader::Version6(ip, _) => ip.hop_limit,
}
}
}
pub(super) struct TcpPacket {
header: TcpHeader,
}
impl TcpPacket {
pub fn inner(&self) -> &TcpHeader {
&self.header
}
pub fn flags(&self) -> u8 {
let inner = self.inner();
let mut flags = 0;
if inner.cwr {
flags |= tcp_flags::CWR;
}
if inner.ece {
flags |= tcp_flags::ECE;
}
if inner.urg {
flags |= tcp_flags::URG;
}
if inner.ack {
flags |= tcp_flags::ACK;
}
if inner.psh {
flags |= tcp_flags::PSH;
}
if inner.rst {
flags |= tcp_flags::RST;
}
if inner.syn {
flags |= tcp_flags::SYN;
}
if inner.fin {
flags |= tcp_flags::FIN;
}
flags
}
}
impl From<&TcpHeader> for TcpPacket {
fn from(header: &TcpHeader) -> Self {
TcpPacket {
header: header.clone(),
}
}
}
// pub struct UdpPacket {
// header: UdpHeader,
// }
// impl UdpPacket {
// pub fn inner(&self) -> &UdpHeader {
// &self.header
// }
// }
// impl From<&UdpHeader> for UdpPacket {
// fn from(header: &UdpHeader) -> Self {
// UdpPacket {
// header: header.clone(),
// }
// }
// }

View File

@ -0,0 +1,46 @@
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
pub use self::tcp::IpStackTcpStream;
pub use self::udp::IpStackUdpStream;
pub use self::unknown::IpStackUnknownTransport;
mod tcb;
mod tcp;
mod udp;
mod unknown;
pub enum IpStackStream {
Tcp(IpStackTcpStream),
Udp(IpStackUdpStream),
UnknownTransport(IpStackUnknownTransport),
UnknownNetwork(Vec<u8>),
}
impl IpStackStream {
pub fn local_addr(&self) -> SocketAddr {
match self {
IpStackStream::Tcp(tcp) => tcp.local_addr(),
IpStackStream::Udp(udp) => udp.local_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() {
std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
},
}
}
pub fn peer_addr(&self) -> SocketAddr {
match self {
IpStackStream::Tcp(tcp) => tcp.peer_addr(),
IpStackStream::Udp(udp) => udp.peer_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() {
std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
},
}
}
}

View File

@ -0,0 +1,234 @@
use std::{
collections::BTreeMap,
pin::Pin,
time::{Duration, SystemTime},
};
use tokio::time::Sleep;
use crate::packet::TcpPacket;
const MAX_UNACK: u32 = 1024 * 16; // 16KB
const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB
#[derive(Clone, Debug)]
pub enum TcpState {
SynReceived(bool), // bool means if syn/ack is sent
Established,
FinWait1,
FinWait2(bool), // bool means waiting for ack
Closed,
}
#[derive(Clone, Debug)]
pub(super) enum PacketStatus {
WindowUpdate,
Invalid,
RetransmissionRequest,
NewPacket,
Ack,
KeepAlive,
}
pub(super) struct Tcb {
pub(super) seq: u32,
pub(super) retransmission: Option<u32>,
pub(super) ack: u32,
pub(super) last_ack: u32,
pub(super) timeout: Pin<Box<Sleep>>,
tcp_timeout: Duration,
recv_window: u16,
pub(super) send_window: u16,
state: TcpState,
pub(super) avg_send_window: (u64, u64),
pub(super) inflight_packets: Vec<InflightPacket>,
pub(super) unordered_packets: BTreeMap<u32, UnorderedPacket>,
}
impl Tcb {
pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb {
let seq = 100;
Tcb {
seq,
retransmission: None,
ack,
last_ack: seq,
tcp_timeout,
timeout: Box::pin(tokio::time::sleep_until(
tokio::time::Instant::now() + tcp_timeout,
)),
send_window: u16::MAX,
recv_window: 0,
state: TcpState::SynReceived(false),
avg_send_window: (1, 1),
inflight_packets: Vec::new(),
unordered_packets: BTreeMap::new(),
}
}
pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: &[u8]) {
self.inflight_packets
.push(InflightPacket::new(seq, buf.to_vec()));
self.seq = self.seq.wrapping_add(buf.len() as u32);
}
pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: &[u8]) {
if seq < self.ack {
return;
}
self.unordered_packets
.insert(seq, UnorderedPacket::new(buf.to_vec()));
}
pub(super) fn get_available_read_buffer_size(&self) -> usize {
READ_BUFFER_SIZE.saturating_sub(
self.unordered_packets
.iter()
.fold(0, |acc, (_, p)| acc + p.payload.len()),
)
}
pub(super) fn get_unordered_packets(&mut self) -> Option<Vec<u8>> {
// dbg!(self.ack);
// for (seq,_) in self.unordered_packets.iter() {
// dbg!(seq);
// }
self.unordered_packets
.remove(&self.ack)
.map(|p| p.payload.clone())
}
pub(super) fn add_seq_one(&mut self) {
self.seq = self.seq.wrapping_add(1);
}
pub(super) fn get_seq(&self) -> u32 {
self.seq
}
pub(super) fn add_ack(&mut self, add: u32) {
self.ack = self.ack.wrapping_add(add);
}
pub(super) fn get_ack(&self) -> u32 {
self.ack
}
pub(super) fn change_state(&mut self, state: TcpState) {
self.state = state;
}
pub(super) fn get_state(&self) -> &TcpState {
&self.state
}
pub(super) fn change_send_window(&mut self, window: u16) {
let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64)
/ (self.avg_send_window.1 + 1);
self.avg_send_window.0 = avg_send_window;
self.avg_send_window.1 += 1;
self.send_window = window;
}
pub(super) fn get_send_window(&self) -> u16 {
self.send_window
}
pub(super) fn change_recv_window(&mut self, window: u16) {
self.recv_window = window;
}
pub(super) fn get_recv_window(&self) -> u16 {
self.recv_window
}
// #[inline(always)]
// pub(super) fn buffer_size(&self, payload_len: u16) -> u16 {
// match MAX_UNACK - self.inflight_packets.len() as u32 {
// // b if b.saturating_sub(payload_len as u32 + 64) != 0 => payload_len,
// // b if b < 128 && b >= 4 => (b / 2) as u16,
// // b if b < 4 => b as u16,
// // b => (b - 64) as u16,
// b if b >= payload_len as u32 * 2 && b > 0 => payload_len,
// b if b < 4 => b as u16,
// b => (b / 2) as u16,
// }
// }
pub(super) fn check_pkt_type(&self, incoming_packet: &TcpPacket, p: &[u8]) -> PacketStatus {
let received_ack_distance = self
.seq
.wrapping_sub(incoming_packet.inner().acknowledgment_number);
let current_ack_distance = self.seq.wrapping_sub(self.last_ack);
if received_ack_distance > current_ack_distance
|| (incoming_packet.inner().acknowledgment_number != self.seq
&& self
.seq
.saturating_sub(incoming_packet.inner().acknowledgment_number)
== 0)
{
PacketStatus::Invalid
} else if self.last_ack == incoming_packet.inner().acknowledgment_number {
if !p.is_empty() {
PacketStatus::NewPacket
} else if self.send_window == incoming_packet.inner().window_size
&& self.seq != self.last_ack
{
PacketStatus::RetransmissionRequest
} else if self.ack.wrapping_sub(1) == incoming_packet.inner().sequence_number {
PacketStatus::KeepAlive
} else {
PacketStatus::WindowUpdate
}
} else if self.last_ack < incoming_packet.inner().acknowledgment_number {
if !p.is_empty() {
PacketStatus::NewPacket
} else {
PacketStatus::Ack
}
} else {
PacketStatus::Invalid
}
}
pub(super) fn change_last_ack(&mut self, ack: u32) {
self.timeout
.as_mut()
.reset(tokio::time::Instant::now() + self.tcp_timeout);
let distance = ack.wrapping_sub(self.last_ack);
if matches!(self.state, TcpState::Established) {
if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) {
let mut inflight_packet = self.inflight_packets.remove(i);
let distance = ack.wrapping_sub(inflight_packet.seq);
if (distance as usize) < inflight_packet.payload.len() {
inflight_packet.payload.drain(0..distance as usize);
inflight_packet.seq = ack;
self.inflight_packets.push(inflight_packet);
}
}
}
self.last_ack = self.last_ack.wrapping_add(distance);
}
pub fn is_send_buffer_full(&self) -> bool {
self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK
}
}
pub struct InflightPacket {
pub seq: u32,
pub payload: Vec<u8>,
pub send_time: SystemTime,
}
impl InflightPacket {
fn new(seq: u32, payload: Vec<u8>) -> Self {
Self {
seq,
payload,
send_time: SystemTime::now(),
}
}
pub(crate) fn contains(&self, seq: u32) -> bool {
self.seq < seq && self.seq + self.payload.len() as u32 >= seq
}
}
pub struct UnorderedPacket {
pub payload: Vec<u8>,
pub recv_time: SystemTime,
}
impl UnorderedPacket {
pub(crate) fn new(payload: Vec<u8>) -> Self {
Self {
payload,
recv_time: SystemTime::now(),
}
}
}

View File

@ -0,0 +1,509 @@
use crate::{
error::IpStackError,
packet::{tcp_flags, IpStackPacketProtocol, TcpPacket, TransportHeader},
stream::tcb::{Tcb, TcpState},
DROP_TTL, TTL,
};
use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions};
use std::{
cmp,
future::Future,
io::{Error, ErrorKind},
net::SocketAddr,
pin::Pin,
task::Waker,
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Notify,
},
};
#[cfg(feature = "log")]
use tracing::{trace, warn};
use crate::packet::NetworkPacket;
use super::tcb::PacketStatus;
pub struct IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
packet_to_send: Option<NetworkPacket>,
tcb: Tcb,
mtu: u16,
shutdown: Option<Notify>,
write_notify: Option<Waker>,
}
impl IpStackTcpStream {
pub(crate) async fn new(
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
let mut stream = IpStackTcpStream {
src_addr,
dst_addr,
stream_sender,
stream_receiver,
packet_sender: pkt_sender.clone(),
packet_to_send: None,
tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout),
mtu,
shutdown: None,
write_notify: None,
};
if !tcp.inner().syn {
pkt_sender
.send(stream.create_rev_packet(
tcp_flags::RST | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?)
.map_err(|_| IpStackError::InvalidTcpPacket)?;
stream.tcb.change_state(TcpState::Closed);
}
Ok(stream)
}
pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}
fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 {
cmp::min(
self.tcb.get_send_window(),
self.mtu.saturating_sub(ip_header_size + tcp_header_size),
)
}
fn create_rev_packet(
&self,
flags: u8,
ttl: u8,
seq: Option<u32>,
mut payload: Vec<u8>,
) -> Result<NetworkPacket, Error> {
let mut tcp_header = etherparse::TcpHeader::new(
self.dst_addr.port(),
self.src_addr.port(),
seq.unwrap_or(self.tcb.get_seq()),
self.tcb.get_recv_window(),
);
tcp_header.acknowledgment_number = self.tcb.get_ack();
if flags & tcp_flags::SYN != 0 {
tcp_header.syn = true;
}
if flags & tcp_flags::ACK != 0 {
tcp_header.ack = true;
}
if flags & tcp_flags::RST != 0 {
tcp_header.rst = true;
}
if flags & tcp_flags::FIN != 0 {
tcp_header.fin = true;
}
if flags & tcp_flags::PSH != 0 {
tcp_header.psh = true;
}
let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) {
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
let mut ip_h = Ipv4Header::new(0, ttl, 6, dst.octets(), src.octets());
let payload_len =
self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len());
payload.truncate(payload_len as usize);
ip_h.payload_len = payload.len() as u16 + tcp_header.header_len();
ip_h.dont_fragment = true;
etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default())
}
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
let mut ip_h = etherparse::Ipv6Header {
traffic_class: 0,
flow_label: 0,
payload_length: 0,
next_header: 6,
hop_limit: ttl,
source: dst.octets(),
destination: src.octets(),
};
let payload_len =
self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len());
payload.truncate(payload_len as usize);
ip_h.payload_length = payload.len() as u16 + tcp_header.header_len();
etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default())
}
_ => unreachable!(),
};
match ip_header {
etherparse::IpHeader::Version4(ref ip_header, _) => {
tcp_header.checksum = tcp_header
.calc_checksum_ipv4(ip_header, &payload)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
}
etherparse::IpHeader::Version6(ref ip_header, _) => {
tcp_header.checksum = tcp_header
.calc_checksum_ipv6(ip_header, &payload)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
}
}
Ok(NetworkPacket {
ip: ip_header,
transport: TransportHeader::Tcp(tcp_header),
payload,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}
pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
}
impl AsyncRead for IpStackTcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
loop {
if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) {
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
return std::task::Poll::Ready(Ok(()));
}
let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX);
self.tcb.change_recv_window(min);
if matches!(
Pin::new(&mut self.tcb.timeout).poll(cx),
std::task::Poll::Ready(_)
) {
#[cfg(feature = "log")]
trace!("timeout reached for {:?}", self.dst_addr);
self.packet_sender
.send(self.create_rev_packet(
tcp_flags::RST | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?)
.map_err(|_| ErrorKind::UnexpectedEof)?;
return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
}
if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) {
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::SYN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::SynReceived(true));
}
if let Some(packet) = self.packet_to_send.take() {
self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
if matches!(self.tcb.get_state(), TcpState::Closed) {
if let Some(shutdown) = self.shutdown.take() {
shutdown.notify_one();
}
return std::task::Poll::Ready(Ok(()));
}
}
if let Some(b) = self.tcb.get_unordered_packets() {
self.tcb.add_ack(b.len() as u32);
buf.put_slice(&b);
self.packet_sender
.send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
return std::task::Poll::Ready(Ok(()));
}
if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) {
self.tcb.change_state(TcpState::FinWait1);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::FIN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
continue;
}
match self.stream_receiver.poll_recv(cx) {
std::task::Poll::Ready(Some(p)) => {
let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
unreachable!()
};
if t.flags() & tcp_flags::RST != 0 {
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
return std::task::Poll::Ready(Err(Error::from(
ErrorKind::ConnectionReset,
)));
}
if matches!(
self.tcb.check_pkt_type(&t, &p.payload),
PacketStatus::Invalid
) {
continue;
}
if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) {
if t.flags() == tcp_flags::ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::Established);
}
} else if matches!(self.tcb.get_state(), TcpState::Established) {
if t.flags() == tcp_flags::ACK {
match self.tcb.check_pkt_type(&t, &p.payload) {
PacketStatus::WindowUpdate => {
self.tcb.change_send_window(t.inner().window_size);
if let Some(ref n) = self.write_notify {
n.wake_by_ref();
self.write_notify = None;
};
continue;
}
PacketStatus::Invalid => continue,
PacketStatus::KeepAlive => {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
continue;
}
PacketStatus::RetransmissionRequest => {
self.tcb.change_send_window(t.inner().window_size);
self.tcb.retransmission = Some(t.inner().acknowledgment_number);
if matches!(
self.as_mut().poll_flush(cx),
std::task::Poll::Pending
) {
return std::task::Poll::Pending;
}
continue;
}
PacketStatus::NewPacket => {
// if t.inner().sequence_number != self.tcb.get_ack() {
// dbg!(t.inner().sequence_number);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// TTL,
// None,
// Vec::new(),
// )?);
// continue;
// }
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.add_unordered_packet(
t.inner().sequence_number,
&p.payload,
);
// buf.put_slice(&p.payload);
// self.tcb.add_ack(p.payload.len() as u32);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// TTL,
// None,
// Vec::new(),
// )?);
self.tcb.change_send_window(t.inner().window_size);
if let Some(ref n) = self.write_notify {
n.wake_by_ref();
self.write_notify = None;
};
continue;
// return std::task::Poll::Ready(Ok(()));
}
PacketStatus::Ack => {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
if let Some(ref n) = self.write_notify {
n.wake_by_ref();
self.write_notify = None;
};
continue;
}
};
}
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
self.tcb.add_ack(1);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::FIN | tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(true));
continue;
}
if t.flags() == (tcp_flags::PSH | tcp_flags::ACK) {
if !matches!(
self.tcb.check_pkt_type(&t, &p.payload),
PacketStatus::NewPacket
) {
continue;
}
self.tcb.change_last_ack(t.inner().acknowledgment_number);
if p.payload.is_empty()
|| self.tcb.get_ack() != t.inner().sequence_number
{
continue;
}
// self.tcb.add_ack(p.payload.len() as u32);
self.tcb.change_send_window(t.inner().window_size);
// buf.put_slice(&p.payload);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// TTL,
// None,
// Vec::new(),
// )?);
// return std::task::Poll::Ready(Ok(()));
self.tcb
.add_unordered_packet(t.inner().sequence_number, &p.payload);
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1) {
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(false));
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait2(true))
&& t.flags() == tcp_flags::ACK
{
self.tcb.change_state(TcpState::FinWait2(false));
}
}
std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())),
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
}
}
impl AsyncWrite for IpStackTcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2
|| self.tcb.is_send_buffer_full()
{
self.write_notify = Some(cx.waker().clone());
return std::task::Poll::Pending;
}
if self.tcb.retransmission.is_some() {
self.write_notify = Some(cx.waker().clone());
if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) {
return std::task::Poll::Pending;
}
}
let packet =
self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?;
let seq = self.tcb.seq;
let payload_len = packet.payload.len();
let payload = packet.payload.clone();
self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
self.tcb.add_inflight_packet(seq, &payload);
std::task::Poll::Ready(Ok(payload_len))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
if let Some(i) = self
.tcb
.retransmission
.and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s))
.and_then(|p| self.tcb.inflight_packets.get(p))
{
let packet = self.create_rev_packet(
tcp_flags::PSH | tcp_flags::ACK,
TTL,
Some(i.seq),
i.payload.to_vec(),
)?;
self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
self.tcb.retransmission = None;
} else if let Some(_i) = self.tcb.retransmission {
#[cfg(feature = "log")]
{
warn!(_i);
warn!(self.tcb.seq);
warn!(self.tcb.last_ack);
warn!(self.tcb.ack);
for p in self.tcb.inflight_packets.iter() {
warn!(p.seq);
warn!("{}", p.payload.len());
}
}
panic!("Please report these values at: https://github.com/narrowlink/ipstack/");
}
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let notified = self.shutdown.get_or_insert(Notify::new()).notified();
match Pin::new(&mut Box::pin(notified)).poll(cx) {
std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Ok(p) = self.create_rev_packet(0, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
}
}
}

View File

@ -0,0 +1,181 @@
use core::task;
use std::{
future::Future,
io::{self, Error, ErrorKind},
net::SocketAddr,
pin::Pin,
task::Poll,
time::Duration,
};
use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
time::Sleep,
};
// use crate::packet::TransportHeader;
use crate::{
packet::{NetworkPacket, TransportHeader},
TTL,
};
pub struct IpStackUdpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
first_paload: Option<Vec<u8>>,
timeout: Pin<Box<Sleep>>,
udp_timeout: Duration,
mtu: u16,
}
impl IpStackUdpStream {
pub fn new(
src_addr: SocketAddr,
dst_addr: SocketAddr,
payload: Vec<u8>,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
udp_timeout: Duration,
) -> Self {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
IpStackUdpStream {
src_addr,
dst_addr,
stream_sender,
stream_receiver,
packet_sender: pkt_sender.clone(),
first_paload: Some(payload),
timeout: Box::pin(tokio::time::sleep_until(
tokio::time::Instant::now() + udp_timeout,
)),
udp_timeout,
mtu,
}
}
pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}
fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> Result<NetworkPacket, Error> {
match (self.dst_addr.ip(), self.src_addr.ip()) {
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
let mut ip_h = Ipv4Header::new(0, ttl, 17, dst.octets(), src.octets());
let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size
payload.truncate(line_buffer as usize);
ip_h.payload_len = payload.len() as u16 + 8; // 8 is udp header size
let udp_header = UdpHeader::with_ipv4_checksum(
self.dst_addr.port(),
self.src_addr.port(),
&ip_h,
&payload,
)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
Ok(NetworkPacket {
ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()),
transport: TransportHeader::Udp(udp_header),
payload,
})
}
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
let mut ip_h = Ipv6Header {
traffic_class: 0,
flow_label: 0,
payload_length: 0,
next_header: 17,
hop_limit: ttl,
source: dst.octets(),
destination: src.octets(),
};
let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size
payload.truncate(line_buffer as usize);
ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size
let udp_header = UdpHeader::with_ipv6_checksum(
self.dst_addr.port(),
self.src_addr.port(),
&ip_h,
&payload,
)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
Ok(NetworkPacket {
ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()),
transport: TransportHeader::Udp(udp_header),
payload,
})
}
_ => unreachable!(),
}
}
pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}
pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
}
impl AsyncRead for IpStackUdpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
if let Some(p) = self.first_paload.take() {
buf.put_slice(&p);
return Poll::Ready(Ok(()));
}
if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) {
return Poll::Ready(Ok(())); // todo: return timeout error
}
let udp_timeout = self.udp_timeout;
match self.stream_receiver.poll_recv(cx) {
Poll::Ready(Some(p)) => {
buf.put_slice(&p.payload);
self.timeout
.as_mut()
.reset(tokio::time::Instant::now() + udp_timeout);
Poll::Ready(Ok(()))
}
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for IpStackUdpStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<Result<usize, io::Error>> {
let udp_timeout = self.udp_timeout;
self.timeout
.as_mut()
.reset(tokio::time::Instant::now() + udp_timeout);
let packet = self.create_rev_packet(TTL, buf.to_vec())?;
let payload_len = packet.payload.len();
self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
std::task::Poll::Ready(Ok(payload_len))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}

View File

@ -0,0 +1,111 @@
use std::{io::Error, mem, net::IpAddr};
use etherparse::{IpHeader, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header};
use tokio::sync::mpsc::UnboundedSender;
use crate::{
packet::{NetworkPacket, TransportHeader},
TTL,
};
pub struct IpStackUnknownTransport {
src_addr: IpAddr,
dst_addr: IpAddr,
payload: Vec<u8>,
protocol: u8,
mtu: u16,
packet_sender: UnboundedSender<NetworkPacket>,
}
impl IpStackUnknownTransport {
pub fn new(
src_addr: IpAddr,
dst_addr: IpAddr,
payload: Vec<u8>,
ip: &IpHeader,
mtu: u16,
packet_sender: UnboundedSender<NetworkPacket>,
) -> Self {
let protocol = match ip {
IpHeader::Version4(ip, _) => ip.protocol,
IpHeader::Version6(ip, _) => ip.next_header,
};
IpStackUnknownTransport {
src_addr,
dst_addr,
payload,
protocol,
mtu,
packet_sender,
}
}
pub fn src_addr(&self) -> IpAddr {
self.src_addr
}
pub fn dst_addr(&self) -> IpAddr {
self.dst_addr
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
pub fn ip_protocol(&self) -> u8 {
self.protocol
}
pub async fn send(&self, mut payload: Vec<u8>) -> Result<(), Error> {
loop {
let packet = self.create_rev_packet(&mut payload)?;
self.packet_sender
.send(packet)
.map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?;
if payload.is_empty() {
return Ok(());
}
}
}
pub fn create_rev_packet(&self, payload: &mut Vec<u8>) -> Result<NetworkPacket, Error> {
match (self.dst_addr, self.src_addr) {
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets());
let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
let p = if payload.len() > line_buffer as usize {
payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
} else {
mem::take(payload)
};
ip_h.payload_len = p.len() as u16;
Ok(NetworkPacket {
ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()),
transport: TransportHeader::Unknown,
payload: p,
})
}
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
let mut ip_h = Ipv6Header {
traffic_class: 0,
flow_label: 0,
payload_length: 0,
next_header: 17,
hop_limit: TTL,
source: dst.octets(),
destination: src.octets(),
};
let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
payload.truncate(line_buffer as usize);
ip_h.payload_length = payload.len() as u16;
let p = if payload.len() > line_buffer as usize {
payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
} else {
mem::take(payload)
};
Ok(NetworkPacket {
ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()),
transport: TransportHeader::Unknown,
payload: p,
})
}
_ => unreachable!(),
}
}
}