diff --git a/Cargo.toml b/Cargo.toml index c9144de..d8bdbc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "libs/xen/xenclient", "libs/advmac", "libs/loopdev", - "libs/ipstack", "shared", "container", "network", @@ -52,6 +51,8 @@ futures = "0.3.30" ipnetwork = "0.20.0" udp-stream = "0.0.11" smoltcp = "0.11.0" +etherparse = "0.14.2" +async-trait = "0.1.77" [workspace.dependencies.uuid] version = "1.6.1" diff --git a/libs/ipstack/Cargo.toml b/libs/ipstack/Cargo.toml deleted file mode 100644 index e40df38..0000000 --- a/libs/ipstack/Cargo.toml +++ /dev/null @@ -1,47 +0,0 @@ -# This package is from https://github.com/narrowlink/ipstack -# Mycelium maintains an in-tree version because we need to work at the ethernet layer -# rather than the standard tun layer of IP. -[package] -authors = ['Narrowlink '] -description = 'Asynchronous lightweight implementation of TCP/IP stack for Tun device' -name = "ipstack" -version = "0.0.3" -edition = "2021" -license = "Apache-2.0" -repository = 'https://github.com/narrowlink/ipstack' -# homepage = 'https://github.com/narrowlink/ipstack' -readme = "README.md" - -[features] -default = [] -log = ["tracing/log"] - -[dependencies] -tokio = { version = "1.35", features = [ - "sync", - "rt", - "time", - "io-util", - "macros", -], default-features = false } -etherparse = { version = "0.13", default-features = false } -thiserror = { version = "1.0", default-features = false } -tracing = { version = "0.1", default-features = false, features = [ - "log", -], optional = true } - -[dev-dependencies] -clap = { version = "4.4", features = ["derive"] } -udp-stream = { version = "0.0", default-features = false } -tokio = { version = "1.35", features = [ - "rt-multi-thread", -], default-features = false } - -#tun2.rs example -tun2 = { version = "1.0", features = ["async"] } - -#tun_wintun.rs example -[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] -tun = { version = "0.6.1", features = ["async"], default-features = false } -[target.'cfg(target_os = "windows")'.dev-dependencies] -wintun = { version = "0.4", default-features = false } diff --git a/libs/ipstack/LICENSE b/libs/ipstack/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/libs/ipstack/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/libs/ipstack/src/error.rs b/libs/ipstack/src/error.rs deleted file mode 100644 index 2b7f242..0000000 --- a/libs/ipstack/src/error.rs +++ /dev/null @@ -1,30 +0,0 @@ -#[allow(dead_code)] -#[derive(thiserror::Error, Debug)] -pub enum IpStackError { - #[error("The transport protocol is not supported")] - UnsupportedTransportProtocol, - #[error("The packet is invalid")] - InvalidPacket, - #[error("Write error: {0}")] - PacketWriteError(etherparse::WriteError), - #[error("Invalid Tcp packet")] - InvalidTcpPacket, - #[error("IO error: {0}")] - IoError(#[from] std::io::Error), - #[error("Accept Error")] - AcceptError, - - #[error("Send Error {0}")] - SendError(#[from] tokio::sync::mpsc::error::SendError), -} - -impl From for std::io::Error { - fn from(e: IpStackError) -> Self { - match e { - IpStackError::IoError(e) => e, - _ => std::io::Error::new(std::io::ErrorKind::Other, e), - } - } -} - -pub type Result = std::result::Result; diff --git a/libs/ipstack/src/lib.rs b/libs/ipstack/src/lib.rs deleted file mode 100644 index 9139097..0000000 --- a/libs/ipstack/src/lib.rs +++ /dev/null @@ -1,194 +0,0 @@ -pub use error::{IpStackError, Result}; -use packet::{NetworkPacket, NetworkTuple}; -use std::{ - collections::{ - hash_map::Entry::{Occupied, Vacant}, - HashMap, - }, - time::Duration, -}; -use stream::IpStackStream; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, - select, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, -}; -#[cfg(feature = "log")] -use tracing::{error, trace}; - -use crate::{ - packet::IpStackPacketProtocol, - stream::{IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, -}; -mod error; -mod packet; -pub mod stream; - -const DROP_TTL: u8 = 0; - -#[cfg(unix)] -const TTL: u8 = 64; - -#[cfg(windows)] -const TTL: u8 = 128; - -#[cfg(unix)] -const TUN_FLAGS: [u8; 2] = [0x00, 0x00]; - -#[cfg(any(target_os = "linux", target_os = "android"))] -const TUN_PROTO_IP6: [u8; 2] = [0x86, 0xdd]; -#[cfg(any(target_os = "linux", target_os = "android"))] -const TUN_PROTO_IP4: [u8; 2] = [0x08, 0x00]; - -#[cfg(any(target_os = "macos", target_os = "ios"))] -const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A]; -#[cfg(any(target_os = "macos", target_os = "ios"))] -const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02]; - -pub struct IpStackConfig { - pub mtu: u16, - pub packet_information: bool, - pub tcp_timeout: Duration, - pub udp_timeout: Duration, -} - -impl Default for IpStackConfig { - fn default() -> Self { - IpStackConfig { - mtu: u16::MAX, - packet_information: false, - tcp_timeout: Duration::from_secs(60), - udp_timeout: Duration::from_secs(30), - } - } -} - -impl IpStackConfig { - pub fn tcp_timeout(&mut self, timeout: Duration) { - self.tcp_timeout = timeout; - } - pub fn udp_timeout(&mut self, timeout: Duration) { - self.udp_timeout = timeout; - } - pub fn mtu(&mut self, mtu: u16) { - self.mtu = mtu; - } - pub fn packet_information(&mut self, packet_information: bool) { - self.packet_information = packet_information; - } -} - -pub struct IpStack { - accept_receiver: UnboundedReceiver, -} - -impl IpStack { - pub fn new(config: IpStackConfig, mut device: D) -> IpStack - where - D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static, - { - let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - - tokio::spawn(async move { - let mut streams: HashMap> = HashMap::new(); - let mut buffer = [0u8; u16::MAX as usize]; - - let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); - loop { - // dbg!(streams.len()); - select! { - Ok(n) = device.read(&mut buffer) => { - let offset = if config.packet_information && cfg!(unix) {4} else {0}; - let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else { - accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?; - continue; - }; - if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - accept_sender.send(IpStackStream::UnknownTransport(IpStackUnknownTransport::new(packet.src_addr().ip(),packet.dst_addr().ip(),packet.payload,&packet.ip,config.mtu,pkt_sender.clone())))?; - continue; - } - - match streams.entry(packet.network_tuple()){ - Occupied(entry) =>{ - // let t = packet.transport_protocol(); - if let Err(_x) = entry.get().send(packet){ - #[cfg(feature = "log")] - trace!("{}", _x); - // match t{ - // IpStackPacketProtocol::Tcp(_t) => { - // // dbg!(t.flags()); - // } - // IpStackPacketProtocol::Udp => { - // // dbg!("udp"); - // } - // IpStackPacketProtocol::Unknown => { - // // dbg!("unknown"); - // } - // } - - } - } - Vacant(entry) => { - match packet.transport_protocol(){ - IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{ - Ok(stream) => { - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Tcp(stream))?; - } - Err(_e) => { - #[cfg(feature = "log")] - error!("{}", _e); - } - } - } - IpStackPacketProtocol::Udp => { - let stream = IpStackUdpStream::new(packet.src_addr(),packet.dst_addr(),packet.payload, pkt_sender.clone(),config.mtu,config.udp_timeout); - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Udp(stream))?; - } - IpStackPacketProtocol::Unknown => { - unreachable!() - } - } - } - } - } - Some(packet) = pkt_receiver.recv() => { - if packet.ttl() == 0{ - streams.remove(&packet.reverse_network_tuple()); - continue; - } - #[allow(unused_mut)] - let Ok(mut packet_byte) = packet.to_bytes() else{ - #[cfg(feature = "log")] - trace!("to_bytes error"); - continue; - }; - #[cfg(unix)] - if config.packet_information { - if packet.src_addr().is_ipv4(){ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); - } else{ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); - } - } - device.write_all(&packet_byte).await?; - // device.flush().await.unwrap(); - } - } - } - #[allow(unreachable_code)] - Ok::<(), IpStackError>(()) - }); - - IpStack { accept_receiver } - } - pub async fn accept(&mut self) -> Result { - if let Some(s) = self.accept_receiver.recv().await { - Ok(s) - } else { - Err(IpStackError::AcceptError) - } - } -} diff --git a/libs/ipstack/src/packet.rs b/libs/ipstack/src/packet.rs deleted file mode 100644 index d9cb7a0..0000000 --- a/libs/ipstack/src/packet.rs +++ /dev/null @@ -1,214 +0,0 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - -use etherparse::{Ethernet2Header, IpHeader, PacketHeaders, TcpHeader, UdpHeader, WriteError}; - -use crate::error::IpStackError; - -#[derive(Eq, Hash, PartialEq, Debug)] -pub struct NetworkTuple { - pub src: SocketAddr, - pub dst: SocketAddr, - pub tcp: bool, -} -pub mod tcp_flags { - pub const CWR: u8 = 0b10000000; - pub const ECE: u8 = 0b01000000; - pub const URG: u8 = 0b00100000; - pub const ACK: u8 = 0b00010000; - pub const PSH: u8 = 0b00001000; - pub const RST: u8 = 0b00000100; - pub const SYN: u8 = 0b00000010; - pub const FIN: u8 = 0b00000001; -} - -pub(crate) enum IpStackPacketProtocol { - Tcp(TcpPacket), - Unknown, - Udp, -} - -pub(crate) enum TransportHeader { - Tcp(TcpHeader), - Udp(UdpHeader), - Unknown, -} - -pub struct NetworkPacket { - pub(crate) ip: IpHeader, - pub(crate) transport: TransportHeader, - pub(crate) payload: Vec, -} - -impl NetworkPacket { - pub fn parse(buf: &[u8]) -> Result { - let p = PacketHeaders::from_ethernet_slice(buf).map_err(|_| IpStackError::InvalidPacket)?; - let ip = p.ip.ok_or(IpStackError::InvalidPacket)?; - let transport = match p.transport { - Some(etherparse::TransportHeader::Tcp(h)) => TransportHeader::Tcp(h), - Some(etherparse::TransportHeader::Udp(u)) => TransportHeader::Udp(u), - _ => TransportHeader::Unknown, - }; - - let payload = if let TransportHeader::Unknown = transport { - buf[ip.header_len()..].to_vec() - } else { - p.payload.to_vec() - }; - - Ok(NetworkPacket { - ip, - transport, - payload, - }) - } - pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { - match self.transport { - TransportHeader::Udp(_) => IpStackPacketProtocol::Udp, - TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()), - _ => IpStackPacketProtocol::Unknown, - } - } - pub fn src_addr(&self) -> SocketAddr { - let port = match &self.transport { - TransportHeader::Udp(udp) => udp.source_port, - TransportHeader::Tcp(tcp) => tcp.source_port, - _ => 0, - }; - match &self.ip { - IpHeader::Version4(ip, _) => { - SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.source)), port) - } - IpHeader::Version6(ip, _) => { - SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.source)), port) - } - } - } - pub fn dst_addr(&self) -> SocketAddr { - let port = match &self.transport { - TransportHeader::Udp(udp) => udp.destination_port, - TransportHeader::Tcp(tcp) => tcp.destination_port, - _ => 0, - }; - match &self.ip { - IpHeader::Version4(ip, _) => { - SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.destination)), port) - } - IpHeader::Version6(ip, _) => { - SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.destination)), port) - } - } - } - pub fn network_tuple(&self) -> NetworkTuple { - NetworkTuple { - src: self.src_addr(), - dst: self.dst_addr(), - tcp: matches!(self.transport, TransportHeader::Tcp(_)), - } - } - pub fn reverse_network_tuple(&self) -> NetworkTuple { - NetworkTuple { - src: self.dst_addr(), - dst: self.src_addr(), - tcp: matches!(self.transport, TransportHeader::Tcp(_)), - } - } - pub fn to_bytes(&self) -> Result, IpStackError> { - let mut buf = Vec::new(); - let header = Ethernet2Header { - source: [255; 6], - destination: [255; 6], - ether_type: 0x0800, - }; - header.write(&mut buf).map_err(IpStackError::IoError)?; - self.ip - .write(&mut buf) - .map_err(IpStackError::PacketWriteError)?; - match self.transport { - TransportHeader::Tcp(ref h) => h - .write(&mut buf) - .map_err(WriteError::from) - .map_err(IpStackError::PacketWriteError)?, - TransportHeader::Udp(ref h) => { - h.write(&mut buf).map_err(IpStackError::PacketWriteError)? - } - _ => {} - }; - // self.transport - // .write(&mut buf) - // .map_err(IpStackError::PacketWriteError)?; - buf.extend_from_slice(&self.payload); - Ok(buf) - } - pub fn ttl(&self) -> u8 { - match &self.ip { - IpHeader::Version4(ip, _) => ip.time_to_live, - IpHeader::Version6(ip, _) => ip.hop_limit, - } - } -} - -pub(super) struct TcpPacket { - header: TcpHeader, -} - -impl TcpPacket { - pub fn inner(&self) -> &TcpHeader { - &self.header - } - pub fn flags(&self) -> u8 { - let inner = self.inner(); - let mut flags = 0; - if inner.cwr { - flags |= tcp_flags::CWR; - } - if inner.ece { - flags |= tcp_flags::ECE; - } - if inner.urg { - flags |= tcp_flags::URG; - } - if inner.ack { - flags |= tcp_flags::ACK; - } - if inner.psh { - flags |= tcp_flags::PSH; - } - if inner.rst { - flags |= tcp_flags::RST; - } - if inner.syn { - flags |= tcp_flags::SYN; - } - if inner.fin { - flags |= tcp_flags::FIN; - } - - flags - } -} - -impl From<&TcpHeader> for TcpPacket { - fn from(header: &TcpHeader) -> Self { - TcpPacket { - header: header.clone(), - } - } -} - -// pub struct UdpPacket { -// header: UdpHeader, -// } - -// impl UdpPacket { -// pub fn inner(&self) -> &UdpHeader { -// &self.header -// } -// } - -// impl From<&UdpHeader> for UdpPacket { -// fn from(header: &UdpHeader) -> Self { -// UdpPacket { -// header: header.clone(), -// } -// } -// } diff --git a/libs/ipstack/src/stream/mod.rs b/libs/ipstack/src/stream/mod.rs deleted file mode 100644 index 9878f99..0000000 --- a/libs/ipstack/src/stream/mod.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; - -pub use self::tcp::IpStackTcpStream; -pub use self::udp::IpStackUdpStream; -pub use self::unknown::IpStackUnknownTransport; - -mod tcb; -mod tcp; -mod udp; -mod unknown; - -pub enum IpStackStream { - Tcp(IpStackTcpStream), - Udp(IpStackUdpStream), - UnknownTransport(IpStackUnknownTransport), - UnknownNetwork(Vec), -} - -impl IpStackStream { - pub fn local_addr(&self) -> SocketAddr { - match self { - IpStackStream::Tcp(tcp) => tcp.local_addr(), - IpStackStream::Udp(udp) => udp.local_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) - } - IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { - std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), - std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), - }, - } - } - pub fn peer_addr(&self) -> SocketAddr { - match self { - IpStackStream::Tcp(tcp) => tcp.peer_addr(), - IpStackStream::Udp(udp) => udp.peer_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) - } - IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { - std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), - std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), - }, - } - } -} diff --git a/libs/ipstack/src/stream/tcb.rs b/libs/ipstack/src/stream/tcb.rs deleted file mode 100644 index f3adbe2..0000000 --- a/libs/ipstack/src/stream/tcb.rs +++ /dev/null @@ -1,234 +0,0 @@ -use std::{ - collections::BTreeMap, - pin::Pin, - time::{Duration, SystemTime}, -}; - -use tokio::time::Sleep; - -use crate::packet::TcpPacket; - -const MAX_UNACK: u32 = 1024 * 16; // 16KB -const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB - -#[derive(Clone, Debug)] -pub enum TcpState { - SynReceived(bool), // bool means if syn/ack is sent - Established, - FinWait1, - FinWait2(bool), // bool means waiting for ack - Closed, -} -#[derive(Clone, Debug)] -pub(super) enum PacketStatus { - WindowUpdate, - Invalid, - RetransmissionRequest, - NewPacket, - Ack, - KeepAlive, -} - -pub(super) struct Tcb { - pub(super) seq: u32, - pub(super) retransmission: Option, - pub(super) ack: u32, - pub(super) last_ack: u32, - pub(super) timeout: Pin>, - tcp_timeout: Duration, - recv_window: u16, - pub(super) send_window: u16, - state: TcpState, - pub(super) avg_send_window: (u64, u64), - pub(super) inflight_packets: Vec, - pub(super) unordered_packets: BTreeMap, -} - -impl Tcb { - pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { - let seq = 100; - Tcb { - seq, - retransmission: None, - ack, - last_ack: seq, - tcp_timeout, - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + tcp_timeout, - )), - send_window: u16::MAX, - recv_window: 0, - state: TcpState::SynReceived(false), - avg_send_window: (1, 1), - inflight_packets: Vec::new(), - unordered_packets: BTreeMap::new(), - } - } - pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: &[u8]) { - self.inflight_packets - .push(InflightPacket::new(seq, buf.to_vec())); - self.seq = self.seq.wrapping_add(buf.len() as u32); - } - pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: &[u8]) { - if seq < self.ack { - return; - } - self.unordered_packets - .insert(seq, UnorderedPacket::new(buf.to_vec())); - } - pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub( - self.unordered_packets - .iter() - .fold(0, |acc, (_, p)| acc + p.payload.len()), - ) - } - pub(super) fn get_unordered_packets(&mut self) -> Option> { - // dbg!(self.ack); - // for (seq,_) in self.unordered_packets.iter() { - // dbg!(seq); - // } - self.unordered_packets - .remove(&self.ack) - .map(|p| p.payload.clone()) - } - pub(super) fn add_seq_one(&mut self) { - self.seq = self.seq.wrapping_add(1); - } - pub(super) fn get_seq(&self) -> u32 { - self.seq - } - pub(super) fn add_ack(&mut self, add: u32) { - self.ack = self.ack.wrapping_add(add); - } - pub(super) fn get_ack(&self) -> u32 { - self.ack - } - pub(super) fn change_state(&mut self, state: TcpState) { - self.state = state; - } - pub(super) fn get_state(&self) -> &TcpState { - &self.state - } - pub(super) fn change_send_window(&mut self, window: u16) { - let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) - / (self.avg_send_window.1 + 1); - self.avg_send_window.0 = avg_send_window; - self.avg_send_window.1 += 1; - self.send_window = window; - } - pub(super) fn get_send_window(&self) -> u16 { - self.send_window - } - pub(super) fn change_recv_window(&mut self, window: u16) { - self.recv_window = window; - } - pub(super) fn get_recv_window(&self) -> u16 { - self.recv_window - } - // #[inline(always)] - // pub(super) fn buffer_size(&self, payload_len: u16) -> u16 { - // match MAX_UNACK - self.inflight_packets.len() as u32 { - // // b if b.saturating_sub(payload_len as u32 + 64) != 0 => payload_len, - // // b if b < 128 && b >= 4 => (b / 2) as u16, - // // b if b < 4 => b as u16, - // // b => (b - 64) as u16, - // b if b >= payload_len as u32 * 2 && b > 0 => payload_len, - // b if b < 4 => b as u16, - // b => (b / 2) as u16, - // } - // } - - pub(super) fn check_pkt_type(&self, incoming_packet: &TcpPacket, p: &[u8]) -> PacketStatus { - let received_ack_distance = self - .seq - .wrapping_sub(incoming_packet.inner().acknowledgment_number); - - let current_ack_distance = self.seq.wrapping_sub(self.last_ack); - if received_ack_distance > current_ack_distance - || (incoming_packet.inner().acknowledgment_number != self.seq - && self - .seq - .saturating_sub(incoming_packet.inner().acknowledgment_number) - == 0) - { - PacketStatus::Invalid - } else if self.last_ack == incoming_packet.inner().acknowledgment_number { - if !p.is_empty() { - PacketStatus::NewPacket - } else if self.send_window == incoming_packet.inner().window_size - && self.seq != self.last_ack - { - PacketStatus::RetransmissionRequest - } else if self.ack.wrapping_sub(1) == incoming_packet.inner().sequence_number { - PacketStatus::KeepAlive - } else { - PacketStatus::WindowUpdate - } - } else if self.last_ack < incoming_packet.inner().acknowledgment_number { - if !p.is_empty() { - PacketStatus::NewPacket - } else { - PacketStatus::Ack - } - } else { - PacketStatus::Invalid - } - } - pub(super) fn change_last_ack(&mut self, ack: u32) { - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + self.tcp_timeout); - let distance = ack.wrapping_sub(self.last_ack); - - if matches!(self.state, TcpState::Established) { - if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { - let mut inflight_packet = self.inflight_packets.remove(i); - let distance = ack.wrapping_sub(inflight_packet.seq); - if (distance as usize) < inflight_packet.payload.len() { - inflight_packet.payload.drain(0..distance as usize); - inflight_packet.seq = ack; - self.inflight_packets.push(inflight_packet); - } - } - } - - self.last_ack = self.last_ack.wrapping_add(distance); - } - pub fn is_send_buffer_full(&self) -> bool { - self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK - } -} - -pub struct InflightPacket { - pub seq: u32, - pub payload: Vec, - pub send_time: SystemTime, -} - -impl InflightPacket { - fn new(seq: u32, payload: Vec) -> Self { - Self { - seq, - payload, - send_time: SystemTime::now(), - } - } - pub(crate) fn contains(&self, seq: u32) -> bool { - self.seq < seq && self.seq + self.payload.len() as u32 >= seq - } -} - -pub struct UnorderedPacket { - pub payload: Vec, - pub recv_time: SystemTime, -} - -impl UnorderedPacket { - pub(crate) fn new(payload: Vec) -> Self { - Self { - payload, - recv_time: SystemTime::now(), - } - } -} diff --git a/libs/ipstack/src/stream/tcp.rs b/libs/ipstack/src/stream/tcp.rs deleted file mode 100644 index 9c739ca..0000000 --- a/libs/ipstack/src/stream/tcp.rs +++ /dev/null @@ -1,509 +0,0 @@ -use crate::{ - error::IpStackError, - packet::{tcp_flags, IpStackPacketProtocol, TcpPacket, TransportHeader}, - stream::tcb::{Tcb, TcpState}, - DROP_TTL, TTL, -}; -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions}; -use std::{ - cmp, - future::Future, - io::{Error, ErrorKind}, - net::SocketAddr, - pin::Pin, - task::Waker, - time::Duration, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::{ - mpsc::{self, UnboundedReceiver, UnboundedSender}, - Notify, - }, -}; -#[cfg(feature = "log")] -use tracing::{trace, warn}; - -use crate::packet::NetworkPacket; - -use super::tcb::PacketStatus; - -pub struct IpStackTcpStream { - src_addr: SocketAddr, - dst_addr: SocketAddr, - stream_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, - packet_to_send: Option, - tcb: Tcb, - mtu: u16, - shutdown: Option, - write_notify: Option, -} - -impl IpStackTcpStream { - pub(crate) async fn new( - src_addr: SocketAddr, - dst_addr: SocketAddr, - tcp: TcpPacket, - pkt_sender: UnboundedSender, - mtu: u16, - tcp_timeout: Duration, - ) -> Result { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - - let mut stream = IpStackTcpStream { - src_addr, - dst_addr, - stream_sender, - stream_receiver, - packet_sender: pkt_sender.clone(), - packet_to_send: None, - tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), - mtu, - shutdown: None, - write_notify: None, - }; - if !tcp.inner().syn { - pkt_sender - .send(stream.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) - .map_err(|_| IpStackError::InvalidTcpPacket)?; - stream.tcb.change_state(TcpState::Closed); - } - Ok(stream) - } - pub(crate) fn stream_sender(&self) -> UnboundedSender { - self.stream_sender.clone() - } - fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { - cmp::min( - self.tcb.get_send_window(), - self.mtu.saturating_sub(ip_header_size + tcp_header_size), - ) - } - fn create_rev_packet( - &self, - flags: u8, - ttl: u8, - seq: Option, - mut payload: Vec, - ) -> Result { - let mut tcp_header = etherparse::TcpHeader::new( - self.dst_addr.port(), - self.src_addr.port(), - seq.unwrap_or(self.tcb.get_seq()), - self.tcb.get_recv_window(), - ); - - tcp_header.acknowledgment_number = self.tcb.get_ack(); - if flags & tcp_flags::SYN != 0 { - tcp_header.syn = true; - } - if flags & tcp_flags::ACK != 0 { - tcp_header.ack = true; - } - if flags & tcp_flags::RST != 0 { - tcp_header.rst = true; - } - if flags & tcp_flags::FIN != 0 { - tcp_header.fin = true; - } - if flags & tcp_flags::PSH != 0 { - tcp_header.psh = true; - } - - let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { - (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 6, dst.octets(), src.octets()); - let payload_len = - self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len()); - payload.truncate(payload_len as usize); - ip_h.payload_len = payload.len() as u16 + tcp_header.header_len(); - ip_h.dont_fragment = true; - etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()) - } - (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { - let mut ip_h = etherparse::Ipv6Header { - traffic_class: 0, - flow_label: 0, - payload_length: 0, - next_header: 6, - hop_limit: ttl, - source: dst.octets(), - destination: src.octets(), - }; - let payload_len = - self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len()); - payload.truncate(payload_len as usize); - ip_h.payload_length = payload.len() as u16 + tcp_header.header_len(); - - etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()) - } - _ => unreachable!(), - }; - - match ip_header { - etherparse::IpHeader::Version4(ref ip_header, _) => { - tcp_header.checksum = tcp_header - .calc_checksum_ipv4(ip_header, &payload) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; - } - etherparse::IpHeader::Version6(ref ip_header, _) => { - tcp_header.checksum = tcp_header - .calc_checksum_ipv6(ip_header, &payload) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; - } - } - Ok(NetworkPacket { - ip: ip_header, - transport: TransportHeader::Tcp(tcp_header), - payload, - }) - } - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } -} - -impl AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - loop { - if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) { - self.packet_to_send = - Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::Closed); - return std::task::Poll::Ready(Ok(())); - } - let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX); - self.tcb.change_recv_window(min); - if matches!( - Pin::new(&mut self.tcb.timeout).poll(cx), - std::task::Poll::Ready(_) - ) { - #[cfg(feature = "log")] - trace!("timeout reached for {:?}", self.dst_addr); - self.packet_sender - .send(self.create_rev_packet( - tcp_flags::RST | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?) - .map_err(|_| ErrorKind::UnexpectedEof)?; - return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); - } - - if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::SYN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); - self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::SynReceived(true)); - } - - if let Some(packet) = self.packet_to_send.take() { - self.packet_sender - .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - if matches!(self.tcb.get_state(), TcpState::Closed) { - if let Some(shutdown) = self.shutdown.take() { - shutdown.notify_one(); - } - return std::task::Poll::Ready(Ok(())); - } - } - if let Some(b) = self.tcb.get_unordered_packets() { - self.tcb.add_ack(b.len() as u32); - buf.put_slice(&b); - self.packet_sender - .send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - return std::task::Poll::Ready(Ok(())); - } - if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) { - self.tcb.change_state(TcpState::FinWait1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); - continue; - } - match self.stream_receiver.poll_recv(cx) { - std::task::Poll::Ready(Some(p)) => { - let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { - unreachable!() - }; - if t.flags() & tcp_flags::RST != 0 { - self.packet_to_send = - Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::Closed); - return std::task::Poll::Ready(Err(Error::from( - ErrorKind::ConnectionReset, - ))); - } - if matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::Invalid - ) { - continue; - } - - if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) { - if t.flags() == tcp_flags::ACK { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); - self.tcb.change_state(TcpState::Established); - } - } else if matches!(self.tcb.get_state(), TcpState::Established) { - if t.flags() == tcp_flags::ACK { - match self.tcb.check_pkt_type(&t, &p.payload) { - PacketStatus::WindowUpdate => { - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; - continue; - } - PacketStatus::Invalid => continue, - PacketStatus::KeepAlive => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); - continue; - } - PacketStatus::RetransmissionRequest => { - self.tcb.change_send_window(t.inner().window_size); - self.tcb.retransmission = Some(t.inner().acknowledgment_number); - if matches!( - self.as_mut().poll_flush(cx), - std::task::Poll::Pending - ) { - return std::task::Poll::Pending; - } - continue; - } - PacketStatus::NewPacket => { - // if t.inner().sequence_number != self.tcb.get_ack() { - // dbg!(t.inner().sequence_number); - // self.packet_to_send = Some(self.create_rev_packet( - // tcp_flags::ACK, - // TTL, - // None, - // Vec::new(), - // )?); - // continue; - // } - - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.add_unordered_packet( - t.inner().sequence_number, - &p.payload, - ); - // buf.put_slice(&p.payload); - // self.tcb.add_ack(p.payload.len() as u32); - // self.packet_to_send = Some(self.create_rev_packet( - // tcp_flags::ACK, - // TTL, - // None, - // Vec::new(), - // )?); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; - continue; - // return std::task::Poll::Ready(Ok(())); - } - PacketStatus::Ack => { - self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb.change_send_window(t.inner().window_size); - if let Some(ref n) = self.write_notify { - n.wake_by_ref(); - self.write_notify = None; - }; - continue; - } - }; - } - if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { - self.tcb.add_ack(1); - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::FIN | tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); - self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::FinWait2(true)); - continue; - } - if t.flags() == (tcp_flags::PSH | tcp_flags::ACK) { - if !matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::NewPacket - ) { - continue; - } - self.tcb.change_last_ack(t.inner().acknowledgment_number); - - if p.payload.is_empty() - || self.tcb.get_ack() != t.inner().sequence_number - { - continue; - } - - // self.tcb.add_ack(p.payload.len() as u32); - self.tcb.change_send_window(t.inner().window_size); - // buf.put_slice(&p.payload); - // self.packet_to_send = Some(self.create_rev_packet( - // tcp_flags::ACK, - // TTL, - // None, - // Vec::new(), - // )?); - // return std::task::Poll::Ready(Ok(())); - self.tcb - .add_unordered_packet(t.inner().sequence_number, &p.payload); - continue; - } - } else if matches!(self.tcb.get_state(), TcpState::FinWait1) { - if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { - self.packet_to_send = Some(self.create_rev_packet( - tcp_flags::ACK, - TTL, - None, - Vec::new(), - )?); - self.tcb.change_send_window(t.inner().window_size); - self.tcb.add_seq_one(); - self.tcb.change_state(TcpState::FinWait2(false)); - continue; - } - } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) - && t.flags() == tcp_flags::ACK - { - self.tcb.change_state(TcpState::FinWait2(false)); - } - } - std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())), - std::task::Poll::Pending => return std::task::Poll::Pending, - } - } - } -} - -impl AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2 - || self.tcb.is_send_buffer_full() - { - self.write_notify = Some(cx.waker().clone()); - return std::task::Poll::Pending; - } - - if self.tcb.retransmission.is_some() { - self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) { - return std::task::Poll::Pending; - } - } - - let packet = - self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?; - let seq = self.tcb.seq; - let payload_len = packet.payload.len(); - let payload = packet.payload.clone(); - - self.packet_sender - .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - self.tcb.add_inflight_packet(seq, &payload); - - std::task::Poll::Ready(Ok(payload_len)) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if let Some(i) = self - .tcb - .retransmission - .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s)) - .and_then(|p| self.tcb.inflight_packets.get(p)) - { - let packet = self.create_rev_packet( - tcp_flags::PSH | tcp_flags::ACK, - TTL, - Some(i.seq), - i.payload.to_vec(), - )?; - - self.packet_sender - .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - self.tcb.retransmission = None; - } else if let Some(_i) = self.tcb.retransmission { - #[cfg(feature = "log")] - { - warn!(_i); - warn!(self.tcb.seq); - warn!(self.tcb.last_ack); - warn!(self.tcb.ack); - for p in self.tcb.inflight_packets.iter() { - warn!(p.seq); - warn!("{}", p.payload.len()); - } - } - panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); - } - std::task::Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let notified = self.shutdown.get_or_insert(Notify::new()).notified(); - match Pin::new(&mut Box::pin(notified)).poll(cx) { - std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())), - std::task::Poll::Pending => std::task::Poll::Pending, - } - } -} - -impl Drop for IpStackTcpStream { - fn drop(&mut self) { - if let Ok(p) = self.create_rev_packet(0, DROP_TTL, None, Vec::new()) { - _ = self.packet_sender.send(p); - } - } -} diff --git a/libs/ipstack/src/stream/udp.rs b/libs/ipstack/src/stream/udp.rs deleted file mode 100644 index bf7a14d..0000000 --- a/libs/ipstack/src/stream/udp.rs +++ /dev/null @@ -1,181 +0,0 @@ -use core::task; -use std::{ - future::Future, - io::{self, Error, ErrorKind}, - net::SocketAddr, - pin::Pin, - task::Poll, - time::Duration, -}; - -use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, - time::Sleep, -}; -// use crate::packet::TransportHeader; -use crate::{ - packet::{NetworkPacket, TransportHeader}, - TTL, -}; - -pub struct IpStackUdpStream { - src_addr: SocketAddr, - dst_addr: SocketAddr, - stream_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, - first_paload: Option>, - timeout: Pin>, - udp_timeout: Duration, - mtu: u16, -} - -impl IpStackUdpStream { - pub fn new( - src_addr: SocketAddr, - dst_addr: SocketAddr, - payload: Vec, - pkt_sender: UnboundedSender, - mtu: u16, - udp_timeout: Duration, - ) -> Self { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackUdpStream { - src_addr, - dst_addr, - stream_sender, - stream_receiver, - packet_sender: pkt_sender.clone(), - first_paload: Some(payload), - timeout: Box::pin(tokio::time::sleep_until( - tokio::time::Instant::now() + udp_timeout, - )), - udp_timeout, - mtu, - } - } - pub(crate) fn stream_sender(&self) -> UnboundedSender { - self.stream_sender.clone() - } - fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> Result { - match (self.dst_addr.ip(), self.src_addr.ip()) { - (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, 17, dst.octets(), src.octets()); - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size - payload.truncate(line_buffer as usize); - ip_h.payload_len = payload.len() as u16 + 8; // 8 is udp header size - let udp_header = UdpHeader::with_ipv4_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; - Ok(NetworkPacket { - ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()), - transport: TransportHeader::Udp(udp_header), - payload, - }) - } - (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { - let mut ip_h = Ipv6Header { - traffic_class: 0, - flow_label: 0, - payload_length: 0, - next_header: 17, - hop_limit: ttl, - source: dst.octets(), - destination: src.octets(), - }; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size - - payload.truncate(line_buffer as usize); - - ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size - let udp_header = UdpHeader::with_ipv6_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; - Ok(NetworkPacket { - ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()), - transport: TransportHeader::Udp(udp_header), - payload, - }) - } - _ => unreachable!(), - } - } - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } -} - -impl AsyncRead for IpStackUdpStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> task::Poll> { - if let Some(p) = self.first_paload.take() { - buf.put_slice(&p); - return Poll::Ready(Ok(())); - } - if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) { - return Poll::Ready(Ok(())); // todo: return timeout error - } - - let udp_timeout = self.udp_timeout; - match self.stream_receiver.poll_recv(cx) { - Poll::Ready(Some(p)) => { - buf.put_slice(&p.payload); - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); - Poll::Ready(Ok(())) - } - Poll::Ready(None) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, - } - } -} - -impl AsyncWrite for IpStackUdpStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - buf: &[u8], - ) -> task::Poll> { - let udp_timeout = self.udp_timeout; - self.timeout - .as_mut() - .reset(tokio::time::Instant::now() + udp_timeout); - let packet = self.create_rev_packet(TTL, buf.to_vec())?; - let payload_len = packet.payload.len(); - self.packet_sender - .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - std::task::Poll::Ready(Ok(payload_len)) - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> task::Poll> { - Poll::Ready(Ok(())) - } -} diff --git a/libs/ipstack/src/stream/unknown.rs b/libs/ipstack/src/stream/unknown.rs deleted file mode 100644 index 4f5fae3..0000000 --- a/libs/ipstack/src/stream/unknown.rs +++ /dev/null @@ -1,111 +0,0 @@ -use std::{io::Error, mem, net::IpAddr}; - -use etherparse::{IpHeader, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header}; -use tokio::sync::mpsc::UnboundedSender; - -use crate::{ - packet::{NetworkPacket, TransportHeader}, - TTL, -}; - -pub struct IpStackUnknownTransport { - src_addr: IpAddr, - dst_addr: IpAddr, - payload: Vec, - protocol: u8, - mtu: u16, - packet_sender: UnboundedSender, -} - -impl IpStackUnknownTransport { - pub fn new( - src_addr: IpAddr, - dst_addr: IpAddr, - payload: Vec, - ip: &IpHeader, - mtu: u16, - packet_sender: UnboundedSender, - ) -> Self { - let protocol = match ip { - IpHeader::Version4(ip, _) => ip.protocol, - IpHeader::Version6(ip, _) => ip.next_header, - }; - IpStackUnknownTransport { - src_addr, - dst_addr, - payload, - protocol, - mtu, - packet_sender, - } - } - pub fn src_addr(&self) -> IpAddr { - self.src_addr - } - pub fn dst_addr(&self) -> IpAddr { - self.dst_addr - } - pub fn payload(&self) -> &[u8] { - &self.payload - } - pub fn ip_protocol(&self) -> u8 { - self.protocol - } - pub async fn send(&self, mut payload: Vec) -> Result<(), Error> { - loop { - let packet = self.create_rev_packet(&mut payload)?; - self.packet_sender - .send(packet) - .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?; - if payload.is_empty() { - return Ok(()); - } - } - } - - pub fn create_rev_packet(&self, payload: &mut Vec) -> Result { - match (self.dst_addr, self.src_addr) { - (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()); - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); - - let p = if payload.len() > line_buffer as usize { - payload.drain(0..line_buffer as usize).collect::>() - } else { - mem::take(payload) - }; - ip_h.payload_len = p.len() as u16; - Ok(NetworkPacket { - ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()), - transport: TransportHeader::Unknown, - payload: p, - }) - } - (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { - let mut ip_h = Ipv6Header { - traffic_class: 0, - flow_label: 0, - payload_length: 0, - next_header: 17, - hop_limit: TTL, - source: dst.octets(), - destination: src.octets(), - }; - let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); - payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16; - let p = if payload.len() > line_buffer as usize { - payload.drain(0..line_buffer as usize).collect::>() - } else { - mem::take(payload) - }; - Ok(NetworkPacket { - ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()), - transport: TransportHeader::Unknown, - payload: p, - }) - } - _ => unreachable!(), - } - } -} diff --git a/network/Cargo.toml b/network/Cargo.toml index 6528bea..b715386 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -16,14 +16,12 @@ futures = { workspace = true } libc = { workspace = true } udp-stream = { workspace = true } smoltcp = { workspace = true } +etherparse = { workspace = true } +async-trait = { workspace = true } [dependencies.advmac] path = "../libs/advmac" -[dependencies.ipstack] -path = "../libs/ipstack" -features = ["log"] - [lib] path = "src/lib.rs" diff --git a/network/src/backend.rs b/network/src/backend.rs index 3e17a4b..d41909b 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -1,36 +1,75 @@ -use crate::raw_socket::{AsyncRawSocket, RawSocket}; +use crate::chandev::ChannelDevice; +use crate::nat::NatRouter; +use crate::proxynat::ProxyNatHandlerFactory; +use crate::raw_socket::AsyncRawSocket; use advmac::MacAddr6; use anyhow::{anyhow, Result}; -use futures::channel::oneshot; -use futures::{try_join, TryStreamExt}; -use ipstack::stream::IpStackStream; -use log::{debug, warn}; +use futures::TryStreamExt; +use log::warn; use smoltcp::iface::{Config, Interface, SocketSet}; use smoltcp::time::Instant; use smoltcp::wire::{HardwareAddress, IpCidr}; -use std::os::fd::AsRawFd; use std::str::FromStr; -use std::thread; use std::time::Duration; -use tokio::net::TcpStream; -use udp_stream::UdpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::select; +use tokio::sync::mpsc::{channel, Receiver}; -pub trait NetworkSlice { - async fn run(&self) -> Result<()>; +#[derive(Clone)] +pub struct NetworkBackend { + network: String, + interface: String, } -pub struct NetworkBackend { - pub interface: String, - local: LocalNetworkSlice, - internet: InternetNetworkSlice, +enum NetworkStackSelect<'a> { + Receive(&'a [u8]), + Send(Option>), +} + +struct NetworkStack<'a> { + tx: Receiver>, + kdev: AsyncRawSocket, + udev: ChannelDevice, + interface: Interface, + sockets: SocketSet<'a>, + router: NatRouter, +} + +impl NetworkStack<'_> { + async fn poll(&mut self, receive_buffer: &mut [u8]) -> Result<()> { + let what = select! { + x = self.tx.recv() => NetworkStackSelect::Send(x), + x = self.kdev.read(receive_buffer) => NetworkStackSelect::Receive(&receive_buffer[0..x?]), + }; + + match what { + NetworkStackSelect::Send(packet) => { + if let Some(packet) = packet { + self.kdev.write_all(&packet).await? + } + } + + NetworkStackSelect::Receive(packet) => { + if let Err(error) = self.router.process(packet).await { + warn!("router failed to process packet: {}", error); + } + + self.udev.rx = Some(packet.to_vec()); + let timestamp = Instant::now(); + self.interface + .poll(timestamp, &mut self.udev, &mut self.sockets); + } + } + + Ok(()) + } } impl NetworkBackend { pub fn new(network: &str, interface: &str) -> Result { Ok(Self { + network: network.to_string(), interface: interface.to_string(), - local: LocalNetworkSlice::new(network, interface)?, - internet: InternetNetworkSlice::new(interface)?, }) } @@ -56,116 +95,41 @@ impl NetworkBackend { Ok(()) } - pub async fn run(&mut self) -> Result<()> { - try_join!(self.local.run(), self.internet.run()).map(|_| ()) - } -} - -#[derive(Clone)] -struct LocalNetworkSlice { - network: String, - interface: String, -} - -impl LocalNetworkSlice { - fn new(network: &str, interface: &str) -> Result { - Ok(Self { - network: network.to_string(), - interface: interface.to_string(), - }) + pub async fn run(&self) -> Result<()> { + let mut stack = self.create_network_stack()?; + let mut buffer = vec![0u8; 1500]; + loop { + stack.poll(&mut buffer).await?; + } } - fn run_blocking(&self) -> Result<()> { + fn create_network_stack(&self) -> Result { + let proxy = Box::new(ProxyNatHandlerFactory::new()); let address = IpCidr::from_str(&self.network) .map_err(|_| anyhow!("failed to parse cidr: {}", self.network))?; let addresses: Vec = vec![address]; - let mut socket = RawSocket::new(&self.interface)?; + let kdev = AsyncRawSocket::bind(&self.interface)?; + let (sender, receiver) = channel::>(4); + let mut udev = ChannelDevice::new(1500, sender); let mac = MacAddr6::random(); - let mac = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress(mac.to_array())); + let mac = smoltcp::wire::EthernetAddress(mac.to_array()); + let nat = NatRouter::new(proxy, mac); + let mac = HardwareAddress::Ethernet(mac); let config = Config::new(mac); - let mut iface = Interface::new(config, &mut socket, Instant::now()); + let mut iface = Interface::new(config, &mut udev, Instant::now()); iface.update_ip_addrs(|addrs| { addrs .extend_from_slice(&addresses) .expect("failed to set ip addresses"); }); - - let mut sockets = SocketSet::new(vec![]); - let fd = socket.as_raw_fd(); - loop { - let timestamp = Instant::now(); - iface.poll(timestamp, &mut socket, &mut sockets); - smoltcp::phy::wait(fd, iface.poll_delay(timestamp, &sockets))?; - } - } -} - -impl NetworkSlice for LocalNetworkSlice { - async fn run(&self) -> Result<()> { - let (tx, rx) = oneshot::channel(); - let me = self.clone(); - thread::spawn(move || { - let _ = tx.send(me.run_blocking()); - }); - rx.await? - } -} - -struct InternetNetworkSlice { - interface: String, -} - -impl InternetNetworkSlice { - pub fn new(interface: &str) -> Result { - Ok(Self { - interface: interface.to_string(), + let sockets = SocketSet::new(vec![]); + Ok(NetworkStack { + tx: receiver, + kdev, + udev, + interface: iface, + sockets, + router: nat, }) } - - async fn process_stream(stream: IpStackStream) { - match stream { - IpStackStream::Tcp(mut tcp) => { - debug!("tcp: {}", tcp.peer_addr()); - if let Ok(mut stream) = TcpStream::connect(tcp.peer_addr()).await { - let _ = tokio::io::copy_bidirectional(&mut tcp, &mut stream).await; - } else { - warn!("failed to connect to tcp address: {}", tcp.peer_addr()); - } - } - - IpStackStream::Udp(mut udp) => { - debug!("udp: {}", udp.peer_addr()); - if let Ok(mut stream) = UdpStream::connect(udp.peer_addr()).await { - let _ = tokio::io::copy_bidirectional(&mut stream, &mut udp).await; - } else { - warn!("failed to connect to udp address: {}", udp.peer_addr()); - } - } - - IpStackStream::UnknownTransport(u) => { - debug!("unknown transport: {}", u.dst_addr()); - } - - IpStackStream::UnknownNetwork(packet) => { - debug!("unknown network: {:?}", packet); - } - } - } -} - -impl NetworkSlice for InternetNetworkSlice { - async fn run(&self) -> Result<()> { - let mut config = ipstack::IpStackConfig::default(); - config.mtu(1500); - config.tcp_timeout(std::time::Duration::from_secs(60)); - config.udp_timeout(std::time::Duration::from_secs(10)); - - let socket = AsyncRawSocket::bind(&self.interface)?; - let mut stack = ipstack::IpStack::new(config, socket); - - while let Ok(stream) = stack.accept().await { - tokio::spawn(InternetNetworkSlice::process_stream(stream)); - } - Ok(()) - } } diff --git a/network/src/chandev.rs b/network/src/chandev.rs new file mode 100644 index 0000000..1aac5ab --- /dev/null +++ b/network/src/chandev.rs @@ -0,0 +1,73 @@ +use log::warn; +// Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/channelized_smoltcp_device.rs +use smoltcp::phy::{Checksum, Device}; +use tokio::sync::mpsc::Sender; + +pub struct ChannelDevice { + pub mtu: usize, + pub tx: Sender>, + pub rx: Option>, +} + +impl ChannelDevice { + pub fn new(mtu: usize, tx: Sender>) -> Self { + Self { mtu, tx, rx: None } + } +} + +pub struct RxToken(pub Vec); + +impl Device for ChannelDevice { + type RxToken<'a> = RxToken where Self: 'a; + type TxToken<'a> = &'a mut ChannelDevice where Self: 'a; + + fn receive( + &mut self, + _timestamp: smoltcp::time::Instant, + ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + self.rx.take().map(|x| (RxToken(x), self)) + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + if self.tx.capacity() == 0 { + warn!("ran out of transmission capacity"); + return None; + } + Some(self) + } + + fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities { + let mut capabilities = smoltcp::phy::DeviceCapabilities::default(); + capabilities.medium = smoltcp::phy::Medium::Ethernet; + capabilities.max_transmission_unit = self.mtu; + capabilities.checksum = smoltcp::phy::ChecksumCapabilities::ignored(); + capabilities.checksum.tcp = Checksum::Tx; + capabilities.checksum.ipv4 = Checksum::Tx; + capabilities.checksum.icmpv4 = Checksum::Tx; + capabilities.checksum.icmpv6 = Checksum::Tx; + capabilities + } +} + +impl smoltcp::phy::RxToken for RxToken { + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.0[..]) + } +} + +impl<'a> smoltcp::phy::TxToken for &'a mut ChannelDevice { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = vec![0u8; len]; + let result = f(&mut buffer[..]); + if let Err(error) = self.tx.try_send(buffer) { + warn!("failed to transmit packet: {}", error); + } + result + } +} diff --git a/network/src/lib.rs b/network/src/lib.rs index ec33f35..00e9c9c 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -9,6 +9,9 @@ use tokio::time::sleep; use crate::backend::NetworkBackend; mod backend; +mod chandev; +mod nat; +mod proxynat; mod raw_socket; pub struct NetworkService { diff --git a/network/src/nat.rs b/network/src/nat.rs new file mode 100644 index 0000000..32be5fc --- /dev/null +++ b/network/src/nat.rs @@ -0,0 +1,189 @@ +// Referenced https://github.com/vi/wgslirpy/blob/master/crates/libwgslirpy/src/router.rs as a very interesting way to implement NAT. +// hypha will heavily change how the original code functions however. NatKey was a very useful example of what we need to store in a NAT map. + +use anyhow::Result; +use async_trait::async_trait; +use etherparse::IpNumber; +use etherparse::IpPayloadSlice; +use etherparse::Ipv4Slice; +use etherparse::LinkSlice; +use etherparse::NetSlice; +use etherparse::SlicedPacket; +use etherparse::TcpHeaderSlice; +use etherparse::UdpHeaderSlice; +use smoltcp::wire::EthernetAddress; +use smoltcp::wire::IpAddress; +use smoltcp::wire::IpEndpoint; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::fmt::Display; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub enum NatKey { + Tcp { + client: IpEndpoint, + external: IpEndpoint, + }, + + Udp { + client: IpEndpoint, + external: IpEndpoint, + }, + + Ping { + client: IpAddress, + external: IpAddress, + }, +} + +impl Display for NatKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NatKey::Tcp { client, external } => write!(f, "TCP {client} -> {external}"), + NatKey::Udp { client, external } => write!(f, "UDP {client} -> {external}"), + NatKey::Ping { client, external } => write!(f, "Ping {client} -> {external}"), + } + } +} + +#[async_trait] +pub trait NatHandler: Send { + async fn receive(&self, packet: &[u8]) -> Result<()>; +} + +pub struct NatTable { + inner: HashMap>, +} + +impl NatTable { + pub fn new() -> Self { + Self { + inner: HashMap::new(), + } + } +} + +#[async_trait] +pub trait NatHandlerFactory: Send { + async fn nat(&self, key: NatKey) -> Option>; +} + +pub struct NatRouter { + _mac: EthernetAddress, + factory: Box, + table: NatTable, +} + +impl NatRouter { + pub fn new(factory: Box, mac: EthernetAddress) -> Self { + Self { + _mac: mac, + factory, + table: NatTable::new(), + } + } + + pub async fn process(&mut self, data: &[u8]) -> Result<()> { + let packet = SlicedPacket::from_ethernet(data)?; + let Some(ref link) = packet.link else { + return Ok(()); + }; + + let LinkSlice::Ethernet2(ref ether) = link else { + return Ok(()); + }; + + let _mac = EthernetAddress(ether.destination()); + + let Some(ref net) = packet.net else { + return Ok(()); + }; + + match net { + NetSlice::Ipv4(ipv4) => { + self.process_ipv4(data, ipv4).await?; + } + _ => { + return Ok(()); + } + } + + Ok(()) + } + + pub async fn process_ipv4<'a>(&mut self, data: &[u8], ipv4: &Ipv4Slice<'a>) -> Result<()> { + let source_addr = IpAddress::Ipv4(ipv4.header().source_addr().into()); + let dest_addr = IpAddress::Ipv4(ipv4.header().destination_addr().into()); + + match ipv4.header().protocol() { + IpNumber::TCP => { + self.process_tcp(data, source_addr, dest_addr, ipv4.payload()) + .await?; + } + + IpNumber::UDP => { + self.process_udp(data, source_addr, dest_addr, ipv4.payload()) + .await?; + } + + _ => {} + } + + Ok(()) + } + + pub async fn process_tcp<'a>( + &mut self, + data: &'a [u8], + source_addr: IpAddress, + dest_addr: IpAddress, + payload: &IpPayloadSlice<'a>, + ) -> Result<()> { + let header = TcpHeaderSlice::from_slice(payload.payload)?; + let source = IpEndpoint::new(source_addr, header.source_port()); + let dest = IpEndpoint::new(dest_addr, header.destination_port()); + let key = NatKey::Tcp { + client: source, + external: dest, + }; + self.process_nat(data, key).await?; + Ok(()) + } + + pub async fn process_udp<'a>( + &mut self, + data: &'a [u8], + source_addr: IpAddress, + dest_addr: IpAddress, + payload: &IpPayloadSlice<'a>, + ) -> Result<()> { + let header = UdpHeaderSlice::from_slice(payload.payload)?; + let source = IpEndpoint::new(source_addr, header.source_port()); + let dest = IpEndpoint::new(dest_addr, header.destination_port()); + let key = NatKey::Udp { + client: source, + external: dest, + }; + self.process_nat(data, key).await?; + Ok(()) + } + + pub async fn process_nat(&mut self, data: &[u8], key: NatKey) -> Result<()> { + let handler: Option<&mut Box> = match self.table.inner.entry(key) { + Entry::Occupied(entry) => Some(entry.into_mut()), + Entry::Vacant(entry) => { + if let Some(handler) = self.factory.nat(key).await { + Some(entry.insert(handler)) + } else { + None + } + } + }; + + if let Some(handler) = handler { + handler.receive(data).await?; + } + + Ok(()) + } +} diff --git a/network/src/proxynat.rs b/network/src/proxynat.rs new file mode 100644 index 0000000..44a30ca --- /dev/null +++ b/network/src/proxynat.rs @@ -0,0 +1,135 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use etherparse::{SlicedPacket, UdpSlice}; +use log::{debug, warn}; +use smoltcp::{ + phy::{Checksum, ChecksumCapabilities}, + wire::{IpAddress, IpEndpoint}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, + sync::mpsc::channel, +}; +use tokio::{sync::mpsc::Receiver, sync::mpsc::Sender}; +use udp_stream::UdpStream; + +use crate::nat::{NatHandler, NatHandlerFactory, NatKey}; + +pub struct ProxyNatHandlerFactory {} + +struct ProxyUdpHandler { + external: IpEndpoint, + sender: Sender>, +} + +impl ProxyNatHandlerFactory { + pub fn new() -> Self { + Self {} + } +} + +#[async_trait] +impl NatHandlerFactory for ProxyNatHandlerFactory { + async fn nat(&self, key: NatKey) -> Option> { + debug!("creating proxy nat entry for key: {}", key); + + match key { + NatKey::Udp { + client: _, + external, + } => { + let (sender, receiver) = channel::>(4); + let mut handler = ProxyUdpHandler { external, sender }; + + if let Err(error) = handler.spawn(receiver).await { + warn!("unable to spawn udp proxy handler: {}", error); + None + } else { + Some(Box::new(handler)) + } + } + + _ => None, + } + } +} + +#[async_trait] +impl NatHandler for ProxyUdpHandler { + async fn receive(&self, data: &[u8]) -> Result<()> { + self.sender.try_send(data.to_vec())?; + Ok(()) + } +} + +enum ProxySelect { + External(usize), + Internal(Vec), + Closed, +} + +impl ProxyUdpHandler { + async fn spawn(&mut self, receiver: Receiver>) -> Result<()> { + let external_addr = match self.external.addr { + IpAddress::Ipv4(addr) => SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(addr.0[0], addr.0[1], addr.0[2], addr.0[3])), + self.external.port, + ), + IpAddress::Ipv6(_) => return Err(anyhow!("IPv6 unsupported")), + }; + + let socket = UdpStream::connect(external_addr).await?; + tokio::spawn(async move { + if let Err(error) = ProxyUdpHandler::process(socket, receiver).await { + warn!("processing of udp proxy failed: {}", error); + } + }); + Ok(()) + } + + async fn process(mut socket: UdpStream, mut receiver: Receiver>) -> Result<()> { + let mut checksum = ChecksumCapabilities::ignored(); + checksum.udp = Checksum::Tx; + checksum.ipv4 = Checksum::Tx; + checksum.tcp = Checksum::Tx; + + let mut external_buffer = vec![0u8; 2048]; + + loop { + let selection = select! { + x = receiver.recv() => if let Some(data) = x { + ProxySelect::Internal(data) + } else { + ProxySelect::Closed + }, + x = socket.read(&mut external_buffer) => ProxySelect::External(x?), + }; + + match selection { + ProxySelect::External(size) => { + let data = &external_buffer[0..size]; + debug!("UDP from external: {:?}", data); + } + ProxySelect::Internal(data) => { + debug!("udp socket to handle data: {:?}", data); + let packet = SlicedPacket::from_ethernet(&data)?; + let Some(ref net) = packet.net else { + continue; + }; + + let Some(ip) = net.ip_payload_ref() else { + continue; + }; + + let udp = UdpSlice::from_slice(ip.payload)?; + debug!("UDP from internal: {:?}", udp.payload()); + socket.write_all(udp.payload()).await?; + } + ProxySelect::Closed => warn!("UDP socket closed"), + } + } + } +} diff --git a/network/src/raw_socket.rs b/network/src/raw_socket.rs index 09fabd1..92faa47 100644 --- a/network/src/raw_socket.rs +++ b/network/src/raw_socket.rs @@ -1,12 +1,7 @@ use anyhow::Result; use futures::ready; -use log::debug; -use smoltcp::phy::{Device, DeviceCapabilities, Medium}; -use smoltcp::time::Instant; -use std::cell::RefCell; use std::os::unix::io::{AsRawFd, RawFd}; use std::pin::Pin; -use std::rc::Rc; use std::task::{Context, Poll}; use std::{io, mem}; use tokio::io::unix::AsyncFd; @@ -121,107 +116,6 @@ impl Drop for RawSocketHandle { } } -#[derive(Debug)] -pub struct RawSocket { - lower: Rc>, - mtu: usize, -} - -impl AsRawFd for RawSocket { - fn as_raw_fd(&self) -> RawFd { - self.lower.borrow().as_raw_fd() - } -} - -impl RawSocket { - pub fn new(name: &str) -> io::Result { - let mut lower = RawSocketHandle::new(name)?; - lower.bind_interface()?; - let mtu = lower.mtu; - Ok(RawSocket { - lower: Rc::new(RefCell::new(lower)), - mtu, - }) - } -} - -impl Device for RawSocket { - type RxToken<'a> = RxToken - where - Self: 'a; - type TxToken<'a> = TxToken - where - Self: 'a; - - fn capabilities(&self) -> DeviceCapabilities { - let mut capabilities = DeviceCapabilities::default(); - capabilities.medium = Medium::Ethernet; - capabilities.max_transmission_unit = self.mtu; - capabilities - } - - fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let lower = self.lower.borrow_mut(); - let mut buffer = vec![0; self.mtu]; - match lower.recv(&mut buffer[..]) { - Ok(size) => { - buffer.resize(size, 0); - let rx = RxToken { buffer }; - let tx = TxToken { - lower: self.lower.clone(), - }; - Some((rx, tx)) - } - Err(err) if err.kind() == io::ErrorKind::WouldBlock => None, - Err(err) => panic!("{}", err), - } - } - - fn transmit(&mut self, _timestamp: Instant) -> Option> { - Some(TxToken { - lower: self.lower.clone(), - }) - } -} - -#[doc(hidden)] -pub struct RxToken { - buffer: Vec, -} - -impl smoltcp::phy::RxToken for RxToken { - fn consume(mut self, f: F) -> R - where - F: FnOnce(&mut [u8]) -> R, - { - f(&mut self.buffer[..]) - } -} - -#[doc(hidden)] -pub struct TxToken { - lower: Rc>, -} - -impl smoltcp::phy::TxToken for TxToken { - fn consume(self, len: usize, f: F) -> R - where - F: FnOnce(&mut [u8]) -> R, - { - let lower = self.lower.borrow_mut(); - let mut buffer = vec![0; len]; - let result = f(&mut buffer); - match lower.send(&buffer[..]) { - Ok(_) => {} - Err(err) if err.kind() == io::ErrorKind::WouldBlock => { - debug!("phy: tx failed due to WouldBlock") - } - Err(err) => panic!("{}", err), - } - result - } -} - #[repr(C)] #[derive(Debug)] struct Ifreq {