diff --git a/Cargo.toml b/Cargo.toml index 6b8941f..d28f024 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "libs/xen/xenclient", "libs/advmac", "libs/loopdev", + "libs/ipstack", "shared", "container", "network", @@ -49,7 +50,7 @@ rtnetlink = "0.14.1" netlink-packet-route = "0.19.0" futures = "0.3.30" ipnetwork = "0.20.0" -smoltcp = "0.11.0" +udp-stream = "0.0.11" [workspace.dependencies.uuid] version = "1.6.1" diff --git a/container/src/init.rs b/container/src/init.rs index 882006f..c0712e9 100644 --- a/container/src/init.rs +++ b/container/src/init.rs @@ -9,6 +9,7 @@ use oci_spec::image::{Config, ImageConfiguration}; use std::ffi::{CStr, CString}; use std::fs; use std::fs::{File, OpenOptions, Permissions}; +use std::net::Ipv4Addr; use std::os::fd::AsRawFd; use std::os::linux::fs::MetadataExt; use std::os::unix::fs::{chroot, PermissionsExt}; @@ -304,11 +305,25 @@ impl ContainerInit { .execute() .await?; - handle.link().set(link.header.index).up().execute().await?; + handle + .link() + .set(link.header.index) + .arp(false) + .up() + .execute() + .await?; + + handle + .route() + .add() + .v4() + .destination_prefix(Ipv4Addr::new(0, 0, 0, 0), 0) + .output_interface(link.header.index) + .execute() + .await?; } else { warn!("unable to find link named {}", network.link); } - Ok(()) } diff --git a/controller/src/ctl/mod.rs b/controller/src/ctl/mod.rs index fa1b4da..c254646 100644 --- a/controller/src/ctl/mod.rs +++ b/controller/src/ctl/mod.rs @@ -137,6 +137,7 @@ impl Controller { writable: false, }, ], + consoles: vec![], vifs: vec![DomainNetworkInterface { mac: &mac, mtu: 1500, diff --git a/libs/ipstack/Cargo.toml b/libs/ipstack/Cargo.toml new file mode 100644 index 0000000..e40df38 --- /dev/null +++ b/libs/ipstack/Cargo.toml @@ -0,0 +1,47 @@ +# This package is from https://github.com/narrowlink/ipstack +# Mycelium maintains an in-tree version because we need to work at the ethernet layer +# rather than the standard tun layer of IP. +[package] +authors = ['Narrowlink '] +description = 'Asynchronous lightweight implementation of TCP/IP stack for Tun device' +name = "ipstack" +version = "0.0.3" +edition = "2021" +license = "Apache-2.0" +repository = 'https://github.com/narrowlink/ipstack' +# homepage = 'https://github.com/narrowlink/ipstack' +readme = "README.md" + +[features] +default = [] +log = ["tracing/log"] + +[dependencies] +tokio = { version = "1.35", features = [ + "sync", + "rt", + "time", + "io-util", + "macros", +], default-features = false } +etherparse = { version = "0.13", default-features = false } +thiserror = { version = "1.0", default-features = false } +tracing = { version = "0.1", default-features = false, features = [ + "log", +], optional = true } + +[dev-dependencies] +clap = { version = "4.4", features = ["derive"] } +udp-stream = { version = "0.0", default-features = false } +tokio = { version = "1.35", features = [ + "rt-multi-thread", +], default-features = false } + +#tun2.rs example +tun2 = { version = "1.0", features = ["async"] } + +#tun_wintun.rs example +[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] +tun = { version = "0.6.1", features = ["async"], default-features = false } +[target.'cfg(target_os = "windows")'.dev-dependencies] +wintun = { version = "0.4", default-features = false } diff --git a/libs/ipstack/LICENSE b/libs/ipstack/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/libs/ipstack/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/libs/ipstack/src/error.rs b/libs/ipstack/src/error.rs new file mode 100644 index 0000000..2b7f242 --- /dev/null +++ b/libs/ipstack/src/error.rs @@ -0,0 +1,30 @@ +#[allow(dead_code)] +#[derive(thiserror::Error, Debug)] +pub enum IpStackError { + #[error("The transport protocol is not supported")] + UnsupportedTransportProtocol, + #[error("The packet is invalid")] + InvalidPacket, + #[error("Write error: {0}")] + PacketWriteError(etherparse::WriteError), + #[error("Invalid Tcp packet")] + InvalidTcpPacket, + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + #[error("Accept Error")] + AcceptError, + + #[error("Send Error {0}")] + SendError(#[from] tokio::sync::mpsc::error::SendError), +} + +impl From for std::io::Error { + fn from(e: IpStackError) -> Self { + match e { + IpStackError::IoError(e) => e, + _ => std::io::Error::new(std::io::ErrorKind::Other, e), + } + } +} + +pub type Result = std::result::Result; diff --git a/libs/ipstack/src/lib.rs b/libs/ipstack/src/lib.rs new file mode 100644 index 0000000..9139097 --- /dev/null +++ b/libs/ipstack/src/lib.rs @@ -0,0 +1,194 @@ +pub use error::{IpStackError, Result}; +use packet::{NetworkPacket, NetworkTuple}; +use std::{ + collections::{ + hash_map::Entry::{Occupied, Vacant}, + HashMap, + }, + time::Duration, +}; +use stream::IpStackStream; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + select, + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, +}; +#[cfg(feature = "log")] +use tracing::{error, trace}; + +use crate::{ + packet::IpStackPacketProtocol, + stream::{IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, +}; +mod error; +mod packet; +pub mod stream; + +const DROP_TTL: u8 = 0; + +#[cfg(unix)] +const TTL: u8 = 64; + +#[cfg(windows)] +const TTL: u8 = 128; + +#[cfg(unix)] +const TUN_FLAGS: [u8; 2] = [0x00, 0x00]; + +#[cfg(any(target_os = "linux", target_os = "android"))] +const TUN_PROTO_IP6: [u8; 2] = [0x86, 0xdd]; +#[cfg(any(target_os = "linux", target_os = "android"))] +const TUN_PROTO_IP4: [u8; 2] = [0x08, 0x00]; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A]; +#[cfg(any(target_os = "macos", target_os = "ios"))] +const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02]; + +pub struct IpStackConfig { + pub mtu: u16, + pub packet_information: bool, + pub tcp_timeout: Duration, + pub udp_timeout: Duration, +} + +impl Default for IpStackConfig { + fn default() -> Self { + IpStackConfig { + mtu: u16::MAX, + packet_information: false, + tcp_timeout: Duration::from_secs(60), + udp_timeout: Duration::from_secs(30), + } + } +} + +impl IpStackConfig { + pub fn tcp_timeout(&mut self, timeout: Duration) { + self.tcp_timeout = timeout; + } + pub fn udp_timeout(&mut self, timeout: Duration) { + self.udp_timeout = timeout; + } + pub fn mtu(&mut self, mtu: u16) { + self.mtu = mtu; + } + pub fn packet_information(&mut self, packet_information: bool) { + self.packet_information = packet_information; + } +} + +pub struct IpStack { + accept_receiver: UnboundedReceiver, +} + +impl IpStack { + pub fn new(config: IpStackConfig, mut device: D) -> IpStack + where + D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static, + { + let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); + + tokio::spawn(async move { + let mut streams: HashMap> = HashMap::new(); + let mut buffer = [0u8; u16::MAX as usize]; + + let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); + loop { + // dbg!(streams.len()); + select! { + Ok(n) = device.read(&mut buffer) => { + let offset = if config.packet_information && cfg!(unix) {4} else {0}; + let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else { + accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?; + continue; + }; + if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { + accept_sender.send(IpStackStream::UnknownTransport(IpStackUnknownTransport::new(packet.src_addr().ip(),packet.dst_addr().ip(),packet.payload,&packet.ip,config.mtu,pkt_sender.clone())))?; + continue; + } + + match streams.entry(packet.network_tuple()){ + Occupied(entry) =>{ + // let t = packet.transport_protocol(); + if let Err(_x) = entry.get().send(packet){ + #[cfg(feature = "log")] + trace!("{}", _x); + // match t{ + // IpStackPacketProtocol::Tcp(_t) => { + // // dbg!(t.flags()); + // } + // IpStackPacketProtocol::Udp => { + // // dbg!("udp"); + // } + // IpStackPacketProtocol::Unknown => { + // // dbg!("unknown"); + // } + // } + + } + } + Vacant(entry) => { + match packet.transport_protocol(){ + IpStackPacketProtocol::Tcp(h) => { + match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{ + Ok(stream) => { + entry.insert(stream.stream_sender()); + accept_sender.send(IpStackStream::Tcp(stream))?; + } + Err(_e) => { + #[cfg(feature = "log")] + error!("{}", _e); + } + } + } + IpStackPacketProtocol::Udp => { + let stream = IpStackUdpStream::new(packet.src_addr(),packet.dst_addr(),packet.payload, pkt_sender.clone(),config.mtu,config.udp_timeout); + entry.insert(stream.stream_sender()); + accept_sender.send(IpStackStream::Udp(stream))?; + } + IpStackPacketProtocol::Unknown => { + unreachable!() + } + } + } + } + } + Some(packet) = pkt_receiver.recv() => { + if packet.ttl() == 0{ + streams.remove(&packet.reverse_network_tuple()); + continue; + } + #[allow(unused_mut)] + let Ok(mut packet_byte) = packet.to_bytes() else{ + #[cfg(feature = "log")] + trace!("to_bytes error"); + continue; + }; + #[cfg(unix)] + if config.packet_information { + if packet.src_addr().is_ipv4(){ + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); + } else{ + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); + } + } + device.write_all(&packet_byte).await?; + // device.flush().await.unwrap(); + } + } + } + #[allow(unreachable_code)] + Ok::<(), IpStackError>(()) + }); + + IpStack { accept_receiver } + } + pub async fn accept(&mut self) -> Result { + if let Some(s) = self.accept_receiver.recv().await { + Ok(s) + } else { + Err(IpStackError::AcceptError) + } + } +} diff --git a/libs/ipstack/src/packet.rs b/libs/ipstack/src/packet.rs new file mode 100644 index 0000000..509342f --- /dev/null +++ b/libs/ipstack/src/packet.rs @@ -0,0 +1,217 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use etherparse::{Ethernet2Header, IpHeader, PacketHeaders, TcpHeader, UdpHeader, WriteError}; +use tracing::debug; + +use crate::error::IpStackError; + +#[derive(Eq, Hash, PartialEq, Debug)] +pub struct NetworkTuple { + pub src: SocketAddr, + pub dst: SocketAddr, + pub tcp: bool, +} +pub mod tcp_flags { + pub const CWR: u8 = 0b10000000; + pub const ECE: u8 = 0b01000000; + pub const URG: u8 = 0b00100000; + pub const ACK: u8 = 0b00010000; + pub const PSH: u8 = 0b00001000; + pub const RST: u8 = 0b00000100; + pub const SYN: u8 = 0b00000010; + pub const FIN: u8 = 0b00000001; +} + +pub(crate) enum IpStackPacketProtocol { + Tcp(TcpPacket), + Unknown, + Udp, +} + +pub(crate) enum TransportHeader { + Tcp(TcpHeader), + Udp(UdpHeader), + Unknown, +} + +pub struct NetworkPacket { + pub(crate) ip: IpHeader, + pub(crate) transport: TransportHeader, + pub(crate) payload: Vec, +} + +impl NetworkPacket { + pub fn parse(buf: &[u8]) -> Result { + debug!("read: {:?}", buf); + let p = PacketHeaders::from_ethernet_slice(buf).map_err(|_| IpStackError::InvalidPacket)?; + let ip = p.ip.ok_or(IpStackError::InvalidPacket)?; + let transport = match p.transport { + Some(etherparse::TransportHeader::Tcp(h)) => TransportHeader::Tcp(h), + Some(etherparse::TransportHeader::Udp(u)) => TransportHeader::Udp(u), + _ => TransportHeader::Unknown, + }; + + let payload = if let TransportHeader::Unknown = transport { + buf[ip.header_len()..].to_vec() + } else { + p.payload.to_vec() + }; + + Ok(NetworkPacket { + ip, + transport, + payload, + }) + } + pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { + match self.transport { + TransportHeader::Udp(_) => IpStackPacketProtocol::Udp, + TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()), + _ => IpStackPacketProtocol::Unknown, + } + } + pub fn src_addr(&self) -> SocketAddr { + let port = match &self.transport { + TransportHeader::Udp(udp) => udp.source_port, + TransportHeader::Tcp(tcp) => tcp.source_port, + _ => 0, + }; + match &self.ip { + IpHeader::Version4(ip, _) => { + SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.source)), port) + } + IpHeader::Version6(ip, _) => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.source)), port) + } + } + } + pub fn dst_addr(&self) -> SocketAddr { + let port = match &self.transport { + TransportHeader::Udp(udp) => udp.destination_port, + TransportHeader::Tcp(tcp) => tcp.destination_port, + _ => 0, + }; + match &self.ip { + IpHeader::Version4(ip, _) => { + SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.destination)), port) + } + IpHeader::Version6(ip, _) => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.destination)), port) + } + } + } + pub fn network_tuple(&self) -> NetworkTuple { + NetworkTuple { + src: self.src_addr(), + dst: self.dst_addr(), + tcp: matches!(self.transport, TransportHeader::Tcp(_)), + } + } + pub fn reverse_network_tuple(&self) -> NetworkTuple { + NetworkTuple { + src: self.dst_addr(), + dst: self.src_addr(), + tcp: matches!(self.transport, TransportHeader::Tcp(_)), + } + } + pub fn to_bytes(&self) -> Result, IpStackError> { + let mut buf = Vec::new(); + let header = Ethernet2Header { + source: [255; 6], + destination: [255; 6], + ether_type: 0x0800, + }; + header.write(&mut buf).map_err(IpStackError::IoError)?; + self.ip + .write(&mut buf) + .map_err(IpStackError::PacketWriteError)?; + match self.transport { + TransportHeader::Tcp(ref h) => h + .write(&mut buf) + .map_err(WriteError::from) + .map_err(IpStackError::PacketWriteError)?, + TransportHeader::Udp(ref h) => { + h.write(&mut buf).map_err(IpStackError::PacketWriteError)? + } + _ => {} + }; + // self.transport + // .write(&mut buf) + // .map_err(IpStackError::PacketWriteError)?; + buf.extend_from_slice(&self.payload); + debug!("write: {:?}", buf); + Ok(buf) + } + pub fn ttl(&self) -> u8 { + match &self.ip { + IpHeader::Version4(ip, _) => ip.time_to_live, + IpHeader::Version6(ip, _) => ip.hop_limit, + } + } +} + +pub(super) struct TcpPacket { + header: TcpHeader, +} + +impl TcpPacket { + pub fn inner(&self) -> &TcpHeader { + &self.header + } + pub fn flags(&self) -> u8 { + let inner = self.inner(); + let mut flags = 0; + if inner.cwr { + flags |= tcp_flags::CWR; + } + if inner.ece { + flags |= tcp_flags::ECE; + } + if inner.urg { + flags |= tcp_flags::URG; + } + if inner.ack { + flags |= tcp_flags::ACK; + } + if inner.psh { + flags |= tcp_flags::PSH; + } + if inner.rst { + flags |= tcp_flags::RST; + } + if inner.syn { + flags |= tcp_flags::SYN; + } + if inner.fin { + flags |= tcp_flags::FIN; + } + + flags + } +} + +impl From<&TcpHeader> for TcpPacket { + fn from(header: &TcpHeader) -> Self { + TcpPacket { + header: header.clone(), + } + } +} + +// pub struct UdpPacket { +// header: UdpHeader, +// } + +// impl UdpPacket { +// pub fn inner(&self) -> &UdpHeader { +// &self.header +// } +// } + +// impl From<&UdpHeader> for UdpPacket { +// fn from(header: &UdpHeader) -> Self { +// UdpPacket { +// header: header.clone(), +// } +// } +// } diff --git a/libs/ipstack/src/stream/mod.rs b/libs/ipstack/src/stream/mod.rs new file mode 100644 index 0000000..9878f99 --- /dev/null +++ b/libs/ipstack/src/stream/mod.rs @@ -0,0 +1,46 @@ +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; + +pub use self::tcp::IpStackTcpStream; +pub use self::udp::IpStackUdpStream; +pub use self::unknown::IpStackUnknownTransport; + +mod tcb; +mod tcp; +mod udp; +mod unknown; + +pub enum IpStackStream { + Tcp(IpStackTcpStream), + Udp(IpStackUdpStream), + UnknownTransport(IpStackUnknownTransport), + UnknownNetwork(Vec), +} + +impl IpStackStream { + pub fn local_addr(&self) -> SocketAddr { + match self { + IpStackStream::Tcp(tcp) => tcp.local_addr(), + IpStackStream::Udp(udp) => udp.local_addr(), + IpStackStream::UnknownNetwork(_) => { + SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) + } + IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { + std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), + std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), + }, + } + } + pub fn peer_addr(&self) -> SocketAddr { + match self { + IpStackStream::Tcp(tcp) => tcp.peer_addr(), + IpStackStream::Udp(udp) => udp.peer_addr(), + IpStackStream::UnknownNetwork(_) => { + SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) + } + IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { + std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), + std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), + }, + } + } +} diff --git a/libs/ipstack/src/stream/tcb.rs b/libs/ipstack/src/stream/tcb.rs new file mode 100644 index 0000000..f3adbe2 --- /dev/null +++ b/libs/ipstack/src/stream/tcb.rs @@ -0,0 +1,234 @@ +use std::{ + collections::BTreeMap, + pin::Pin, + time::{Duration, SystemTime}, +}; + +use tokio::time::Sleep; + +use crate::packet::TcpPacket; + +const MAX_UNACK: u32 = 1024 * 16; // 16KB +const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB + +#[derive(Clone, Debug)] +pub enum TcpState { + SynReceived(bool), // bool means if syn/ack is sent + Established, + FinWait1, + FinWait2(bool), // bool means waiting for ack + Closed, +} +#[derive(Clone, Debug)] +pub(super) enum PacketStatus { + WindowUpdate, + Invalid, + RetransmissionRequest, + NewPacket, + Ack, + KeepAlive, +} + +pub(super) struct Tcb { + pub(super) seq: u32, + pub(super) retransmission: Option, + pub(super) ack: u32, + pub(super) last_ack: u32, + pub(super) timeout: Pin>, + tcp_timeout: Duration, + recv_window: u16, + pub(super) send_window: u16, + state: TcpState, + pub(super) avg_send_window: (u64, u64), + pub(super) inflight_packets: Vec, + pub(super) unordered_packets: BTreeMap, +} + +impl Tcb { + pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb { + let seq = 100; + Tcb { + seq, + retransmission: None, + ack, + last_ack: seq, + tcp_timeout, + timeout: Box::pin(tokio::time::sleep_until( + tokio::time::Instant::now() + tcp_timeout, + )), + send_window: u16::MAX, + recv_window: 0, + state: TcpState::SynReceived(false), + avg_send_window: (1, 1), + inflight_packets: Vec::new(), + unordered_packets: BTreeMap::new(), + } + } + pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: &[u8]) { + self.inflight_packets + .push(InflightPacket::new(seq, buf.to_vec())); + self.seq = self.seq.wrapping_add(buf.len() as u32); + } + pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: &[u8]) { + if seq < self.ack { + return; + } + self.unordered_packets + .insert(seq, UnorderedPacket::new(buf.to_vec())); + } + pub(super) fn get_available_read_buffer_size(&self) -> usize { + READ_BUFFER_SIZE.saturating_sub( + self.unordered_packets + .iter() + .fold(0, |acc, (_, p)| acc + p.payload.len()), + ) + } + pub(super) fn get_unordered_packets(&mut self) -> Option> { + // dbg!(self.ack); + // for (seq,_) in self.unordered_packets.iter() { + // dbg!(seq); + // } + self.unordered_packets + .remove(&self.ack) + .map(|p| p.payload.clone()) + } + pub(super) fn add_seq_one(&mut self) { + self.seq = self.seq.wrapping_add(1); + } + pub(super) fn get_seq(&self) -> u32 { + self.seq + } + pub(super) fn add_ack(&mut self, add: u32) { + self.ack = self.ack.wrapping_add(add); + } + pub(super) fn get_ack(&self) -> u32 { + self.ack + } + pub(super) fn change_state(&mut self, state: TcpState) { + self.state = state; + } + pub(super) fn get_state(&self) -> &TcpState { + &self.state + } + pub(super) fn change_send_window(&mut self, window: u16) { + let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) + / (self.avg_send_window.1 + 1); + self.avg_send_window.0 = avg_send_window; + self.avg_send_window.1 += 1; + self.send_window = window; + } + pub(super) fn get_send_window(&self) -> u16 { + self.send_window + } + pub(super) fn change_recv_window(&mut self, window: u16) { + self.recv_window = window; + } + pub(super) fn get_recv_window(&self) -> u16 { + self.recv_window + } + // #[inline(always)] + // pub(super) fn buffer_size(&self, payload_len: u16) -> u16 { + // match MAX_UNACK - self.inflight_packets.len() as u32 { + // // b if b.saturating_sub(payload_len as u32 + 64) != 0 => payload_len, + // // b if b < 128 && b >= 4 => (b / 2) as u16, + // // b if b < 4 => b as u16, + // // b => (b - 64) as u16, + // b if b >= payload_len as u32 * 2 && b > 0 => payload_len, + // b if b < 4 => b as u16, + // b => (b / 2) as u16, + // } + // } + + pub(super) fn check_pkt_type(&self, incoming_packet: &TcpPacket, p: &[u8]) -> PacketStatus { + let received_ack_distance = self + .seq + .wrapping_sub(incoming_packet.inner().acknowledgment_number); + + let current_ack_distance = self.seq.wrapping_sub(self.last_ack); + if received_ack_distance > current_ack_distance + || (incoming_packet.inner().acknowledgment_number != self.seq + && self + .seq + .saturating_sub(incoming_packet.inner().acknowledgment_number) + == 0) + { + PacketStatus::Invalid + } else if self.last_ack == incoming_packet.inner().acknowledgment_number { + if !p.is_empty() { + PacketStatus::NewPacket + } else if self.send_window == incoming_packet.inner().window_size + && self.seq != self.last_ack + { + PacketStatus::RetransmissionRequest + } else if self.ack.wrapping_sub(1) == incoming_packet.inner().sequence_number { + PacketStatus::KeepAlive + } else { + PacketStatus::WindowUpdate + } + } else if self.last_ack < incoming_packet.inner().acknowledgment_number { + if !p.is_empty() { + PacketStatus::NewPacket + } else { + PacketStatus::Ack + } + } else { + PacketStatus::Invalid + } + } + pub(super) fn change_last_ack(&mut self, ack: u32) { + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + self.tcp_timeout); + let distance = ack.wrapping_sub(self.last_ack); + + if matches!(self.state, TcpState::Established) { + if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { + let mut inflight_packet = self.inflight_packets.remove(i); + let distance = ack.wrapping_sub(inflight_packet.seq); + if (distance as usize) < inflight_packet.payload.len() { + inflight_packet.payload.drain(0..distance as usize); + inflight_packet.seq = ack; + self.inflight_packets.push(inflight_packet); + } + } + } + + self.last_ack = self.last_ack.wrapping_add(distance); + } + pub fn is_send_buffer_full(&self) -> bool { + self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK + } +} + +pub struct InflightPacket { + pub seq: u32, + pub payload: Vec, + pub send_time: SystemTime, +} + +impl InflightPacket { + fn new(seq: u32, payload: Vec) -> Self { + Self { + seq, + payload, + send_time: SystemTime::now(), + } + } + pub(crate) fn contains(&self, seq: u32) -> bool { + self.seq < seq && self.seq + self.payload.len() as u32 >= seq + } +} + +pub struct UnorderedPacket { + pub payload: Vec, + pub recv_time: SystemTime, +} + +impl UnorderedPacket { + pub(crate) fn new(payload: Vec) -> Self { + Self { + payload, + recv_time: SystemTime::now(), + } + } +} diff --git a/libs/ipstack/src/stream/tcp.rs b/libs/ipstack/src/stream/tcp.rs new file mode 100644 index 0000000..9c739ca --- /dev/null +++ b/libs/ipstack/src/stream/tcp.rs @@ -0,0 +1,509 @@ +use crate::{ + error::IpStackError, + packet::{tcp_flags, IpStackPacketProtocol, TcpPacket, TransportHeader}, + stream::tcb::{Tcb, TcpState}, + DROP_TTL, TTL, +}; +use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions}; +use std::{ + cmp, + future::Future, + io::{Error, ErrorKind}, + net::SocketAddr, + pin::Pin, + task::Waker, + time::Duration, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{ + mpsc::{self, UnboundedReceiver, UnboundedSender}, + Notify, + }, +}; +#[cfg(feature = "log")] +use tracing::{trace, warn}; + +use crate::packet::NetworkPacket; + +use super::tcb::PacketStatus; + +pub struct IpStackTcpStream { + src_addr: SocketAddr, + dst_addr: SocketAddr, + stream_sender: UnboundedSender, + stream_receiver: UnboundedReceiver, + packet_sender: UnboundedSender, + packet_to_send: Option, + tcb: Tcb, + mtu: u16, + shutdown: Option, + write_notify: Option, +} + +impl IpStackTcpStream { + pub(crate) async fn new( + src_addr: SocketAddr, + dst_addr: SocketAddr, + tcp: TcpPacket, + pkt_sender: UnboundedSender, + mtu: u16, + tcp_timeout: Duration, + ) -> Result { + let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + + let mut stream = IpStackTcpStream { + src_addr, + dst_addr, + stream_sender, + stream_receiver, + packet_sender: pkt_sender.clone(), + packet_to_send: None, + tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), + mtu, + shutdown: None, + write_notify: None, + }; + if !tcp.inner().syn { + pkt_sender + .send(stream.create_rev_packet( + tcp_flags::RST | tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?) + .map_err(|_| IpStackError::InvalidTcpPacket)?; + stream.tcb.change_state(TcpState::Closed); + } + Ok(stream) + } + pub(crate) fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } + fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { + cmp::min( + self.tcb.get_send_window(), + self.mtu.saturating_sub(ip_header_size + tcp_header_size), + ) + } + fn create_rev_packet( + &self, + flags: u8, + ttl: u8, + seq: Option, + mut payload: Vec, + ) -> Result { + let mut tcp_header = etherparse::TcpHeader::new( + self.dst_addr.port(), + self.src_addr.port(), + seq.unwrap_or(self.tcb.get_seq()), + self.tcb.get_recv_window(), + ); + + tcp_header.acknowledgment_number = self.tcb.get_ack(); + if flags & tcp_flags::SYN != 0 { + tcp_header.syn = true; + } + if flags & tcp_flags::ACK != 0 { + tcp_header.ack = true; + } + if flags & tcp_flags::RST != 0 { + tcp_header.rst = true; + } + if flags & tcp_flags::FIN != 0 { + tcp_header.fin = true; + } + if flags & tcp_flags::PSH != 0 { + tcp_header.psh = true; + } + + let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { + (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { + let mut ip_h = Ipv4Header::new(0, ttl, 6, dst.octets(), src.octets()); + let payload_len = + self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len()); + payload.truncate(payload_len as usize); + ip_h.payload_len = payload.len() as u16 + tcp_header.header_len(); + ip_h.dont_fragment = true; + etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()) + } + (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { + let mut ip_h = etherparse::Ipv6Header { + traffic_class: 0, + flow_label: 0, + payload_length: 0, + next_header: 6, + hop_limit: ttl, + source: dst.octets(), + destination: src.octets(), + }; + let payload_len = + self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len()); + payload.truncate(payload_len as usize); + ip_h.payload_length = payload.len() as u16 + tcp_header.header_len(); + + etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()) + } + _ => unreachable!(), + }; + + match ip_header { + etherparse::IpHeader::Version4(ref ip_header, _) => { + tcp_header.checksum = tcp_header + .calc_checksum_ipv4(ip_header, &payload) + .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + } + etherparse::IpHeader::Version6(ref ip_header, _) => { + tcp_header.checksum = tcp_header + .calc_checksum_ipv6(ip_header, &payload) + .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + } + } + Ok(NetworkPacket { + ip: ip_header, + transport: TransportHeader::Tcp(tcp_header), + payload, + }) + } + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } +} + +impl AsyncRead for IpStackTcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + loop { + if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) { + self.packet_to_send = + Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); + self.tcb.change_state(TcpState::Closed); + return std::task::Poll::Ready(Ok(())); + } + let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX); + self.tcb.change_recv_window(min); + if matches!( + Pin::new(&mut self.tcb.timeout).poll(cx), + std::task::Poll::Ready(_) + ) { + #[cfg(feature = "log")] + trace!("timeout reached for {:?}", self.dst_addr); + self.packet_sender + .send(self.create_rev_packet( + tcp_flags::RST | tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?) + .map_err(|_| ErrorKind::UnexpectedEof)?; + return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); + } + + if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { + self.packet_to_send = Some(self.create_rev_packet( + tcp_flags::SYN | tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?); + self.tcb.add_seq_one(); + self.tcb.change_state(TcpState::SynReceived(true)); + } + + if let Some(packet) = self.packet_to_send.take() { + self.packet_sender + .send(packet) + .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + if matches!(self.tcb.get_state(), TcpState::Closed) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.notify_one(); + } + return std::task::Poll::Ready(Ok(())); + } + } + if let Some(b) = self.tcb.get_unordered_packets() { + self.tcb.add_ack(b.len() as u32); + buf.put_slice(&b); + self.packet_sender + .send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?) + .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + return std::task::Poll::Ready(Ok(())); + } + if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) { + self.tcb.change_state(TcpState::FinWait1); + self.packet_to_send = Some(self.create_rev_packet( + tcp_flags::FIN | tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?); + continue; + } + match self.stream_receiver.poll_recv(cx) { + std::task::Poll::Ready(Some(p)) => { + let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { + unreachable!() + }; + if t.flags() & tcp_flags::RST != 0 { + self.packet_to_send = + Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); + self.tcb.change_state(TcpState::Closed); + return std::task::Poll::Ready(Err(Error::from( + ErrorKind::ConnectionReset, + ))); + } + if matches!( + self.tcb.check_pkt_type(&t, &p.payload), + PacketStatus::Invalid + ) { + continue; + } + + if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) { + if t.flags() == tcp_flags::ACK { + self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_send_window(t.inner().window_size); + self.tcb.change_state(TcpState::Established); + } + } else if matches!(self.tcb.get_state(), TcpState::Established) { + if t.flags() == tcp_flags::ACK { + match self.tcb.check_pkt_type(&t, &p.payload) { + PacketStatus::WindowUpdate => { + self.tcb.change_send_window(t.inner().window_size); + if let Some(ref n) = self.write_notify { + n.wake_by_ref(); + self.write_notify = None; + }; + continue; + } + PacketStatus::Invalid => continue, + PacketStatus::KeepAlive => { + self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_send_window(t.inner().window_size); + self.packet_to_send = Some(self.create_rev_packet( + tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?); + continue; + } + PacketStatus::RetransmissionRequest => { + self.tcb.change_send_window(t.inner().window_size); + self.tcb.retransmission = Some(t.inner().acknowledgment_number); + if matches!( + self.as_mut().poll_flush(cx), + std::task::Poll::Pending + ) { + return std::task::Poll::Pending; + } + continue; + } + PacketStatus::NewPacket => { + // if t.inner().sequence_number != self.tcb.get_ack() { + // dbg!(t.inner().sequence_number); + // self.packet_to_send = Some(self.create_rev_packet( + // tcp_flags::ACK, + // TTL, + // None, + // Vec::new(), + // )?); + // continue; + // } + + self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.add_unordered_packet( + t.inner().sequence_number, + &p.payload, + ); + // buf.put_slice(&p.payload); + // self.tcb.add_ack(p.payload.len() as u32); + // self.packet_to_send = Some(self.create_rev_packet( + // tcp_flags::ACK, + // TTL, + // None, + // Vec::new(), + // )?); + self.tcb.change_send_window(t.inner().window_size); + if let Some(ref n) = self.write_notify { + n.wake_by_ref(); + self.write_notify = None; + }; + continue; + // return std::task::Poll::Ready(Ok(())); + } + PacketStatus::Ack => { + self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.change_send_window(t.inner().window_size); + if let Some(ref n) = self.write_notify { + n.wake_by_ref(); + self.write_notify = None; + }; + continue; + } + }; + } + if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { + self.tcb.add_ack(1); + self.packet_to_send = Some(self.create_rev_packet( + tcp_flags::FIN | tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?); + self.tcb.add_seq_one(); + self.tcb.change_state(TcpState::FinWait2(true)); + continue; + } + if t.flags() == (tcp_flags::PSH | tcp_flags::ACK) { + if !matches!( + self.tcb.check_pkt_type(&t, &p.payload), + PacketStatus::NewPacket + ) { + continue; + } + self.tcb.change_last_ack(t.inner().acknowledgment_number); + + if p.payload.is_empty() + || self.tcb.get_ack() != t.inner().sequence_number + { + continue; + } + + // self.tcb.add_ack(p.payload.len() as u32); + self.tcb.change_send_window(t.inner().window_size); + // buf.put_slice(&p.payload); + // self.packet_to_send = Some(self.create_rev_packet( + // tcp_flags::ACK, + // TTL, + // None, + // Vec::new(), + // )?); + // return std::task::Poll::Ready(Ok(())); + self.tcb + .add_unordered_packet(t.inner().sequence_number, &p.payload); + continue; + } + } else if matches!(self.tcb.get_state(), TcpState::FinWait1) { + if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { + self.packet_to_send = Some(self.create_rev_packet( + tcp_flags::ACK, + TTL, + None, + Vec::new(), + )?); + self.tcb.change_send_window(t.inner().window_size); + self.tcb.add_seq_one(); + self.tcb.change_state(TcpState::FinWait2(false)); + continue; + } + } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) + && t.flags() == tcp_flags::ACK + { + self.tcb.change_state(TcpState::FinWait2(false)); + } + } + std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => return std::task::Poll::Pending, + } + } + } +} + +impl AsyncWrite for IpStackTcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2 + || self.tcb.is_send_buffer_full() + { + self.write_notify = Some(cx.waker().clone()); + return std::task::Poll::Pending; + } + + if self.tcb.retransmission.is_some() { + self.write_notify = Some(cx.waker().clone()); + if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) { + return std::task::Poll::Pending; + } + } + + let packet = + self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?; + let seq = self.tcb.seq; + let payload_len = packet.payload.len(); + let payload = packet.payload.clone(); + + self.packet_sender + .send(packet) + .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + self.tcb.add_inflight_packet(seq, &payload); + + std::task::Poll::Ready(Ok(payload_len)) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(i) = self + .tcb + .retransmission + .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s)) + .and_then(|p| self.tcb.inflight_packets.get(p)) + { + let packet = self.create_rev_packet( + tcp_flags::PSH | tcp_flags::ACK, + TTL, + Some(i.seq), + i.payload.to_vec(), + )?; + + self.packet_sender + .send(packet) + .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + self.tcb.retransmission = None; + } else if let Some(_i) = self.tcb.retransmission { + #[cfg(feature = "log")] + { + warn!(_i); + warn!(self.tcb.seq); + warn!(self.tcb.last_ack); + warn!(self.tcb.ack); + for p in self.tcb.inflight_packets.iter() { + warn!(p.seq); + warn!("{}", p.payload.len()); + } + } + panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); + } + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let notified = self.shutdown.get_or_insert(Notify::new()).notified(); + match Pin::new(&mut Box::pin(notified)).poll(cx) { + std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +impl Drop for IpStackTcpStream { + fn drop(&mut self) { + if let Ok(p) = self.create_rev_packet(0, DROP_TTL, None, Vec::new()) { + _ = self.packet_sender.send(p); + } + } +} diff --git a/libs/ipstack/src/stream/udp.rs b/libs/ipstack/src/stream/udp.rs new file mode 100644 index 0000000..bf7a14d --- /dev/null +++ b/libs/ipstack/src/stream/udp.rs @@ -0,0 +1,181 @@ +use core::task; +use std::{ + future::Future, + io::{self, Error, ErrorKind}, + net::SocketAddr, + pin::Pin, + task::Poll, + time::Duration, +}; + +use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + time::Sleep, +}; +// use crate::packet::TransportHeader; +use crate::{ + packet::{NetworkPacket, TransportHeader}, + TTL, +}; + +pub struct IpStackUdpStream { + src_addr: SocketAddr, + dst_addr: SocketAddr, + stream_sender: UnboundedSender, + stream_receiver: UnboundedReceiver, + packet_sender: UnboundedSender, + first_paload: Option>, + timeout: Pin>, + udp_timeout: Duration, + mtu: u16, +} + +impl IpStackUdpStream { + pub fn new( + src_addr: SocketAddr, + dst_addr: SocketAddr, + payload: Vec, + pkt_sender: UnboundedSender, + mtu: u16, + udp_timeout: Duration, + ) -> Self { + let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + IpStackUdpStream { + src_addr, + dst_addr, + stream_sender, + stream_receiver, + packet_sender: pkt_sender.clone(), + first_paload: Some(payload), + timeout: Box::pin(tokio::time::sleep_until( + tokio::time::Instant::now() + udp_timeout, + )), + udp_timeout, + mtu, + } + } + pub(crate) fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } + fn create_rev_packet(&self, ttl: u8, mut payload: Vec) -> Result { + match (self.dst_addr.ip(), self.src_addr.ip()) { + (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { + let mut ip_h = Ipv4Header::new(0, ttl, 17, dst.octets(), src.octets()); + let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + payload.truncate(line_buffer as usize); + ip_h.payload_len = payload.len() as u16 + 8; // 8 is udp header size + let udp_header = UdpHeader::with_ipv4_checksum( + self.dst_addr.port(), + self.src_addr.port(), + &ip_h, + &payload, + ) + .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + Ok(NetworkPacket { + ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()), + transport: TransportHeader::Udp(udp_header), + payload, + }) + } + (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { + let mut ip_h = Ipv6Header { + traffic_class: 0, + flow_label: 0, + payload_length: 0, + next_header: 17, + hop_limit: ttl, + source: dst.octets(), + destination: src.octets(), + }; + let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size + + payload.truncate(line_buffer as usize); + + ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size + let udp_header = UdpHeader::with_ipv6_checksum( + self.dst_addr.port(), + self.src_addr.port(), + &ip_h, + &payload, + ) + .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + Ok(NetworkPacket { + ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()), + transport: TransportHeader::Udp(udp_header), + payload, + }) + } + _ => unreachable!(), + } + } + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } +} + +impl AsyncRead for IpStackUdpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> task::Poll> { + if let Some(p) = self.first_paload.take() { + buf.put_slice(&p); + return Poll::Ready(Ok(())); + } + if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) { + return Poll::Ready(Ok(())); // todo: return timeout error + } + + let udp_timeout = self.udp_timeout; + match self.stream_receiver.poll_recv(cx) { + Poll::Ready(Some(p)) => { + buf.put_slice(&p.payload); + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + udp_timeout); + Poll::Ready(Ok(())) + } + Poll::Ready(None) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsyncWrite for IpStackUdpStream { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &[u8], + ) -> task::Poll> { + let udp_timeout = self.udp_timeout; + self.timeout + .as_mut() + .reset(tokio::time::Instant::now() + udp_timeout); + let packet = self.create_rev_packet(TTL, buf.to_vec())?; + let payload_len = packet.payload.len(); + self.packet_sender + .send(packet) + .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + std::task::Poll::Ready(Ok(payload_len)) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> task::Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> task::Poll> { + Poll::Ready(Ok(())) + } +} diff --git a/libs/ipstack/src/stream/unknown.rs b/libs/ipstack/src/stream/unknown.rs new file mode 100644 index 0000000..4f5fae3 --- /dev/null +++ b/libs/ipstack/src/stream/unknown.rs @@ -0,0 +1,111 @@ +use std::{io::Error, mem, net::IpAddr}; + +use etherparse::{IpHeader, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header}; +use tokio::sync::mpsc::UnboundedSender; + +use crate::{ + packet::{NetworkPacket, TransportHeader}, + TTL, +}; + +pub struct IpStackUnknownTransport { + src_addr: IpAddr, + dst_addr: IpAddr, + payload: Vec, + protocol: u8, + mtu: u16, + packet_sender: UnboundedSender, +} + +impl IpStackUnknownTransport { + pub fn new( + src_addr: IpAddr, + dst_addr: IpAddr, + payload: Vec, + ip: &IpHeader, + mtu: u16, + packet_sender: UnboundedSender, + ) -> Self { + let protocol = match ip { + IpHeader::Version4(ip, _) => ip.protocol, + IpHeader::Version6(ip, _) => ip.next_header, + }; + IpStackUnknownTransport { + src_addr, + dst_addr, + payload, + protocol, + mtu, + packet_sender, + } + } + pub fn src_addr(&self) -> IpAddr { + self.src_addr + } + pub fn dst_addr(&self) -> IpAddr { + self.dst_addr + } + pub fn payload(&self) -> &[u8] { + &self.payload + } + pub fn ip_protocol(&self) -> u8 { + self.protocol + } + pub async fn send(&self, mut payload: Vec) -> Result<(), Error> { + loop { + let packet = self.create_rev_packet(&mut payload)?; + self.packet_sender + .send(packet) + .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?; + if payload.is_empty() { + return Ok(()); + } + } + } + + pub fn create_rev_packet(&self, payload: &mut Vec) -> Result { + match (self.dst_addr, self.src_addr) { + (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { + let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()); + let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); + + let p = if payload.len() > line_buffer as usize { + payload.drain(0..line_buffer as usize).collect::>() + } else { + mem::take(payload) + }; + ip_h.payload_len = p.len() as u16; + Ok(NetworkPacket { + ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()), + transport: TransportHeader::Unknown, + payload: p, + }) + } + (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => { + let mut ip_h = Ipv6Header { + traffic_class: 0, + flow_label: 0, + payload_length: 0, + next_header: 17, + hop_limit: TTL, + source: dst.octets(), + destination: src.octets(), + }; + let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); + payload.truncate(line_buffer as usize); + ip_h.payload_length = payload.len() as u16; + let p = if payload.len() > line_buffer as usize { + payload.drain(0..line_buffer as usize).collect::>() + } else { + mem::take(payload) + }; + Ok(NetworkPacket { + ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()), + transport: TransportHeader::Unknown, + payload: p, + }) + } + _ => unreachable!(), + } + } +} diff --git a/libs/xen/xenclient/examples/boot.rs b/libs/xen/xenclient/examples/boot.rs index 299a4ce..10e1806 100644 --- a/libs/xen/xenclient/examples/boot.rs +++ b/libs/xen/xenclient/examples/boot.rs @@ -22,6 +22,7 @@ fn main() -> Result<()> { initrd_path: initrd_path.as_str(), cmdline: "debug elevator=noop", disks: vec![], + consoles: vec![], vifs: vec![], filesystems: vec![], extra_keys: vec![], diff --git a/libs/xen/xenclient/src/lib.rs b/libs/xen/xenclient/src/lib.rs index a12343e..c20a63d 100644 --- a/libs/xen/xenclient/src/lib.rs +++ b/libs/xen/xenclient/src/lib.rs @@ -56,6 +56,9 @@ pub struct DomainNetworkInterface<'a> { pub script: Option<&'a str>, } +#[derive(Debug)] +pub struct DomainConsole {} + #[derive(Debug)] pub struct DomainConfig<'a> { pub backend_domid: u32, @@ -66,6 +69,7 @@ pub struct DomainConfig<'a> { pub initrd_path: &'a str, pub cmdline: &'a str, pub disks: Vec>, + pub consoles: Vec, pub vifs: Vec>, pub filesystems: Vec>, pub extra_keys: Vec<(String, String)>, @@ -348,9 +352,23 @@ impl XenClient { &backend_dom_path, config.backend_domid, domid, - console_evtchn, - console_mfn, + 0, + Some(console_evtchn), + Some(console_mfn), )?; + + for (index, _) in config.consoles.iter().enumerate() { + self.console_device_add( + &dom_path, + &backend_dom_path, + config.backend_domid, + domid, + index + 1, + None, + None, + )?; + } + for (index, disk) in config.disks.iter().enumerate() { self.disk_device_add( &dom_path, @@ -438,35 +456,54 @@ impl XenClient { Ok(()) } + #[allow(clippy::too_many_arguments, clippy::unnecessary_unwrap)] fn console_device_add( &mut self, dom_path: &str, backend_dom_path: &str, backend_domid: u32, domid: u32, - port: u32, - mfn: u64, + index: usize, + port: Option, + mfn: Option, ) -> Result<()> { - let backend_entries = vec![ + let mut backend_entries = vec![ ("frontend-id", domid.to_string()), ("online", "1".to_string()), ("state", "1".to_string()), ("protocol", "vt100".to_string()), ]; - let frontend_entries = vec![ + let mut frontend_entries = vec![ ("backend-id", backend_domid.to_string()), ("limit", "1048576".to_string()), - ("type", "xenconsoled".to_string()), ("output", "pty".to_string()), ("tty", "".to_string()), - ("port", port.to_string()), - ("ring-ref", mfn.to_string()), ]; + if index == 0 { + frontend_entries.push(("type", "xenconsoled".to_string())); + } else { + frontend_entries.push(("type", "ioemu".to_string())); + backend_entries.push(("connection", "pty".to_string())); + backend_entries.push(("output", "pty".to_string())); + } + + if port.is_some() && mfn.is_some() { + frontend_entries.extend_from_slice(&[ + ("port", port.unwrap().to_string()), + ("ring-ref", mfn.unwrap().to_string()), + ]); + } else { + frontend_entries.extend_from_slice(&[ + ("state", "1".to_string()), + ("protocol", "vt100".to_string()), + ]); + } + self.device_add( "console", - 0, + index as u64, dom_path, backend_dom_path, backend_domid, diff --git a/network/Cargo.toml b/network/Cargo.toml index c63e6f4..17b179e 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -13,11 +13,16 @@ rtnetlink = { workspace = true } netlink-packet-route = { workspace = true } tokio = { workspace = true } futures = { workspace = true } -smoltcp = { workspace = true } +libc = { workspace = true } +udp-stream = { workspace = true } [dependencies.advmac] path = "../libs/advmac" +[dependencies.ipstack] +path = "../libs/ipstack" +features = ["log"] + [lib] path = "src/lib.rs" diff --git a/network/src/lib.rs b/network/src/lib.rs index 3f27f40..0b64124 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -1,30 +1,21 @@ -use std::os::fd::AsRawFd; -use std::panic::UnwindSafe; -use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::time::Duration; -use std::{panic, thread}; -use advmac::MacAddr6; use anyhow::{anyhow, Result}; use futures::TryStreamExt; -use log::{error, info, warn}; +use ipstack::stream::IpStackStream; +use log::{debug, error, info, warn}; use netlink_packet_route::link::LinkAttribute; -use smoltcp::iface::{Config, Interface, SocketSet}; -use smoltcp::phy::{self, RawSocket}; -use smoltcp::time::Instant; -use smoltcp::wire::{EthernetAddress, HardwareAddress, IpCidr}; +use raw_socket::{AsyncRawSocket, RawSocket}; +use tokio::net::TcpStream; use tokio::time::sleep; +use udp_stream::UdpStream; + +mod raw_socket; pub struct NetworkBackend { pub interface: String, - pub device: RawSocket, - pub addresses: Vec, } - -unsafe impl Send for NetworkBackend {} -impl UnwindSafe for NetworkBackend {} - pub struct NetworkService { pub network: String, } @@ -36,18 +27,9 @@ impl NetworkService { } impl NetworkBackend { - pub fn new(iface: &str, cidrs: &[&str]) -> Result { - let device = RawSocket::new(iface, smoltcp::phy::Medium::Ethernet)?; - let mut addresses: Vec = Vec::new(); - for cidr in cidrs { - let address = - IpCidr::from_str(cidr).map_err(|_| anyhow!("failed to parse cidr: {}", *cidr))?; - addresses.push(address); - } + pub fn new(iface: &str) -> Result { Ok(NetworkBackend { interface: iface.to_string(), - device, - addresses, }) } @@ -73,34 +55,56 @@ impl NetworkBackend { Ok(()) } - pub fn run(mut self) -> Result<()> { - let result = panic::catch_unwind(move || self.run_maybe_panic()); + pub async fn run(&mut self) -> Result<()> { + let mut config = ipstack::IpStackConfig::default(); + config.mtu(1500); + config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes + config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds - if result.is_err() { - return Err(anyhow!("network backend has terminated")); + let mut socket = RawSocket::new(&self.interface)?; + socket.bind_interface()?; + let socket = AsyncRawSocket::new(socket)?; + let mut stack = ipstack::IpStack::new(config, socket); + + while let Ok(stream) = stack.accept().await { + self.process_stream(stream).await? } - - result.unwrap() + Ok(()) } - fn run_maybe_panic(&mut self) -> Result<()> { - let mac = MacAddr6::random(); - let mac = HardwareAddress::Ethernet(EthernetAddress(mac.to_array())); - let config = Config::new(mac); - let mut iface = Interface::new(config, &mut self.device, Instant::now()); - iface.update_ip_addrs(|addrs| { - addrs - .extend_from_slice(&self.addresses) - .expect("failed to set ip addresses"); - }); + async fn process_stream(&mut self, stream: IpStackStream) -> Result<()> { + match stream { + IpStackStream::Tcp(mut tcp) => { + debug!("tcp: {}", tcp.peer_addr()); + tokio::spawn(async move { + if let Ok(mut stream) = TcpStream::connect(tcp.peer_addr()).await { + let _ = tokio::io::copy_bidirectional(&mut stream, &mut tcp).await; + } else { + warn!("failed to connect to tcp address: {}", tcp.peer_addr()); + } + }); + } - let mut sockets = SocketSet::new(vec![]); - let fd = self.device.as_raw_fd(); - loop { - let timestamp = Instant::now(); - iface.poll(timestamp, &mut self.device, &mut sockets); - phy::wait(fd, iface.poll_delay(timestamp, &sockets))?; + IpStackStream::Udp(mut udp) => { + debug!("udp: {}", udp.peer_addr()); + tokio::spawn(async move { + if let Ok(mut stream) = UdpStream::connect(udp.peer_addr()).await { + let _ = tokio::io::copy_bidirectional(&mut stream, &mut udp).await; + } else { + warn!("failed to connect to udp address: {}", udp.peer_addr()); + } + }); + } + + IpStackStream::UnknownTransport(u) => { + debug!("unknown transport: {}", u.dst_addr()); + } + + IpStackStream::UnknownNetwork(packet) => { + debug!("unknown network: {:?}", packet); + } } + Ok(()) } } @@ -156,15 +160,15 @@ impl NetworkService { spawned: Arc>>, ) -> Result<()> { let interface = interface.to_string(); - let mut network = NetworkBackend::new(&interface, &[&self.network])?; + let mut network = NetworkBackend::new(&interface)?; info!("initializing network backend for interface {}", interface); network.init().await?; tokio::time::sleep(Duration::from_secs(1)).await; info!("spawning network backend for interface {}", interface); - thread::spawn(move || { - if let Err(error) = network.run() { + tokio::spawn(async move { + if let Err(error) = network.run().await { error!( - "failed to run network backend for interface {}: {}", + "network backend for interface {} has been stopped: {}", interface, error ); } diff --git a/network/src/raw_socket.rs b/network/src/raw_socket.rs new file mode 100644 index 0000000..a362369 --- /dev/null +++ b/network/src/raw_socket.rs @@ -0,0 +1,202 @@ +use futures::ready; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, mem}; + +use anyhow::Result; +use tokio::io::unix::AsyncFd; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +const SIOCGIFINDEX: libc::c_ulong = 0x8933; + +#[derive(Debug)] +pub struct RawSocket { + protocol: libc::c_short, + lower: libc::c_int, + ifreq: Ifreq, +} + +impl AsRawFd for RawSocket { + fn as_raw_fd(&self) -> RawFd { + self.lower + } +} + +impl RawSocket { + pub fn new(name: &str) -> io::Result { + let protocol: libc::c_short = 0x0003; + let lower = unsafe { + let lower = libc::socket( + libc::AF_PACKET, + libc::SOCK_RAW | libc::SOCK_NONBLOCK, + protocol.to_be() as i32, + ); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + Ok(RawSocket { + protocol, + lower, + ifreq: ifreq_for(name), + }) + } + + pub fn bind_interface(&mut self) -> io::Result<()> { + let sockaddr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_protocol: self.protocol.to_be() as u16, + sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, SIOCGIFINDEX)?, + sll_hatype: 1, + sll_pkttype: 0, + sll_halen: 6, + sll_addr: [0; 8], + }; + + unsafe { + let res = libc::bind( + self.lower, + &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr, + mem::size_of::() as libc::socklen_t, + ); + if res == -1 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) + } + + pub fn recv(&self, buffer: &mut [u8]) -> io::Result { + unsafe { + let len = libc::recv( + self.lower, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } + + pub fn send(&self, buffer: &[u8]) -> io::Result { + unsafe { + let len = libc::send( + self.lower, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } +} + +impl Drop for RawSocket { + fn drop(&mut self) { + unsafe { + libc::close(self.lower); + } + } +} + +#[repr(C)] +#[derive(Debug)] +struct Ifreq { + ifr_name: [libc::c_char; libc::IF_NAMESIZE], + ifr_data: libc::c_int, /* ifr_ifindex or ifr_mtu */ +} + +fn ifreq_for(name: &str) -> Ifreq { + let mut ifreq = Ifreq { + ifr_name: [0; libc::IF_NAMESIZE], + ifr_data: 0, + }; + for (i, byte) in name.as_bytes().iter().enumerate() { + ifreq.ifr_name[i] = *byte as libc::c_char + } + ifreq +} + +fn ifreq_ioctl( + lower: libc::c_int, + ifreq: &mut Ifreq, + cmd: libc::c_ulong, +) -> io::Result { + unsafe { + let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq); + if res == -1 { + return Err(io::Error::last_os_error()); + } + } + + Ok(ifreq.ifr_data) +} + +pub struct AsyncRawSocket { + inner: AsyncFd, +} + +impl AsyncRawSocket { + pub fn new(socket: RawSocket) -> Result { + Ok(Self { + inner: AsyncFd::new(socket)?, + }) + } +} + +impl AsyncRead for AsyncRawSocket { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + let mut guard = ready!(self.inner.poll_read_ready(cx))?; + + let unfilled = buf.initialize_unfilled(); + match guard.try_io(|inner| inner.get_ref().recv(unfilled)) { + Ok(Ok(len)) => { + buf.advance(len); + return Poll::Ready(Ok(())); + } + Ok(Err(err)) => return Poll::Ready(Err(err)), + Err(_would_block) => continue, + } + } + } +} + +impl AsyncWrite for AsyncRawSocket { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + let mut guard = ready!(self.inner.poll_write_ready(cx))?; + + match guard.try_io(|inner| inner.get_ref().send(buf)) { + Ok(result) => return Poll::Ready(result), + Err(_would_block) => continue, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +}