mirror of
				https://github.com/edera-dev/krata.git
				synced 2025-11-03 23:29:39 +00:00 
			
		
		
		
	hypha: work in progress implementation of outbound internet access
This commit is contained in:
		@ -6,6 +6,7 @@ members = [
 | 
			
		||||
    "libs/xen/xenclient",
 | 
			
		||||
    "libs/advmac",
 | 
			
		||||
    "libs/loopdev",
 | 
			
		||||
    "libs/ipstack",
 | 
			
		||||
    "shared",
 | 
			
		||||
    "container",
 | 
			
		||||
    "network",
 | 
			
		||||
@ -49,7 +50,7 @@ rtnetlink = "0.14.1"
 | 
			
		||||
netlink-packet-route = "0.19.0"
 | 
			
		||||
futures = "0.3.30"
 | 
			
		||||
ipnetwork = "0.20.0"
 | 
			
		||||
smoltcp = "0.11.0"
 | 
			
		||||
udp-stream = "0.0.11"
 | 
			
		||||
 | 
			
		||||
[workspace.dependencies.uuid]
 | 
			
		||||
version = "1.6.1"
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ use oci_spec::image::{Config, ImageConfiguration};
 | 
			
		||||
use std::ffi::{CStr, CString};
 | 
			
		||||
use std::fs;
 | 
			
		||||
use std::fs::{File, OpenOptions, Permissions};
 | 
			
		||||
use std::net::Ipv4Addr;
 | 
			
		||||
use std::os::fd::AsRawFd;
 | 
			
		||||
use std::os::linux::fs::MetadataExt;
 | 
			
		||||
use std::os::unix::fs::{chroot, PermissionsExt};
 | 
			
		||||
@ -304,11 +305,25 @@ impl ContainerInit {
 | 
			
		||||
                .execute()
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
            handle.link().set(link.header.index).up().execute().await?;
 | 
			
		||||
            handle
 | 
			
		||||
                .link()
 | 
			
		||||
                .set(link.header.index)
 | 
			
		||||
                .arp(false)
 | 
			
		||||
                .up()
 | 
			
		||||
                .execute()
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
            handle
 | 
			
		||||
                .route()
 | 
			
		||||
                .add()
 | 
			
		||||
                .v4()
 | 
			
		||||
                .destination_prefix(Ipv4Addr::new(0, 0, 0, 0), 0)
 | 
			
		||||
                .output_interface(link.header.index)
 | 
			
		||||
                .execute()
 | 
			
		||||
                .await?;
 | 
			
		||||
        } else {
 | 
			
		||||
            warn!("unable to find link named {}", network.link);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -137,6 +137,7 @@ impl Controller {
 | 
			
		||||
                    writable: false,
 | 
			
		||||
                },
 | 
			
		||||
            ],
 | 
			
		||||
            consoles: vec![],
 | 
			
		||||
            vifs: vec![DomainNetworkInterface {
 | 
			
		||||
                mac: &mac,
 | 
			
		||||
                mtu: 1500,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										47
									
								
								libs/ipstack/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								libs/ipstack/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
# This package is from https://github.com/narrowlink/ipstack
 | 
			
		||||
# Mycelium maintains an in-tree version because we need to work at the ethernet layer
 | 
			
		||||
# rather than the standard tun layer of IP.
 | 
			
		||||
[package]
 | 
			
		||||
authors = ['Narrowlink <opensource@narrowlink.com>']
 | 
			
		||||
description = 'Asynchronous lightweight implementation of TCP/IP stack for Tun device'
 | 
			
		||||
name = "ipstack"
 | 
			
		||||
version = "0.0.3"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
license = "Apache-2.0"
 | 
			
		||||
repository = 'https://github.com/narrowlink/ipstack'
 | 
			
		||||
# homepage = 'https://github.com/narrowlink/ipstack'
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
 | 
			
		||||
[features]
 | 
			
		||||
default = []
 | 
			
		||||
log = ["tracing/log"]
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
tokio = { version = "1.35", features = [
 | 
			
		||||
    "sync",
 | 
			
		||||
    "rt",
 | 
			
		||||
    "time",
 | 
			
		||||
    "io-util",
 | 
			
		||||
    "macros",
 | 
			
		||||
], default-features = false }
 | 
			
		||||
etherparse = { version = "0.13", default-features = false }
 | 
			
		||||
thiserror = { version = "1.0", default-features = false }
 | 
			
		||||
tracing = { version = "0.1", default-features = false, features = [
 | 
			
		||||
    "log",
 | 
			
		||||
], optional = true }
 | 
			
		||||
 | 
			
		||||
[dev-dependencies]
 | 
			
		||||
clap = { version = "4.4", features = ["derive"] }
 | 
			
		||||
udp-stream = { version = "0.0", default-features = false }
 | 
			
		||||
tokio = { version = "1.35", features = [
 | 
			
		||||
    "rt-multi-thread",
 | 
			
		||||
], default-features = false }
 | 
			
		||||
 | 
			
		||||
#tun2.rs example
 | 
			
		||||
tun2 = { version = "1.0", features = ["async"] }
 | 
			
		||||
 | 
			
		||||
#tun_wintun.rs example
 | 
			
		||||
[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
 | 
			
		||||
tun = { version = "0.6.1", features = ["async"], default-features = false }
 | 
			
		||||
[target.'cfg(target_os = "windows")'.dev-dependencies]
 | 
			
		||||
wintun = { version = "0.4", default-features = false }
 | 
			
		||||
							
								
								
									
										201
									
								
								libs/ipstack/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								libs/ipstack/LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,201 @@
 | 
			
		||||
                                 Apache License
 | 
			
		||||
                           Version 2.0, January 2004
 | 
			
		||||
                        http://www.apache.org/licenses/
 | 
			
		||||
 | 
			
		||||
   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
			
		||||
 | 
			
		||||
   1. Definitions.
 | 
			
		||||
 | 
			
		||||
      "License" shall mean the terms and conditions for use, reproduction,
 | 
			
		||||
      and distribution as defined by Sections 1 through 9 of this document.
 | 
			
		||||
 | 
			
		||||
      "Licensor" shall mean the copyright owner or entity authorized by
 | 
			
		||||
      the copyright owner that is granting the License.
 | 
			
		||||
 | 
			
		||||
      "Legal Entity" shall mean the union of the acting entity and all
 | 
			
		||||
      other entities that control, are controlled by, or are under common
 | 
			
		||||
      control with that entity. For the purposes of this definition,
 | 
			
		||||
      "control" means (i) the power, direct or indirect, to cause the
 | 
			
		||||
      direction or management of such entity, whether by contract or
 | 
			
		||||
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
			
		||||
      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
			
		||||
 | 
			
		||||
      "You" (or "Your") shall mean an individual or Legal Entity
 | 
			
		||||
      exercising permissions granted by this License.
 | 
			
		||||
 | 
			
		||||
      "Source" form shall mean the preferred form for making modifications,
 | 
			
		||||
      including but not limited to software source code, documentation
 | 
			
		||||
      source, and configuration files.
 | 
			
		||||
 | 
			
		||||
      "Object" form shall mean any form resulting from mechanical
 | 
			
		||||
      transformation or translation of a Source form, including but
 | 
			
		||||
      not limited to compiled object code, generated documentation,
 | 
			
		||||
      and conversions to other media types.
 | 
			
		||||
 | 
			
		||||
      "Work" shall mean the work of authorship, whether in Source or
 | 
			
		||||
      Object form, made available under the License, as indicated by a
 | 
			
		||||
      copyright notice that is included in or attached to the work
 | 
			
		||||
      (an example is provided in the Appendix below).
 | 
			
		||||
 | 
			
		||||
      "Derivative Works" shall mean any work, whether in Source or Object
 | 
			
		||||
      form, that is based on (or derived from) the Work and for which the
 | 
			
		||||
      editorial revisions, annotations, elaborations, or other modifications
 | 
			
		||||
      represent, as a whole, an original work of authorship. For the purposes
 | 
			
		||||
      of this License, Derivative Works shall not include works that remain
 | 
			
		||||
      separable from, or merely link (or bind by name) to the interfaces of,
 | 
			
		||||
      the Work and Derivative Works thereof.
 | 
			
		||||
 | 
			
		||||
      "Contribution" shall mean any work of authorship, including
 | 
			
		||||
      the original version of the Work and any modifications or additions
 | 
			
		||||
      to that Work or Derivative Works thereof, that is intentionally
 | 
			
		||||
      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
			
		||||
      or by an individual or Legal Entity authorized to submit on behalf of
 | 
			
		||||
      the copyright owner. For the purposes of this definition, "submitted"
 | 
			
		||||
      means any form of electronic, verbal, or written communication sent
 | 
			
		||||
      to the Licensor or its representatives, including but not limited to
 | 
			
		||||
      communication on electronic mailing lists, source code control systems,
 | 
			
		||||
      and issue tracking systems that are managed by, or on behalf of, the
 | 
			
		||||
      Licensor for the purpose of discussing and improving the Work, but
 | 
			
		||||
      excluding communication that is conspicuously marked or otherwise
 | 
			
		||||
      designated in writing by the copyright owner as "Not a Contribution."
 | 
			
		||||
 | 
			
		||||
      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
			
		||||
      on behalf of whom a Contribution has been received by Licensor and
 | 
			
		||||
      subsequently incorporated within the Work.
 | 
			
		||||
 | 
			
		||||
   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      copyright license to reproduce, prepare Derivative Works of,
 | 
			
		||||
      publicly display, publicly perform, sublicense, and distribute the
 | 
			
		||||
      Work and such Derivative Works in Source or Object form.
 | 
			
		||||
 | 
			
		||||
   3. Grant of Patent License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      (except as stated in this section) patent license to make, have made,
 | 
			
		||||
      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
			
		||||
      where such license applies only to those patent claims licensable
 | 
			
		||||
      by such Contributor that are necessarily infringed by their
 | 
			
		||||
      Contribution(s) alone or by combination of their Contribution(s)
 | 
			
		||||
      with the Work to which such Contribution(s) was submitted. If You
 | 
			
		||||
      institute patent litigation against any entity (including a
 | 
			
		||||
      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
			
		||||
      or a Contribution incorporated within the Work constitutes direct
 | 
			
		||||
      or contributory patent infringement, then any patent licenses
 | 
			
		||||
      granted to You under this License for that Work shall terminate
 | 
			
		||||
      as of the date such litigation is filed.
 | 
			
		||||
 | 
			
		||||
   4. Redistribution. You may reproduce and distribute copies of the
 | 
			
		||||
      Work or Derivative Works thereof in any medium, with or without
 | 
			
		||||
      modifications, and in Source or Object form, provided that You
 | 
			
		||||
      meet the following conditions:
 | 
			
		||||
 | 
			
		||||
      (a) You must give any other recipients of the Work or
 | 
			
		||||
          Derivative Works a copy of this License; and
 | 
			
		||||
 | 
			
		||||
      (b) You must cause any modified files to carry prominent notices
 | 
			
		||||
          stating that You changed the files; and
 | 
			
		||||
 | 
			
		||||
      (c) You must retain, in the Source form of any Derivative Works
 | 
			
		||||
          that You distribute, all copyright, patent, trademark, and
 | 
			
		||||
          attribution notices from the Source form of the Work,
 | 
			
		||||
          excluding those notices that do not pertain to any part of
 | 
			
		||||
          the Derivative Works; and
 | 
			
		||||
 | 
			
		||||
      (d) If the Work includes a "NOTICE" text file as part of its
 | 
			
		||||
          distribution, then any Derivative Works that You distribute must
 | 
			
		||||
          include a readable copy of the attribution notices contained
 | 
			
		||||
          within such NOTICE file, excluding those notices that do not
 | 
			
		||||
          pertain to any part of the Derivative Works, in at least one
 | 
			
		||||
          of the following places: within a NOTICE text file distributed
 | 
			
		||||
          as part of the Derivative Works; within the Source form or
 | 
			
		||||
          documentation, if provided along with the Derivative Works; or,
 | 
			
		||||
          within a display generated by the Derivative Works, if and
 | 
			
		||||
          wherever such third-party notices normally appear. The contents
 | 
			
		||||
          of the NOTICE file are for informational purposes only and
 | 
			
		||||
          do not modify the License. You may add Your own attribution
 | 
			
		||||
          notices within Derivative Works that You distribute, alongside
 | 
			
		||||
          or as an addendum to the NOTICE text from the Work, provided
 | 
			
		||||
          that such additional attribution notices cannot be construed
 | 
			
		||||
          as modifying the License.
 | 
			
		||||
 | 
			
		||||
      You may add Your own copyright statement to Your modifications and
 | 
			
		||||
      may provide additional or different license terms and conditions
 | 
			
		||||
      for use, reproduction, or distribution of Your modifications, or
 | 
			
		||||
      for any such Derivative Works as a whole, provided Your use,
 | 
			
		||||
      reproduction, and distribution of the Work otherwise complies with
 | 
			
		||||
      the conditions stated in this License.
 | 
			
		||||
 | 
			
		||||
   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
			
		||||
      any Contribution intentionally submitted for inclusion in the Work
 | 
			
		||||
      by You to the Licensor shall be under the terms and conditions of
 | 
			
		||||
      this License, without any additional terms or conditions.
 | 
			
		||||
      Notwithstanding the above, nothing herein shall supersede or modify
 | 
			
		||||
      the terms of any separate license agreement you may have executed
 | 
			
		||||
      with Licensor regarding such Contributions.
 | 
			
		||||
 | 
			
		||||
   6. Trademarks. This License does not grant permission to use the trade
 | 
			
		||||
      names, trademarks, service marks, or product names of the Licensor,
 | 
			
		||||
      except as required for reasonable and customary use in describing the
 | 
			
		||||
      origin of the Work and reproducing the content of the NOTICE file.
 | 
			
		||||
 | 
			
		||||
   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
			
		||||
      agreed to in writing, Licensor provides the Work (and each
 | 
			
		||||
      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
			
		||||
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
      implied, including, without limitation, any warranties or conditions
 | 
			
		||||
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
			
		||||
      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
			
		||||
      appropriateness of using or redistributing the Work and assume any
 | 
			
		||||
      risks associated with Your exercise of permissions under this License.
 | 
			
		||||
 | 
			
		||||
   8. Limitation of Liability. In no event and under no legal theory,
 | 
			
		||||
      whether in tort (including negligence), contract, or otherwise,
 | 
			
		||||
      unless required by applicable law (such as deliberate and grossly
 | 
			
		||||
      negligent acts) or agreed to in writing, shall any Contributor be
 | 
			
		||||
      liable to You for damages, including any direct, indirect, special,
 | 
			
		||||
      incidental, or consequential damages of any character arising as a
 | 
			
		||||
      result of this License or out of the use or inability to use the
 | 
			
		||||
      Work (including but not limited to damages for loss of goodwill,
 | 
			
		||||
      work stoppage, computer failure or malfunction, or any and all
 | 
			
		||||
      other commercial damages or losses), even if such Contributor
 | 
			
		||||
      has been advised of the possibility of such damages.
 | 
			
		||||
 | 
			
		||||
   9. Accepting Warranty or Additional Liability. While redistributing
 | 
			
		||||
      the Work or Derivative Works thereof, You may choose to offer,
 | 
			
		||||
      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
			
		||||
      or other liability obligations and/or rights consistent with this
 | 
			
		||||
      License. However, in accepting such obligations, You may act only
 | 
			
		||||
      on Your own behalf and on Your sole responsibility, not on behalf
 | 
			
		||||
      of any other Contributor, and only if You agree to indemnify,
 | 
			
		||||
      defend, and hold each Contributor harmless for any liability
 | 
			
		||||
      incurred by, or claims asserted against, such Contributor by reason
 | 
			
		||||
      of your accepting any such warranty or additional liability.
 | 
			
		||||
 | 
			
		||||
   END OF TERMS AND CONDITIONS
 | 
			
		||||
 | 
			
		||||
   APPENDIX: How to apply the Apache License to your work.
 | 
			
		||||
 | 
			
		||||
      To apply the Apache License to your work, attach the following
 | 
			
		||||
      boilerplate notice, with the fields enclosed by brackets "[]"
 | 
			
		||||
      replaced with your own identifying information. (Don't include
 | 
			
		||||
      the brackets!)  The text should be enclosed in the appropriate
 | 
			
		||||
      comment syntax for the file format. We also recommend that a
 | 
			
		||||
      file or class name and description of purpose be included on the
 | 
			
		||||
      same "printed page" as the copyright notice for easier
 | 
			
		||||
      identification within third-party archives.
 | 
			
		||||
 | 
			
		||||
   Copyright [yyyy] [name of copyright owner]
 | 
			
		||||
 | 
			
		||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
   you may not use this file except in compliance with the License.
 | 
			
		||||
   You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
       http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
   Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
   See the License for the specific language governing permissions and
 | 
			
		||||
   limitations under the License.
 | 
			
		||||
							
								
								
									
										30
									
								
								libs/ipstack/src/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								libs/ipstack/src/error.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
			
		||||
#[allow(dead_code)]
 | 
			
		||||
#[derive(thiserror::Error, Debug)]
 | 
			
		||||
pub enum IpStackError {
 | 
			
		||||
    #[error("The transport protocol is not supported")]
 | 
			
		||||
    UnsupportedTransportProtocol,
 | 
			
		||||
    #[error("The packet is invalid")]
 | 
			
		||||
    InvalidPacket,
 | 
			
		||||
    #[error("Write error: {0}")]
 | 
			
		||||
    PacketWriteError(etherparse::WriteError),
 | 
			
		||||
    #[error("Invalid Tcp packet")]
 | 
			
		||||
    InvalidTcpPacket,
 | 
			
		||||
    #[error("IO error: {0}")]
 | 
			
		||||
    IoError(#[from] std::io::Error),
 | 
			
		||||
    #[error("Accept Error")]
 | 
			
		||||
    AcceptError,
 | 
			
		||||
 | 
			
		||||
    #[error("Send Error {0}")]
 | 
			
		||||
    SendError(#[from] tokio::sync::mpsc::error::SendError<crate::stream::IpStackStream>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<IpStackError> for std::io::Error {
 | 
			
		||||
    fn from(e: IpStackError) -> Self {
 | 
			
		||||
        match e {
 | 
			
		||||
            IpStackError::IoError(e) => e,
 | 
			
		||||
            _ => std::io::Error::new(std::io::ErrorKind::Other, e),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub type Result<T, E = IpStackError> = std::result::Result<T, E>;
 | 
			
		||||
							
								
								
									
										194
									
								
								libs/ipstack/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								libs/ipstack/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,194 @@
 | 
			
		||||
pub use error::{IpStackError, Result};
 | 
			
		||||
use packet::{NetworkPacket, NetworkTuple};
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::{
 | 
			
		||||
        hash_map::Entry::{Occupied, Vacant},
 | 
			
		||||
        HashMap,
 | 
			
		||||
    },
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
use stream::IpStackStream;
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
 | 
			
		||||
    select,
 | 
			
		||||
    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
 | 
			
		||||
};
 | 
			
		||||
#[cfg(feature = "log")]
 | 
			
		||||
use tracing::{error, trace};
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    packet::IpStackPacketProtocol,
 | 
			
		||||
    stream::{IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport},
 | 
			
		||||
};
 | 
			
		||||
mod error;
 | 
			
		||||
mod packet;
 | 
			
		||||
pub mod stream;
 | 
			
		||||
 | 
			
		||||
const DROP_TTL: u8 = 0;
 | 
			
		||||
 | 
			
		||||
#[cfg(unix)]
 | 
			
		||||
const TTL: u8 = 64;
 | 
			
		||||
 | 
			
		||||
#[cfg(windows)]
 | 
			
		||||
const TTL: u8 = 128;
 | 
			
		||||
 | 
			
		||||
#[cfg(unix)]
 | 
			
		||||
const TUN_FLAGS: [u8; 2] = [0x00, 0x00];
 | 
			
		||||
 | 
			
		||||
#[cfg(any(target_os = "linux", target_os = "android"))]
 | 
			
		||||
const TUN_PROTO_IP6: [u8; 2] = [0x86, 0xdd];
 | 
			
		||||
#[cfg(any(target_os = "linux", target_os = "android"))]
 | 
			
		||||
const TUN_PROTO_IP4: [u8; 2] = [0x08, 0x00];
 | 
			
		||||
 | 
			
		||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
 | 
			
		||||
const TUN_PROTO_IP6: [u8; 2] = [0x00, 0x0A];
 | 
			
		||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
 | 
			
		||||
const TUN_PROTO_IP4: [u8; 2] = [0x00, 0x02];
 | 
			
		||||
 | 
			
		||||
pub struct IpStackConfig {
 | 
			
		||||
    pub mtu: u16,
 | 
			
		||||
    pub packet_information: bool,
 | 
			
		||||
    pub tcp_timeout: Duration,
 | 
			
		||||
    pub udp_timeout: Duration,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for IpStackConfig {
 | 
			
		||||
    fn default() -> Self {
 | 
			
		||||
        IpStackConfig {
 | 
			
		||||
            mtu: u16::MAX,
 | 
			
		||||
            packet_information: false,
 | 
			
		||||
            tcp_timeout: Duration::from_secs(60),
 | 
			
		||||
            udp_timeout: Duration::from_secs(30),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStackConfig {
 | 
			
		||||
    pub fn tcp_timeout(&mut self, timeout: Duration) {
 | 
			
		||||
        self.tcp_timeout = timeout;
 | 
			
		||||
    }
 | 
			
		||||
    pub fn udp_timeout(&mut self, timeout: Duration) {
 | 
			
		||||
        self.udp_timeout = timeout;
 | 
			
		||||
    }
 | 
			
		||||
    pub fn mtu(&mut self, mtu: u16) {
 | 
			
		||||
        self.mtu = mtu;
 | 
			
		||||
    }
 | 
			
		||||
    pub fn packet_information(&mut self, packet_information: bool) {
 | 
			
		||||
        self.packet_information = packet_information;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct IpStack {
 | 
			
		||||
    accept_receiver: UnboundedReceiver<IpStackStream>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStack {
 | 
			
		||||
    pub fn new<D>(config: IpStackConfig, mut device: D) -> IpStack
 | 
			
		||||
    where
 | 
			
		||||
        D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
 | 
			
		||||
    {
 | 
			
		||||
        let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();
 | 
			
		||||
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            let mut streams: HashMap<NetworkTuple, UnboundedSender<NetworkPacket>> = HashMap::new();
 | 
			
		||||
            let mut buffer = [0u8; u16::MAX as usize];
 | 
			
		||||
 | 
			
		||||
            let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
 | 
			
		||||
            loop {
 | 
			
		||||
                // dbg!(streams.len());
 | 
			
		||||
                select! {
 | 
			
		||||
                    Ok(n) = device.read(&mut buffer) => {
 | 
			
		||||
                        let offset = if config.packet_information && cfg!(unix) {4} else {0};
 | 
			
		||||
                        let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else {
 | 
			
		||||
                            accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?;
 | 
			
		||||
                            continue;
 | 
			
		||||
                        };
 | 
			
		||||
                        if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
 | 
			
		||||
                            accept_sender.send(IpStackStream::UnknownTransport(IpStackUnknownTransport::new(packet.src_addr().ip(),packet.dst_addr().ip(),packet.payload,&packet.ip,config.mtu,pkt_sender.clone())))?;
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        match streams.entry(packet.network_tuple()){
 | 
			
		||||
                            Occupied(entry) =>{
 | 
			
		||||
                                // let t = packet.transport_protocol();
 | 
			
		||||
                                if let Err(_x) = entry.get().send(packet){
 | 
			
		||||
                                    #[cfg(feature = "log")]
 | 
			
		||||
                                    trace!("{}", _x);
 | 
			
		||||
                                    // match t{
 | 
			
		||||
                                    //     IpStackPacketProtocol::Tcp(_t) => {
 | 
			
		||||
                                    //         // dbg!(t.flags());
 | 
			
		||||
                                    //     }
 | 
			
		||||
                                    //     IpStackPacketProtocol::Udp => {
 | 
			
		||||
                                    //         // dbg!("udp");
 | 
			
		||||
                                    //     }
 | 
			
		||||
                                    //     IpStackPacketProtocol::Unknown => {
 | 
			
		||||
                                    //         // dbg!("unknown");
 | 
			
		||||
                                    //     }
 | 
			
		||||
                                    // }
 | 
			
		||||
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                            Vacant(entry) => {
 | 
			
		||||
                                match packet.transport_protocol(){
 | 
			
		||||
                                    IpStackPacketProtocol::Tcp(h) => {
 | 
			
		||||
                                        match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{
 | 
			
		||||
                                            Ok(stream) => {
 | 
			
		||||
                                                entry.insert(stream.stream_sender());
 | 
			
		||||
                                                accept_sender.send(IpStackStream::Tcp(stream))?;
 | 
			
		||||
                                            }
 | 
			
		||||
                                            Err(_e) => {
 | 
			
		||||
                                                #[cfg(feature = "log")]
 | 
			
		||||
                                                error!("{}", _e);
 | 
			
		||||
                                            }
 | 
			
		||||
                                        }
 | 
			
		||||
                                    }
 | 
			
		||||
                                    IpStackPacketProtocol::Udp => {
 | 
			
		||||
                                        let stream = IpStackUdpStream::new(packet.src_addr(),packet.dst_addr(),packet.payload, pkt_sender.clone(),config.mtu,config.udp_timeout);
 | 
			
		||||
                                        entry.insert(stream.stream_sender());
 | 
			
		||||
                                        accept_sender.send(IpStackStream::Udp(stream))?;
 | 
			
		||||
                                    }
 | 
			
		||||
                                    IpStackPacketProtocol::Unknown => {
 | 
			
		||||
                                        unreachable!()
 | 
			
		||||
                                    }
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    Some(packet) = pkt_receiver.recv() => {
 | 
			
		||||
                        if packet.ttl() == 0{
 | 
			
		||||
                            streams.remove(&packet.reverse_network_tuple());
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                        #[allow(unused_mut)]
 | 
			
		||||
                        let Ok(mut packet_byte) = packet.to_bytes() else{
 | 
			
		||||
                            #[cfg(feature = "log")]
 | 
			
		||||
                            trace!("to_bytes error");
 | 
			
		||||
                            continue;
 | 
			
		||||
                        };
 | 
			
		||||
                        #[cfg(unix)]
 | 
			
		||||
                        if config.packet_information {
 | 
			
		||||
                            if packet.src_addr().is_ipv4(){
 | 
			
		||||
                                packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat());
 | 
			
		||||
                            } else{
 | 
			
		||||
                                packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat());
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        device.write_all(&packet_byte).await?;
 | 
			
		||||
                        // device.flush().await.unwrap();
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            #[allow(unreachable_code)]
 | 
			
		||||
            Ok::<(), IpStackError>(())
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        IpStack { accept_receiver }
 | 
			
		||||
    }
 | 
			
		||||
    pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
 | 
			
		||||
        if let Some(s) = self.accept_receiver.recv().await {
 | 
			
		||||
            Ok(s)
 | 
			
		||||
        } else {
 | 
			
		||||
            Err(IpStackError::AcceptError)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										217
									
								
								libs/ipstack/src/packet.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								libs/ipstack/src/packet.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,217 @@
 | 
			
		||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
 | 
			
		||||
 | 
			
		||||
use etherparse::{Ethernet2Header, IpHeader, PacketHeaders, TcpHeader, UdpHeader, WriteError};
 | 
			
		||||
use tracing::debug;
 | 
			
		||||
 | 
			
		||||
use crate::error::IpStackError;
 | 
			
		||||
 | 
			
		||||
#[derive(Eq, Hash, PartialEq, Debug)]
 | 
			
		||||
pub struct NetworkTuple {
 | 
			
		||||
    pub src: SocketAddr,
 | 
			
		||||
    pub dst: SocketAddr,
 | 
			
		||||
    pub tcp: bool,
 | 
			
		||||
}
 | 
			
		||||
pub mod tcp_flags {
 | 
			
		||||
    pub const CWR: u8 = 0b10000000;
 | 
			
		||||
    pub const ECE: u8 = 0b01000000;
 | 
			
		||||
    pub const URG: u8 = 0b00100000;
 | 
			
		||||
    pub const ACK: u8 = 0b00010000;
 | 
			
		||||
    pub const PSH: u8 = 0b00001000;
 | 
			
		||||
    pub const RST: u8 = 0b00000100;
 | 
			
		||||
    pub const SYN: u8 = 0b00000010;
 | 
			
		||||
    pub const FIN: u8 = 0b00000001;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(crate) enum IpStackPacketProtocol {
 | 
			
		||||
    Tcp(TcpPacket),
 | 
			
		||||
    Unknown,
 | 
			
		||||
    Udp,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(crate) enum TransportHeader {
 | 
			
		||||
    Tcp(TcpHeader),
 | 
			
		||||
    Udp(UdpHeader),
 | 
			
		||||
    Unknown,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct NetworkPacket {
 | 
			
		||||
    pub(crate) ip: IpHeader,
 | 
			
		||||
    pub(crate) transport: TransportHeader,
 | 
			
		||||
    pub(crate) payload: Vec<u8>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkPacket {
 | 
			
		||||
    pub fn parse(buf: &[u8]) -> Result<Self, IpStackError> {
 | 
			
		||||
        debug!("read: {:?}", buf);
 | 
			
		||||
        let p = PacketHeaders::from_ethernet_slice(buf).map_err(|_| IpStackError::InvalidPacket)?;
 | 
			
		||||
        let ip = p.ip.ok_or(IpStackError::InvalidPacket)?;
 | 
			
		||||
        let transport = match p.transport {
 | 
			
		||||
            Some(etherparse::TransportHeader::Tcp(h)) => TransportHeader::Tcp(h),
 | 
			
		||||
            Some(etherparse::TransportHeader::Udp(u)) => TransportHeader::Udp(u),
 | 
			
		||||
            _ => TransportHeader::Unknown,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let payload = if let TransportHeader::Unknown = transport {
 | 
			
		||||
            buf[ip.header_len()..].to_vec()
 | 
			
		||||
        } else {
 | 
			
		||||
            p.payload.to_vec()
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(NetworkPacket {
 | 
			
		||||
            ip,
 | 
			
		||||
            transport,
 | 
			
		||||
            payload,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
    pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol {
 | 
			
		||||
        match self.transport {
 | 
			
		||||
            TransportHeader::Udp(_) => IpStackPacketProtocol::Udp,
 | 
			
		||||
            TransportHeader::Tcp(ref h) => IpStackPacketProtocol::Tcp(h.into()),
 | 
			
		||||
            _ => IpStackPacketProtocol::Unknown,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn src_addr(&self) -> SocketAddr {
 | 
			
		||||
        let port = match &self.transport {
 | 
			
		||||
            TransportHeader::Udp(udp) => udp.source_port,
 | 
			
		||||
            TransportHeader::Tcp(tcp) => tcp.source_port,
 | 
			
		||||
            _ => 0,
 | 
			
		||||
        };
 | 
			
		||||
        match &self.ip {
 | 
			
		||||
            IpHeader::Version4(ip, _) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.source)), port)
 | 
			
		||||
            }
 | 
			
		||||
            IpHeader::Version6(ip, _) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.source)), port)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn dst_addr(&self) -> SocketAddr {
 | 
			
		||||
        let port = match &self.transport {
 | 
			
		||||
            TransportHeader::Udp(udp) => udp.destination_port,
 | 
			
		||||
            TransportHeader::Tcp(tcp) => tcp.destination_port,
 | 
			
		||||
            _ => 0,
 | 
			
		||||
        };
 | 
			
		||||
        match &self.ip {
 | 
			
		||||
            IpHeader::Version4(ip, _) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip.destination)), port)
 | 
			
		||||
            }
 | 
			
		||||
            IpHeader::Version6(ip, _) => {
 | 
			
		||||
                SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip.destination)), port)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn network_tuple(&self) -> NetworkTuple {
 | 
			
		||||
        NetworkTuple {
 | 
			
		||||
            src: self.src_addr(),
 | 
			
		||||
            dst: self.dst_addr(),
 | 
			
		||||
            tcp: matches!(self.transport, TransportHeader::Tcp(_)),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn reverse_network_tuple(&self) -> NetworkTuple {
 | 
			
		||||
        NetworkTuple {
 | 
			
		||||
            src: self.dst_addr(),
 | 
			
		||||
            dst: self.src_addr(),
 | 
			
		||||
            tcp: matches!(self.transport, TransportHeader::Tcp(_)),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn to_bytes(&self) -> Result<Vec<u8>, IpStackError> {
 | 
			
		||||
        let mut buf = Vec::new();
 | 
			
		||||
        let header = Ethernet2Header {
 | 
			
		||||
            source: [255; 6],
 | 
			
		||||
            destination: [255; 6],
 | 
			
		||||
            ether_type: 0x0800,
 | 
			
		||||
        };
 | 
			
		||||
        header.write(&mut buf).map_err(IpStackError::IoError)?;
 | 
			
		||||
        self.ip
 | 
			
		||||
            .write(&mut buf)
 | 
			
		||||
            .map_err(IpStackError::PacketWriteError)?;
 | 
			
		||||
        match self.transport {
 | 
			
		||||
            TransportHeader::Tcp(ref h) => h
 | 
			
		||||
                .write(&mut buf)
 | 
			
		||||
                .map_err(WriteError::from)
 | 
			
		||||
                .map_err(IpStackError::PacketWriteError)?,
 | 
			
		||||
            TransportHeader::Udp(ref h) => {
 | 
			
		||||
                h.write(&mut buf).map_err(IpStackError::PacketWriteError)?
 | 
			
		||||
            }
 | 
			
		||||
            _ => {}
 | 
			
		||||
        };
 | 
			
		||||
        // self.transport
 | 
			
		||||
        //     .write(&mut buf)
 | 
			
		||||
        //     .map_err(IpStackError::PacketWriteError)?;
 | 
			
		||||
        buf.extend_from_slice(&self.payload);
 | 
			
		||||
        debug!("write: {:?}", buf);
 | 
			
		||||
        Ok(buf)
 | 
			
		||||
    }
 | 
			
		||||
    pub fn ttl(&self) -> u8 {
 | 
			
		||||
        match &self.ip {
 | 
			
		||||
            IpHeader::Version4(ip, _) => ip.time_to_live,
 | 
			
		||||
            IpHeader::Version6(ip, _) => ip.hop_limit,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(super) struct TcpPacket {
 | 
			
		||||
    header: TcpHeader,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TcpPacket {
 | 
			
		||||
    pub fn inner(&self) -> &TcpHeader {
 | 
			
		||||
        &self.header
 | 
			
		||||
    }
 | 
			
		||||
    pub fn flags(&self) -> u8 {
 | 
			
		||||
        let inner = self.inner();
 | 
			
		||||
        let mut flags = 0;
 | 
			
		||||
        if inner.cwr {
 | 
			
		||||
            flags |= tcp_flags::CWR;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.ece {
 | 
			
		||||
            flags |= tcp_flags::ECE;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.urg {
 | 
			
		||||
            flags |= tcp_flags::URG;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.ack {
 | 
			
		||||
            flags |= tcp_flags::ACK;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.psh {
 | 
			
		||||
            flags |= tcp_flags::PSH;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.rst {
 | 
			
		||||
            flags |= tcp_flags::RST;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.syn {
 | 
			
		||||
            flags |= tcp_flags::SYN;
 | 
			
		||||
        }
 | 
			
		||||
        if inner.fin {
 | 
			
		||||
            flags |= tcp_flags::FIN;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        flags
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&TcpHeader> for TcpPacket {
 | 
			
		||||
    fn from(header: &TcpHeader) -> Self {
 | 
			
		||||
        TcpPacket {
 | 
			
		||||
            header: header.clone(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// pub struct UdpPacket {
 | 
			
		||||
//     header: UdpHeader,
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
// impl UdpPacket {
 | 
			
		||||
//     pub fn inner(&self) -> &UdpHeader {
 | 
			
		||||
//         &self.header
 | 
			
		||||
//     }
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
// impl From<&UdpHeader> for UdpPacket {
 | 
			
		||||
//     fn from(header: &UdpHeader) -> Self {
 | 
			
		||||
//         UdpPacket {
 | 
			
		||||
//             header: header.clone(),
 | 
			
		||||
//         }
 | 
			
		||||
//     }
 | 
			
		||||
// }
 | 
			
		||||
							
								
								
									
										46
									
								
								libs/ipstack/src/stream/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								libs/ipstack/src/stream/mod.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
 | 
			
		||||
 | 
			
		||||
pub use self::tcp::IpStackTcpStream;
 | 
			
		||||
pub use self::udp::IpStackUdpStream;
 | 
			
		||||
pub use self::unknown::IpStackUnknownTransport;
 | 
			
		||||
 | 
			
		||||
mod tcb;
 | 
			
		||||
mod tcp;
 | 
			
		||||
mod udp;
 | 
			
		||||
mod unknown;
 | 
			
		||||
 | 
			
		||||
pub enum IpStackStream {
 | 
			
		||||
    Tcp(IpStackTcpStream),
 | 
			
		||||
    Udp(IpStackUdpStream),
 | 
			
		||||
    UnknownTransport(IpStackUnknownTransport),
 | 
			
		||||
    UnknownNetwork(Vec<u8>),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStackStream {
 | 
			
		||||
    pub fn local_addr(&self) -> SocketAddr {
 | 
			
		||||
        match self {
 | 
			
		||||
            IpStackStream::Tcp(tcp) => tcp.local_addr(),
 | 
			
		||||
            IpStackStream::Udp(udp) => udp.local_addr(),
 | 
			
		||||
            IpStackStream::UnknownNetwork(_) => {
 | 
			
		||||
                SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
 | 
			
		||||
            }
 | 
			
		||||
            IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() {
 | 
			
		||||
                std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
 | 
			
		||||
                std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn peer_addr(&self) -> SocketAddr {
 | 
			
		||||
        match self {
 | 
			
		||||
            IpStackStream::Tcp(tcp) => tcp.peer_addr(),
 | 
			
		||||
            IpStackStream::Udp(udp) => udp.peer_addr(),
 | 
			
		||||
            IpStackStream::UnknownNetwork(_) => {
 | 
			
		||||
                SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
 | 
			
		||||
            }
 | 
			
		||||
            IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() {
 | 
			
		||||
                std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
 | 
			
		||||
                std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										234
									
								
								libs/ipstack/src/stream/tcb.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								libs/ipstack/src/stream/tcb.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,234 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::BTreeMap,
 | 
			
		||||
    pin::Pin,
 | 
			
		||||
    time::{Duration, SystemTime},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use tokio::time::Sleep;
 | 
			
		||||
 | 
			
		||||
use crate::packet::TcpPacket;
 | 
			
		||||
 | 
			
		||||
const MAX_UNACK: u32 = 1024 * 16; // 16KB
 | 
			
		||||
const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Debug)]
 | 
			
		||||
pub enum TcpState {
 | 
			
		||||
    SynReceived(bool), // bool means if syn/ack is sent
 | 
			
		||||
    Established,
 | 
			
		||||
    FinWait1,
 | 
			
		||||
    FinWait2(bool), // bool means waiting for ack
 | 
			
		||||
    Closed,
 | 
			
		||||
}
 | 
			
		||||
#[derive(Clone, Debug)]
 | 
			
		||||
pub(super) enum PacketStatus {
 | 
			
		||||
    WindowUpdate,
 | 
			
		||||
    Invalid,
 | 
			
		||||
    RetransmissionRequest,
 | 
			
		||||
    NewPacket,
 | 
			
		||||
    Ack,
 | 
			
		||||
    KeepAlive,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub(super) struct Tcb {
 | 
			
		||||
    pub(super) seq: u32,
 | 
			
		||||
    pub(super) retransmission: Option<u32>,
 | 
			
		||||
    pub(super) ack: u32,
 | 
			
		||||
    pub(super) last_ack: u32,
 | 
			
		||||
    pub(super) timeout: Pin<Box<Sleep>>,
 | 
			
		||||
    tcp_timeout: Duration,
 | 
			
		||||
    recv_window: u16,
 | 
			
		||||
    pub(super) send_window: u16,
 | 
			
		||||
    state: TcpState,
 | 
			
		||||
    pub(super) avg_send_window: (u64, u64),
 | 
			
		||||
    pub(super) inflight_packets: Vec<InflightPacket>,
 | 
			
		||||
    pub(super) unordered_packets: BTreeMap<u32, UnorderedPacket>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Tcb {
 | 
			
		||||
    pub(super) fn new(ack: u32, tcp_timeout: Duration) -> Tcb {
 | 
			
		||||
        let seq = 100;
 | 
			
		||||
        Tcb {
 | 
			
		||||
            seq,
 | 
			
		||||
            retransmission: None,
 | 
			
		||||
            ack,
 | 
			
		||||
            last_ack: seq,
 | 
			
		||||
            tcp_timeout,
 | 
			
		||||
            timeout: Box::pin(tokio::time::sleep_until(
 | 
			
		||||
                tokio::time::Instant::now() + tcp_timeout,
 | 
			
		||||
            )),
 | 
			
		||||
            send_window: u16::MAX,
 | 
			
		||||
            recv_window: 0,
 | 
			
		||||
            state: TcpState::SynReceived(false),
 | 
			
		||||
            avg_send_window: (1, 1),
 | 
			
		||||
            inflight_packets: Vec::new(),
 | 
			
		||||
            unordered_packets: BTreeMap::new(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn add_inflight_packet(&mut self, seq: u32, buf: &[u8]) {
 | 
			
		||||
        self.inflight_packets
 | 
			
		||||
            .push(InflightPacket::new(seq, buf.to_vec()));
 | 
			
		||||
        self.seq = self.seq.wrapping_add(buf.len() as u32);
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn add_unordered_packet(&mut self, seq: u32, buf: &[u8]) {
 | 
			
		||||
        if seq < self.ack {
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
        self.unordered_packets
 | 
			
		||||
            .insert(seq, UnorderedPacket::new(buf.to_vec()));
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_available_read_buffer_size(&self) -> usize {
 | 
			
		||||
        READ_BUFFER_SIZE.saturating_sub(
 | 
			
		||||
            self.unordered_packets
 | 
			
		||||
                .iter()
 | 
			
		||||
                .fold(0, |acc, (_, p)| acc + p.payload.len()),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_unordered_packets(&mut self) -> Option<Vec<u8>> {
 | 
			
		||||
        // dbg!(self.ack);
 | 
			
		||||
        // for (seq,_) in self.unordered_packets.iter() {
 | 
			
		||||
        //     dbg!(seq);
 | 
			
		||||
        // }
 | 
			
		||||
        self.unordered_packets
 | 
			
		||||
            .remove(&self.ack)
 | 
			
		||||
            .map(|p| p.payload.clone())
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn add_seq_one(&mut self) {
 | 
			
		||||
        self.seq = self.seq.wrapping_add(1);
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_seq(&self) -> u32 {
 | 
			
		||||
        self.seq
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn add_ack(&mut self, add: u32) {
 | 
			
		||||
        self.ack = self.ack.wrapping_add(add);
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_ack(&self) -> u32 {
 | 
			
		||||
        self.ack
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn change_state(&mut self, state: TcpState) {
 | 
			
		||||
        self.state = state;
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_state(&self) -> &TcpState {
 | 
			
		||||
        &self.state
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn change_send_window(&mut self, window: u16) {
 | 
			
		||||
        let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64)
 | 
			
		||||
            / (self.avg_send_window.1 + 1);
 | 
			
		||||
        self.avg_send_window.0 = avg_send_window;
 | 
			
		||||
        self.avg_send_window.1 += 1;
 | 
			
		||||
        self.send_window = window;
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_send_window(&self) -> u16 {
 | 
			
		||||
        self.send_window
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn change_recv_window(&mut self, window: u16) {
 | 
			
		||||
        self.recv_window = window;
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn get_recv_window(&self) -> u16 {
 | 
			
		||||
        self.recv_window
 | 
			
		||||
    }
 | 
			
		||||
    // #[inline(always)]
 | 
			
		||||
    // pub(super) fn buffer_size(&self, payload_len: u16) -> u16 {
 | 
			
		||||
    //     match MAX_UNACK - self.inflight_packets.len() as u32 {
 | 
			
		||||
    //         // b if b.saturating_sub(payload_len as u32 + 64) != 0 => payload_len,
 | 
			
		||||
    //         // b if b < 128 && b >= 4 => (b / 2) as u16,
 | 
			
		||||
    //         // b if b < 4 => b as u16,
 | 
			
		||||
    //         // b => (b - 64) as u16,
 | 
			
		||||
    //         b if b >= payload_len as u32 * 2 && b > 0 => payload_len,
 | 
			
		||||
    //         b if b < 4 => b as u16,
 | 
			
		||||
    //         b => (b / 2) as u16,
 | 
			
		||||
    //     }
 | 
			
		||||
    // }
 | 
			
		||||
 | 
			
		||||
    pub(super) fn check_pkt_type(&self, incoming_packet: &TcpPacket, p: &[u8]) -> PacketStatus {
 | 
			
		||||
        let received_ack_distance = self
 | 
			
		||||
            .seq
 | 
			
		||||
            .wrapping_sub(incoming_packet.inner().acknowledgment_number);
 | 
			
		||||
 | 
			
		||||
        let current_ack_distance = self.seq.wrapping_sub(self.last_ack);
 | 
			
		||||
        if received_ack_distance > current_ack_distance
 | 
			
		||||
            || (incoming_packet.inner().acknowledgment_number != self.seq
 | 
			
		||||
                && self
 | 
			
		||||
                    .seq
 | 
			
		||||
                    .saturating_sub(incoming_packet.inner().acknowledgment_number)
 | 
			
		||||
                    == 0)
 | 
			
		||||
        {
 | 
			
		||||
            PacketStatus::Invalid
 | 
			
		||||
        } else if self.last_ack == incoming_packet.inner().acknowledgment_number {
 | 
			
		||||
            if !p.is_empty() {
 | 
			
		||||
                PacketStatus::NewPacket
 | 
			
		||||
            } else if self.send_window == incoming_packet.inner().window_size
 | 
			
		||||
                && self.seq != self.last_ack
 | 
			
		||||
            {
 | 
			
		||||
                PacketStatus::RetransmissionRequest
 | 
			
		||||
            } else if self.ack.wrapping_sub(1) == incoming_packet.inner().sequence_number {
 | 
			
		||||
                PacketStatus::KeepAlive
 | 
			
		||||
            } else {
 | 
			
		||||
                PacketStatus::WindowUpdate
 | 
			
		||||
            }
 | 
			
		||||
        } else if self.last_ack < incoming_packet.inner().acknowledgment_number {
 | 
			
		||||
            if !p.is_empty() {
 | 
			
		||||
                PacketStatus::NewPacket
 | 
			
		||||
            } else {
 | 
			
		||||
                PacketStatus::Ack
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            PacketStatus::Invalid
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub(super) fn change_last_ack(&mut self, ack: u32) {
 | 
			
		||||
        self.timeout
 | 
			
		||||
            .as_mut()
 | 
			
		||||
            .reset(tokio::time::Instant::now() + self.tcp_timeout);
 | 
			
		||||
        let distance = ack.wrapping_sub(self.last_ack);
 | 
			
		||||
 | 
			
		||||
        if matches!(self.state, TcpState::Established) {
 | 
			
		||||
            if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) {
 | 
			
		||||
                let mut inflight_packet = self.inflight_packets.remove(i);
 | 
			
		||||
                let distance = ack.wrapping_sub(inflight_packet.seq);
 | 
			
		||||
                if (distance as usize) < inflight_packet.payload.len() {
 | 
			
		||||
                    inflight_packet.payload.drain(0..distance as usize);
 | 
			
		||||
                    inflight_packet.seq = ack;
 | 
			
		||||
                    self.inflight_packets.push(inflight_packet);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.last_ack = self.last_ack.wrapping_add(distance);
 | 
			
		||||
    }
 | 
			
		||||
    pub fn is_send_buffer_full(&self) -> bool {
 | 
			
		||||
        self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct InflightPacket {
 | 
			
		||||
    pub seq: u32,
 | 
			
		||||
    pub payload: Vec<u8>,
 | 
			
		||||
    pub send_time: SystemTime,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl InflightPacket {
 | 
			
		||||
    fn new(seq: u32, payload: Vec<u8>) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            seq,
 | 
			
		||||
            payload,
 | 
			
		||||
            send_time: SystemTime::now(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub(crate) fn contains(&self, seq: u32) -> bool {
 | 
			
		||||
        self.seq < seq && self.seq + self.payload.len() as u32 >= seq
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct UnorderedPacket {
 | 
			
		||||
    pub payload: Vec<u8>,
 | 
			
		||||
    pub recv_time: SystemTime,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl UnorderedPacket {
 | 
			
		||||
    pub(crate) fn new(payload: Vec<u8>) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            payload,
 | 
			
		||||
            recv_time: SystemTime::now(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										509
									
								
								libs/ipstack/src/stream/tcp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										509
									
								
								libs/ipstack/src/stream/tcp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,509 @@
 | 
			
		||||
use crate::{
 | 
			
		||||
    error::IpStackError,
 | 
			
		||||
    packet::{tcp_flags, IpStackPacketProtocol, TcpPacket, TransportHeader},
 | 
			
		||||
    stream::tcb::{Tcb, TcpState},
 | 
			
		||||
    DROP_TTL, TTL,
 | 
			
		||||
};
 | 
			
		||||
use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions};
 | 
			
		||||
use std::{
 | 
			
		||||
    cmp,
 | 
			
		||||
    future::Future,
 | 
			
		||||
    io::{Error, ErrorKind},
 | 
			
		||||
    net::SocketAddr,
 | 
			
		||||
    pin::Pin,
 | 
			
		||||
    task::Waker,
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncRead, AsyncWrite},
 | 
			
		||||
    sync::{
 | 
			
		||||
        mpsc::{self, UnboundedReceiver, UnboundedSender},
 | 
			
		||||
        Notify,
 | 
			
		||||
    },
 | 
			
		||||
};
 | 
			
		||||
#[cfg(feature = "log")]
 | 
			
		||||
use tracing::{trace, warn};
 | 
			
		||||
 | 
			
		||||
use crate::packet::NetworkPacket;
 | 
			
		||||
 | 
			
		||||
use super::tcb::PacketStatus;
 | 
			
		||||
 | 
			
		||||
pub struct IpStackTcpStream {
 | 
			
		||||
    src_addr: SocketAddr,
 | 
			
		||||
    dst_addr: SocketAddr,
 | 
			
		||||
    stream_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
    stream_receiver: UnboundedReceiver<NetworkPacket>,
 | 
			
		||||
    packet_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
    packet_to_send: Option<NetworkPacket>,
 | 
			
		||||
    tcb: Tcb,
 | 
			
		||||
    mtu: u16,
 | 
			
		||||
    shutdown: Option<Notify>,
 | 
			
		||||
    write_notify: Option<Waker>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStackTcpStream {
 | 
			
		||||
    pub(crate) async fn new(
 | 
			
		||||
        src_addr: SocketAddr,
 | 
			
		||||
        dst_addr: SocketAddr,
 | 
			
		||||
        tcp: TcpPacket,
 | 
			
		||||
        pkt_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
        mtu: u16,
 | 
			
		||||
        tcp_timeout: Duration,
 | 
			
		||||
    ) -> Result<IpStackTcpStream, IpStackError> {
 | 
			
		||||
        let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
 | 
			
		||||
 | 
			
		||||
        let mut stream = IpStackTcpStream {
 | 
			
		||||
            src_addr,
 | 
			
		||||
            dst_addr,
 | 
			
		||||
            stream_sender,
 | 
			
		||||
            stream_receiver,
 | 
			
		||||
            packet_sender: pkt_sender.clone(),
 | 
			
		||||
            packet_to_send: None,
 | 
			
		||||
            tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout),
 | 
			
		||||
            mtu,
 | 
			
		||||
            shutdown: None,
 | 
			
		||||
            write_notify: None,
 | 
			
		||||
        };
 | 
			
		||||
        if !tcp.inner().syn {
 | 
			
		||||
            pkt_sender
 | 
			
		||||
                .send(stream.create_rev_packet(
 | 
			
		||||
                    tcp_flags::RST | tcp_flags::ACK,
 | 
			
		||||
                    TTL,
 | 
			
		||||
                    None,
 | 
			
		||||
                    Vec::new(),
 | 
			
		||||
                )?)
 | 
			
		||||
                .map_err(|_| IpStackError::InvalidTcpPacket)?;
 | 
			
		||||
            stream.tcb.change_state(TcpState::Closed);
 | 
			
		||||
        }
 | 
			
		||||
        Ok(stream)
 | 
			
		||||
    }
 | 
			
		||||
    pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
 | 
			
		||||
        self.stream_sender.clone()
 | 
			
		||||
    }
 | 
			
		||||
    fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 {
 | 
			
		||||
        cmp::min(
 | 
			
		||||
            self.tcb.get_send_window(),
 | 
			
		||||
            self.mtu.saturating_sub(ip_header_size + tcp_header_size),
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
    fn create_rev_packet(
 | 
			
		||||
        &self,
 | 
			
		||||
        flags: u8,
 | 
			
		||||
        ttl: u8,
 | 
			
		||||
        seq: Option<u32>,
 | 
			
		||||
        mut payload: Vec<u8>,
 | 
			
		||||
    ) -> Result<NetworkPacket, Error> {
 | 
			
		||||
        let mut tcp_header = etherparse::TcpHeader::new(
 | 
			
		||||
            self.dst_addr.port(),
 | 
			
		||||
            self.src_addr.port(),
 | 
			
		||||
            seq.unwrap_or(self.tcb.get_seq()),
 | 
			
		||||
            self.tcb.get_recv_window(),
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        tcp_header.acknowledgment_number = self.tcb.get_ack();
 | 
			
		||||
        if flags & tcp_flags::SYN != 0 {
 | 
			
		||||
            tcp_header.syn = true;
 | 
			
		||||
        }
 | 
			
		||||
        if flags & tcp_flags::ACK != 0 {
 | 
			
		||||
            tcp_header.ack = true;
 | 
			
		||||
        }
 | 
			
		||||
        if flags & tcp_flags::RST != 0 {
 | 
			
		||||
            tcp_header.rst = true;
 | 
			
		||||
        }
 | 
			
		||||
        if flags & tcp_flags::FIN != 0 {
 | 
			
		||||
            tcp_header.fin = true;
 | 
			
		||||
        }
 | 
			
		||||
        if flags & tcp_flags::PSH != 0 {
 | 
			
		||||
            tcp_header.psh = true;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) {
 | 
			
		||||
            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
 | 
			
		||||
                let mut ip_h = Ipv4Header::new(0, ttl, 6, dst.octets(), src.octets());
 | 
			
		||||
                let payload_len =
 | 
			
		||||
                    self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len());
 | 
			
		||||
                payload.truncate(payload_len as usize);
 | 
			
		||||
                ip_h.payload_len = payload.len() as u16 + tcp_header.header_len();
 | 
			
		||||
                ip_h.dont_fragment = true;
 | 
			
		||||
                etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default())
 | 
			
		||||
            }
 | 
			
		||||
            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
 | 
			
		||||
                let mut ip_h = etherparse::Ipv6Header {
 | 
			
		||||
                    traffic_class: 0,
 | 
			
		||||
                    flow_label: 0,
 | 
			
		||||
                    payload_length: 0,
 | 
			
		||||
                    next_header: 6,
 | 
			
		||||
                    hop_limit: ttl,
 | 
			
		||||
                    source: dst.octets(),
 | 
			
		||||
                    destination: src.octets(),
 | 
			
		||||
                };
 | 
			
		||||
                let payload_len =
 | 
			
		||||
                    self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len());
 | 
			
		||||
                payload.truncate(payload_len as usize);
 | 
			
		||||
                ip_h.payload_length = payload.len() as u16 + tcp_header.header_len();
 | 
			
		||||
 | 
			
		||||
                etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default())
 | 
			
		||||
            }
 | 
			
		||||
            _ => unreachable!(),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        match ip_header {
 | 
			
		||||
            etherparse::IpHeader::Version4(ref ip_header, _) => {
 | 
			
		||||
                tcp_header.checksum = tcp_header
 | 
			
		||||
                    .calc_checksum_ipv4(ip_header, &payload)
 | 
			
		||||
                    .map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
 | 
			
		||||
            }
 | 
			
		||||
            etherparse::IpHeader::Version6(ref ip_header, _) => {
 | 
			
		||||
                tcp_header.checksum = tcp_header
 | 
			
		||||
                    .calc_checksum_ipv6(ip_header, &payload)
 | 
			
		||||
                    .map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(NetworkPacket {
 | 
			
		||||
            ip: ip_header,
 | 
			
		||||
            transport: TransportHeader::Tcp(tcp_header),
 | 
			
		||||
            payload,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
    pub fn local_addr(&self) -> SocketAddr {
 | 
			
		||||
        self.src_addr
 | 
			
		||||
    }
 | 
			
		||||
    pub fn peer_addr(&self) -> SocketAddr {
 | 
			
		||||
        self.dst_addr
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRead for IpStackTcpStream {
 | 
			
		||||
    fn poll_read(
 | 
			
		||||
        mut self: std::pin::Pin<&mut Self>,
 | 
			
		||||
        cx: &mut std::task::Context<'_>,
 | 
			
		||||
        buf: &mut tokio::io::ReadBuf<'_>,
 | 
			
		||||
    ) -> std::task::Poll<std::io::Result<()>> {
 | 
			
		||||
        loop {
 | 
			
		||||
            if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) {
 | 
			
		||||
                self.packet_to_send =
 | 
			
		||||
                    Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
 | 
			
		||||
                self.tcb.change_state(TcpState::Closed);
 | 
			
		||||
                return std::task::Poll::Ready(Ok(()));
 | 
			
		||||
            }
 | 
			
		||||
            let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX);
 | 
			
		||||
            self.tcb.change_recv_window(min);
 | 
			
		||||
            if matches!(
 | 
			
		||||
                Pin::new(&mut self.tcb.timeout).poll(cx),
 | 
			
		||||
                std::task::Poll::Ready(_)
 | 
			
		||||
            ) {
 | 
			
		||||
                #[cfg(feature = "log")]
 | 
			
		||||
                trace!("timeout reached for {:?}", self.dst_addr);
 | 
			
		||||
                self.packet_sender
 | 
			
		||||
                    .send(self.create_rev_packet(
 | 
			
		||||
                        tcp_flags::RST | tcp_flags::ACK,
 | 
			
		||||
                        TTL,
 | 
			
		||||
                        None,
 | 
			
		||||
                        Vec::new(),
 | 
			
		||||
                    )?)
 | 
			
		||||
                    .map_err(|_| ErrorKind::UnexpectedEof)?;
 | 
			
		||||
                return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) {
 | 
			
		||||
                self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                    tcp_flags::SYN | tcp_flags::ACK,
 | 
			
		||||
                    TTL,
 | 
			
		||||
                    None,
 | 
			
		||||
                    Vec::new(),
 | 
			
		||||
                )?);
 | 
			
		||||
                self.tcb.add_seq_one();
 | 
			
		||||
                self.tcb.change_state(TcpState::SynReceived(true));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if let Some(packet) = self.packet_to_send.take() {
 | 
			
		||||
                self.packet_sender
 | 
			
		||||
                    .send(packet)
 | 
			
		||||
                    .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
 | 
			
		||||
                if matches!(self.tcb.get_state(), TcpState::Closed) {
 | 
			
		||||
                    if let Some(shutdown) = self.shutdown.take() {
 | 
			
		||||
                        shutdown.notify_one();
 | 
			
		||||
                    }
 | 
			
		||||
                    return std::task::Poll::Ready(Ok(()));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            if let Some(b) = self.tcb.get_unordered_packets() {
 | 
			
		||||
                self.tcb.add_ack(b.len() as u32);
 | 
			
		||||
                buf.put_slice(&b);
 | 
			
		||||
                self.packet_sender
 | 
			
		||||
                    .send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?)
 | 
			
		||||
                    .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
 | 
			
		||||
                return std::task::Poll::Ready(Ok(()));
 | 
			
		||||
            }
 | 
			
		||||
            if self.shutdown.is_some() && matches!(self.tcb.get_state(), TcpState::Established) {
 | 
			
		||||
                self.tcb.change_state(TcpState::FinWait1);
 | 
			
		||||
                self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                    tcp_flags::FIN | tcp_flags::ACK,
 | 
			
		||||
                    TTL,
 | 
			
		||||
                    None,
 | 
			
		||||
                    Vec::new(),
 | 
			
		||||
                )?);
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
            match self.stream_receiver.poll_recv(cx) {
 | 
			
		||||
                std::task::Poll::Ready(Some(p)) => {
 | 
			
		||||
                    let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
 | 
			
		||||
                        unreachable!()
 | 
			
		||||
                    };
 | 
			
		||||
                    if t.flags() & tcp_flags::RST != 0 {
 | 
			
		||||
                        self.packet_to_send =
 | 
			
		||||
                            Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
 | 
			
		||||
                        self.tcb.change_state(TcpState::Closed);
 | 
			
		||||
                        return std::task::Poll::Ready(Err(Error::from(
 | 
			
		||||
                            ErrorKind::ConnectionReset,
 | 
			
		||||
                        )));
 | 
			
		||||
                    }
 | 
			
		||||
                    if matches!(
 | 
			
		||||
                        self.tcb.check_pkt_type(&t, &p.payload),
 | 
			
		||||
                        PacketStatus::Invalid
 | 
			
		||||
                    ) {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) {
 | 
			
		||||
                        if t.flags() == tcp_flags::ACK {
 | 
			
		||||
                            self.tcb.change_last_ack(t.inner().acknowledgment_number);
 | 
			
		||||
                            self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                            self.tcb.change_state(TcpState::Established);
 | 
			
		||||
                        }
 | 
			
		||||
                    } else if matches!(self.tcb.get_state(), TcpState::Established) {
 | 
			
		||||
                        if t.flags() == tcp_flags::ACK {
 | 
			
		||||
                            match self.tcb.check_pkt_type(&t, &p.payload) {
 | 
			
		||||
                                PacketStatus::WindowUpdate => {
 | 
			
		||||
                                    self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                                    if let Some(ref n) = self.write_notify {
 | 
			
		||||
                                        n.wake_by_ref();
 | 
			
		||||
                                        self.write_notify = None;
 | 
			
		||||
                                    };
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                                PacketStatus::Invalid => continue,
 | 
			
		||||
                                PacketStatus::KeepAlive => {
 | 
			
		||||
                                    self.tcb.change_last_ack(t.inner().acknowledgment_number);
 | 
			
		||||
                                    self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                                    self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                                        tcp_flags::ACK,
 | 
			
		||||
                                        TTL,
 | 
			
		||||
                                        None,
 | 
			
		||||
                                        Vec::new(),
 | 
			
		||||
                                    )?);
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                                PacketStatus::RetransmissionRequest => {
 | 
			
		||||
                                    self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                                    self.tcb.retransmission = Some(t.inner().acknowledgment_number);
 | 
			
		||||
                                    if matches!(
 | 
			
		||||
                                        self.as_mut().poll_flush(cx),
 | 
			
		||||
                                        std::task::Poll::Pending
 | 
			
		||||
                                    ) {
 | 
			
		||||
                                        return std::task::Poll::Pending;
 | 
			
		||||
                                    }
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                                PacketStatus::NewPacket => {
 | 
			
		||||
                                    // if t.inner().sequence_number != self.tcb.get_ack() {
 | 
			
		||||
                                    //     dbg!(t.inner().sequence_number);
 | 
			
		||||
                                    //     self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                                    //         tcp_flags::ACK,
 | 
			
		||||
                                    //         TTL,
 | 
			
		||||
                                    //         None,
 | 
			
		||||
                                    //         Vec::new(),
 | 
			
		||||
                                    //     )?);
 | 
			
		||||
                                    //     continue;
 | 
			
		||||
                                    // }
 | 
			
		||||
 | 
			
		||||
                                    self.tcb.change_last_ack(t.inner().acknowledgment_number);
 | 
			
		||||
                                    self.tcb.add_unordered_packet(
 | 
			
		||||
                                        t.inner().sequence_number,
 | 
			
		||||
                                        &p.payload,
 | 
			
		||||
                                    );
 | 
			
		||||
                                    // buf.put_slice(&p.payload);
 | 
			
		||||
                                    // self.tcb.add_ack(p.payload.len() as u32);
 | 
			
		||||
                                    // self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                                    //     tcp_flags::ACK,
 | 
			
		||||
                                    //     TTL,
 | 
			
		||||
                                    //     None,
 | 
			
		||||
                                    //     Vec::new(),
 | 
			
		||||
                                    // )?);
 | 
			
		||||
                                    self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                                    if let Some(ref n) = self.write_notify {
 | 
			
		||||
                                        n.wake_by_ref();
 | 
			
		||||
                                        self.write_notify = None;
 | 
			
		||||
                                    };
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                    // return std::task::Poll::Ready(Ok(()));
 | 
			
		||||
                                }
 | 
			
		||||
                                PacketStatus::Ack => {
 | 
			
		||||
                                    self.tcb.change_last_ack(t.inner().acknowledgment_number);
 | 
			
		||||
                                    self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                                    if let Some(ref n) = self.write_notify {
 | 
			
		||||
                                        n.wake_by_ref();
 | 
			
		||||
                                        self.write_notify = None;
 | 
			
		||||
                                    };
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                            };
 | 
			
		||||
                        }
 | 
			
		||||
                        if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
 | 
			
		||||
                            self.tcb.add_ack(1);
 | 
			
		||||
                            self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                                tcp_flags::FIN | tcp_flags::ACK,
 | 
			
		||||
                                TTL,
 | 
			
		||||
                                None,
 | 
			
		||||
                                Vec::new(),
 | 
			
		||||
                            )?);
 | 
			
		||||
                            self.tcb.add_seq_one();
 | 
			
		||||
                            self.tcb.change_state(TcpState::FinWait2(true));
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                        if t.flags() == (tcp_flags::PSH | tcp_flags::ACK) {
 | 
			
		||||
                            if !matches!(
 | 
			
		||||
                                self.tcb.check_pkt_type(&t, &p.payload),
 | 
			
		||||
                                PacketStatus::NewPacket
 | 
			
		||||
                            ) {
 | 
			
		||||
                                continue;
 | 
			
		||||
                            }
 | 
			
		||||
                            self.tcb.change_last_ack(t.inner().acknowledgment_number);
 | 
			
		||||
 | 
			
		||||
                            if p.payload.is_empty()
 | 
			
		||||
                                || self.tcb.get_ack() != t.inner().sequence_number
 | 
			
		||||
                            {
 | 
			
		||||
                                continue;
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            // self.tcb.add_ack(p.payload.len() as u32);
 | 
			
		||||
                            self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                            // buf.put_slice(&p.payload);
 | 
			
		||||
                            // self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                            //     tcp_flags::ACK,
 | 
			
		||||
                            //     TTL,
 | 
			
		||||
                            //     None,
 | 
			
		||||
                            //     Vec::new(),
 | 
			
		||||
                            // )?);
 | 
			
		||||
                            // return std::task::Poll::Ready(Ok(()));
 | 
			
		||||
                            self.tcb
 | 
			
		||||
                                .add_unordered_packet(t.inner().sequence_number, &p.payload);
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                    } else if matches!(self.tcb.get_state(), TcpState::FinWait1) {
 | 
			
		||||
                        if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
 | 
			
		||||
                            self.packet_to_send = Some(self.create_rev_packet(
 | 
			
		||||
                                tcp_flags::ACK,
 | 
			
		||||
                                TTL,
 | 
			
		||||
                                None,
 | 
			
		||||
                                Vec::new(),
 | 
			
		||||
                            )?);
 | 
			
		||||
                            self.tcb.change_send_window(t.inner().window_size);
 | 
			
		||||
                            self.tcb.add_seq_one();
 | 
			
		||||
                            self.tcb.change_state(TcpState::FinWait2(false));
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                    } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true))
 | 
			
		||||
                        && t.flags() == tcp_flags::ACK
 | 
			
		||||
                    {
 | 
			
		||||
                        self.tcb.change_state(TcpState::FinWait2(false));
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())),
 | 
			
		||||
                std::task::Poll::Pending => return std::task::Poll::Pending,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncWrite for IpStackTcpStream {
 | 
			
		||||
    fn poll_write(
 | 
			
		||||
        mut self: std::pin::Pin<&mut Self>,
 | 
			
		||||
        cx: &mut std::task::Context<'_>,
 | 
			
		||||
        buf: &[u8],
 | 
			
		||||
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
 | 
			
		||||
        if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2
 | 
			
		||||
            || self.tcb.is_send_buffer_full()
 | 
			
		||||
        {
 | 
			
		||||
            self.write_notify = Some(cx.waker().clone());
 | 
			
		||||
            return std::task::Poll::Pending;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.tcb.retransmission.is_some() {
 | 
			
		||||
            self.write_notify = Some(cx.waker().clone());
 | 
			
		||||
            if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) {
 | 
			
		||||
                return std::task::Poll::Pending;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let packet =
 | 
			
		||||
            self.create_rev_packet(tcp_flags::PSH | tcp_flags::ACK, TTL, None, buf.to_vec())?;
 | 
			
		||||
        let seq = self.tcb.seq;
 | 
			
		||||
        let payload_len = packet.payload.len();
 | 
			
		||||
        let payload = packet.payload.clone();
 | 
			
		||||
 | 
			
		||||
        self.packet_sender
 | 
			
		||||
            .send(packet)
 | 
			
		||||
            .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
 | 
			
		||||
        self.tcb.add_inflight_packet(seq, &payload);
 | 
			
		||||
 | 
			
		||||
        std::task::Poll::Ready(Ok(payload_len))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_flush(
 | 
			
		||||
        mut self: std::pin::Pin<&mut Self>,
 | 
			
		||||
        _cx: &mut std::task::Context<'_>,
 | 
			
		||||
    ) -> std::task::Poll<Result<(), std::io::Error>> {
 | 
			
		||||
        if let Some(i) = self
 | 
			
		||||
            .tcb
 | 
			
		||||
            .retransmission
 | 
			
		||||
            .and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s))
 | 
			
		||||
            .and_then(|p| self.tcb.inflight_packets.get(p))
 | 
			
		||||
        {
 | 
			
		||||
            let packet = self.create_rev_packet(
 | 
			
		||||
                tcp_flags::PSH | tcp_flags::ACK,
 | 
			
		||||
                TTL,
 | 
			
		||||
                Some(i.seq),
 | 
			
		||||
                i.payload.to_vec(),
 | 
			
		||||
            )?;
 | 
			
		||||
 | 
			
		||||
            self.packet_sender
 | 
			
		||||
                .send(packet)
 | 
			
		||||
                .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
 | 
			
		||||
            self.tcb.retransmission = None;
 | 
			
		||||
        } else if let Some(_i) = self.tcb.retransmission {
 | 
			
		||||
            #[cfg(feature = "log")]
 | 
			
		||||
            {
 | 
			
		||||
                warn!(_i);
 | 
			
		||||
                warn!(self.tcb.seq);
 | 
			
		||||
                warn!(self.tcb.last_ack);
 | 
			
		||||
                warn!(self.tcb.ack);
 | 
			
		||||
                for p in self.tcb.inflight_packets.iter() {
 | 
			
		||||
                    warn!(p.seq);
 | 
			
		||||
                    warn!("{}", p.payload.len());
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            panic!("Please report these values at: https://github.com/narrowlink/ipstack/");
 | 
			
		||||
        }
 | 
			
		||||
        std::task::Poll::Ready(Ok(()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_shutdown(
 | 
			
		||||
        mut self: std::pin::Pin<&mut Self>,
 | 
			
		||||
        cx: &mut std::task::Context<'_>,
 | 
			
		||||
    ) -> std::task::Poll<Result<(), std::io::Error>> {
 | 
			
		||||
        let notified = self.shutdown.get_or_insert(Notify::new()).notified();
 | 
			
		||||
        match Pin::new(&mut Box::pin(notified)).poll(cx) {
 | 
			
		||||
            std::task::Poll::Ready(_) => std::task::Poll::Ready(Ok(())),
 | 
			
		||||
            std::task::Poll::Pending => std::task::Poll::Pending,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for IpStackTcpStream {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        if let Ok(p) = self.create_rev_packet(0, DROP_TTL, None, Vec::new()) {
 | 
			
		||||
            _ = self.packet_sender.send(p);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										181
									
								
								libs/ipstack/src/stream/udp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								libs/ipstack/src/stream/udp.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
			
		||||
use core::task;
 | 
			
		||||
use std::{
 | 
			
		||||
    future::Future,
 | 
			
		||||
    io::{self, Error, ErrorKind},
 | 
			
		||||
    net::SocketAddr,
 | 
			
		||||
    pin::Pin,
 | 
			
		||||
    task::Poll,
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use etherparse::{Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header, UdpHeader};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncRead, AsyncWrite},
 | 
			
		||||
    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
 | 
			
		||||
    time::Sleep,
 | 
			
		||||
};
 | 
			
		||||
// use crate::packet::TransportHeader;
 | 
			
		||||
use crate::{
 | 
			
		||||
    packet::{NetworkPacket, TransportHeader},
 | 
			
		||||
    TTL,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
pub struct IpStackUdpStream {
 | 
			
		||||
    src_addr: SocketAddr,
 | 
			
		||||
    dst_addr: SocketAddr,
 | 
			
		||||
    stream_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
    stream_receiver: UnboundedReceiver<NetworkPacket>,
 | 
			
		||||
    packet_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
    first_paload: Option<Vec<u8>>,
 | 
			
		||||
    timeout: Pin<Box<Sleep>>,
 | 
			
		||||
    udp_timeout: Duration,
 | 
			
		||||
    mtu: u16,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStackUdpStream {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        src_addr: SocketAddr,
 | 
			
		||||
        dst_addr: SocketAddr,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
        pkt_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
        mtu: u16,
 | 
			
		||||
        udp_timeout: Duration,
 | 
			
		||||
    ) -> Self {
 | 
			
		||||
        let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
 | 
			
		||||
        IpStackUdpStream {
 | 
			
		||||
            src_addr,
 | 
			
		||||
            dst_addr,
 | 
			
		||||
            stream_sender,
 | 
			
		||||
            stream_receiver,
 | 
			
		||||
            packet_sender: pkt_sender.clone(),
 | 
			
		||||
            first_paload: Some(payload),
 | 
			
		||||
            timeout: Box::pin(tokio::time::sleep_until(
 | 
			
		||||
                tokio::time::Instant::now() + udp_timeout,
 | 
			
		||||
            )),
 | 
			
		||||
            udp_timeout,
 | 
			
		||||
            mtu,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
 | 
			
		||||
        self.stream_sender.clone()
 | 
			
		||||
    }
 | 
			
		||||
    fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> Result<NetworkPacket, Error> {
 | 
			
		||||
        match (self.dst_addr.ip(), self.src_addr.ip()) {
 | 
			
		||||
            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
 | 
			
		||||
                let mut ip_h = Ipv4Header::new(0, ttl, 17, dst.octets(), src.octets());
 | 
			
		||||
                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size
 | 
			
		||||
                payload.truncate(line_buffer as usize);
 | 
			
		||||
                ip_h.payload_len = payload.len() as u16 + 8; // 8 is udp header size
 | 
			
		||||
                let udp_header = UdpHeader::with_ipv4_checksum(
 | 
			
		||||
                    self.dst_addr.port(),
 | 
			
		||||
                    self.src_addr.port(),
 | 
			
		||||
                    &ip_h,
 | 
			
		||||
                    &payload,
 | 
			
		||||
                )
 | 
			
		||||
                .map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
 | 
			
		||||
                Ok(NetworkPacket {
 | 
			
		||||
                    ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()),
 | 
			
		||||
                    transport: TransportHeader::Udp(udp_header),
 | 
			
		||||
                    payload,
 | 
			
		||||
                })
 | 
			
		||||
            }
 | 
			
		||||
            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
 | 
			
		||||
                let mut ip_h = Ipv6Header {
 | 
			
		||||
                    traffic_class: 0,
 | 
			
		||||
                    flow_label: 0,
 | 
			
		||||
                    payload_length: 0,
 | 
			
		||||
                    next_header: 17,
 | 
			
		||||
                    hop_limit: ttl,
 | 
			
		||||
                    source: dst.octets(),
 | 
			
		||||
                    destination: src.octets(),
 | 
			
		||||
                };
 | 
			
		||||
                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16 + 8); // 8 is udp header size
 | 
			
		||||
 | 
			
		||||
                payload.truncate(line_buffer as usize);
 | 
			
		||||
 | 
			
		||||
                ip_h.payload_length = payload.len() as u16 + 8; // 8 is udp header size
 | 
			
		||||
                let udp_header = UdpHeader::with_ipv6_checksum(
 | 
			
		||||
                    self.dst_addr.port(),
 | 
			
		||||
                    self.src_addr.port(),
 | 
			
		||||
                    &ip_h,
 | 
			
		||||
                    &payload,
 | 
			
		||||
                )
 | 
			
		||||
                .map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
 | 
			
		||||
                Ok(NetworkPacket {
 | 
			
		||||
                    ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()),
 | 
			
		||||
                    transport: TransportHeader::Udp(udp_header),
 | 
			
		||||
                    payload,
 | 
			
		||||
                })
 | 
			
		||||
            }
 | 
			
		||||
            _ => unreachable!(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn local_addr(&self) -> SocketAddr {
 | 
			
		||||
        self.src_addr
 | 
			
		||||
    }
 | 
			
		||||
    pub fn peer_addr(&self) -> SocketAddr {
 | 
			
		||||
        self.dst_addr
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRead for IpStackUdpStream {
 | 
			
		||||
    fn poll_read(
 | 
			
		||||
        mut self: Pin<&mut Self>,
 | 
			
		||||
        cx: &mut task::Context<'_>,
 | 
			
		||||
        buf: &mut tokio::io::ReadBuf<'_>,
 | 
			
		||||
    ) -> task::Poll<io::Result<()>> {
 | 
			
		||||
        if let Some(p) = self.first_paload.take() {
 | 
			
		||||
            buf.put_slice(&p);
 | 
			
		||||
            return Poll::Ready(Ok(()));
 | 
			
		||||
        }
 | 
			
		||||
        if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) {
 | 
			
		||||
            return Poll::Ready(Ok(())); // todo: return timeout error
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let udp_timeout = self.udp_timeout;
 | 
			
		||||
        match self.stream_receiver.poll_recv(cx) {
 | 
			
		||||
            Poll::Ready(Some(p)) => {
 | 
			
		||||
                buf.put_slice(&p.payload);
 | 
			
		||||
                self.timeout
 | 
			
		||||
                    .as_mut()
 | 
			
		||||
                    .reset(tokio::time::Instant::now() + udp_timeout);
 | 
			
		||||
                Poll::Ready(Ok(()))
 | 
			
		||||
            }
 | 
			
		||||
            Poll::Ready(None) => Poll::Ready(Ok(())),
 | 
			
		||||
            Poll::Pending => Poll::Pending,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncWrite for IpStackUdpStream {
 | 
			
		||||
    fn poll_write(
 | 
			
		||||
        mut self: Pin<&mut Self>,
 | 
			
		||||
        _cx: &mut task::Context<'_>,
 | 
			
		||||
        buf: &[u8],
 | 
			
		||||
    ) -> task::Poll<Result<usize, io::Error>> {
 | 
			
		||||
        let udp_timeout = self.udp_timeout;
 | 
			
		||||
        self.timeout
 | 
			
		||||
            .as_mut()
 | 
			
		||||
            .reset(tokio::time::Instant::now() + udp_timeout);
 | 
			
		||||
        let packet = self.create_rev_packet(TTL, buf.to_vec())?;
 | 
			
		||||
        let payload_len = packet.payload.len();
 | 
			
		||||
        self.packet_sender
 | 
			
		||||
            .send(packet)
 | 
			
		||||
            .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
 | 
			
		||||
        std::task::Poll::Ready(Ok(payload_len))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_flush(
 | 
			
		||||
        self: Pin<&mut Self>,
 | 
			
		||||
        _cx: &mut task::Context<'_>,
 | 
			
		||||
    ) -> task::Poll<Result<(), io::Error>> {
 | 
			
		||||
        Poll::Ready(Ok(()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_shutdown(
 | 
			
		||||
        self: Pin<&mut Self>,
 | 
			
		||||
        _cx: &mut task::Context<'_>,
 | 
			
		||||
    ) -> task::Poll<Result<(), io::Error>> {
 | 
			
		||||
        Poll::Ready(Ok(()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										111
									
								
								libs/ipstack/src/stream/unknown.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								libs/ipstack/src/stream/unknown.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,111 @@
 | 
			
		||||
use std::{io::Error, mem, net::IpAddr};
 | 
			
		||||
 | 
			
		||||
use etherparse::{IpHeader, Ipv4Extensions, Ipv4Header, Ipv6Extensions, Ipv6Header};
 | 
			
		||||
use tokio::sync::mpsc::UnboundedSender;
 | 
			
		||||
 | 
			
		||||
use crate::{
 | 
			
		||||
    packet::{NetworkPacket, TransportHeader},
 | 
			
		||||
    TTL,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
pub struct IpStackUnknownTransport {
 | 
			
		||||
    src_addr: IpAddr,
 | 
			
		||||
    dst_addr: IpAddr,
 | 
			
		||||
    payload: Vec<u8>,
 | 
			
		||||
    protocol: u8,
 | 
			
		||||
    mtu: u16,
 | 
			
		||||
    packet_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl IpStackUnknownTransport {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        src_addr: IpAddr,
 | 
			
		||||
        dst_addr: IpAddr,
 | 
			
		||||
        payload: Vec<u8>,
 | 
			
		||||
        ip: &IpHeader,
 | 
			
		||||
        mtu: u16,
 | 
			
		||||
        packet_sender: UnboundedSender<NetworkPacket>,
 | 
			
		||||
    ) -> Self {
 | 
			
		||||
        let protocol = match ip {
 | 
			
		||||
            IpHeader::Version4(ip, _) => ip.protocol,
 | 
			
		||||
            IpHeader::Version6(ip, _) => ip.next_header,
 | 
			
		||||
        };
 | 
			
		||||
        IpStackUnknownTransport {
 | 
			
		||||
            src_addr,
 | 
			
		||||
            dst_addr,
 | 
			
		||||
            payload,
 | 
			
		||||
            protocol,
 | 
			
		||||
            mtu,
 | 
			
		||||
            packet_sender,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn src_addr(&self) -> IpAddr {
 | 
			
		||||
        self.src_addr
 | 
			
		||||
    }
 | 
			
		||||
    pub fn dst_addr(&self) -> IpAddr {
 | 
			
		||||
        self.dst_addr
 | 
			
		||||
    }
 | 
			
		||||
    pub fn payload(&self) -> &[u8] {
 | 
			
		||||
        &self.payload
 | 
			
		||||
    }
 | 
			
		||||
    pub fn ip_protocol(&self) -> u8 {
 | 
			
		||||
        self.protocol
 | 
			
		||||
    }
 | 
			
		||||
    pub async fn send(&self, mut payload: Vec<u8>) -> Result<(), Error> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let packet = self.create_rev_packet(&mut payload)?;
 | 
			
		||||
            self.packet_sender
 | 
			
		||||
                .send(packet)
 | 
			
		||||
                .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?;
 | 
			
		||||
            if payload.is_empty() {
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn create_rev_packet(&self, payload: &mut Vec<u8>) -> Result<NetworkPacket, Error> {
 | 
			
		||||
        match (self.dst_addr, self.src_addr) {
 | 
			
		||||
            (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
 | 
			
		||||
                let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets());
 | 
			
		||||
                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
 | 
			
		||||
 | 
			
		||||
                let p = if payload.len() > line_buffer as usize {
 | 
			
		||||
                    payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
 | 
			
		||||
                } else {
 | 
			
		||||
                    mem::take(payload)
 | 
			
		||||
                };
 | 
			
		||||
                ip_h.payload_len = p.len() as u16;
 | 
			
		||||
                Ok(NetworkPacket {
 | 
			
		||||
                    ip: etherparse::IpHeader::Version4(ip_h, Ipv4Extensions::default()),
 | 
			
		||||
                    transport: TransportHeader::Unknown,
 | 
			
		||||
                    payload: p,
 | 
			
		||||
                })
 | 
			
		||||
            }
 | 
			
		||||
            (std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
 | 
			
		||||
                let mut ip_h = Ipv6Header {
 | 
			
		||||
                    traffic_class: 0,
 | 
			
		||||
                    flow_label: 0,
 | 
			
		||||
                    payload_length: 0,
 | 
			
		||||
                    next_header: 17,
 | 
			
		||||
                    hop_limit: TTL,
 | 
			
		||||
                    source: dst.octets(),
 | 
			
		||||
                    destination: src.octets(),
 | 
			
		||||
                };
 | 
			
		||||
                let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16);
 | 
			
		||||
                payload.truncate(line_buffer as usize);
 | 
			
		||||
                ip_h.payload_length = payload.len() as u16;
 | 
			
		||||
                let p = if payload.len() > line_buffer as usize {
 | 
			
		||||
                    payload.drain(0..line_buffer as usize).collect::<Vec<u8>>()
 | 
			
		||||
                } else {
 | 
			
		||||
                    mem::take(payload)
 | 
			
		||||
                };
 | 
			
		||||
                Ok(NetworkPacket {
 | 
			
		||||
                    ip: etherparse::IpHeader::Version6(ip_h, Ipv6Extensions::default()),
 | 
			
		||||
                    transport: TransportHeader::Unknown,
 | 
			
		||||
                    payload: p,
 | 
			
		||||
                })
 | 
			
		||||
            }
 | 
			
		||||
            _ => unreachable!(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -22,6 +22,7 @@ fn main() -> Result<()> {
 | 
			
		||||
        initrd_path: initrd_path.as_str(),
 | 
			
		||||
        cmdline: "debug elevator=noop",
 | 
			
		||||
        disks: vec![],
 | 
			
		||||
        consoles: vec![],
 | 
			
		||||
        vifs: vec![],
 | 
			
		||||
        filesystems: vec![],
 | 
			
		||||
        extra_keys: vec![],
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,9 @@ pub struct DomainNetworkInterface<'a> {
 | 
			
		||||
    pub script: Option<&'a str>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct DomainConsole {}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct DomainConfig<'a> {
 | 
			
		||||
    pub backend_domid: u32,
 | 
			
		||||
@ -66,6 +69,7 @@ pub struct DomainConfig<'a> {
 | 
			
		||||
    pub initrd_path: &'a str,
 | 
			
		||||
    pub cmdline: &'a str,
 | 
			
		||||
    pub disks: Vec<DomainDisk<'a>>,
 | 
			
		||||
    pub consoles: Vec<DomainConsole>,
 | 
			
		||||
    pub vifs: Vec<DomainNetworkInterface<'a>>,
 | 
			
		||||
    pub filesystems: Vec<DomainFilesystem<'a>>,
 | 
			
		||||
    pub extra_keys: Vec<(String, String)>,
 | 
			
		||||
@ -348,9 +352,23 @@ impl XenClient {
 | 
			
		||||
            &backend_dom_path,
 | 
			
		||||
            config.backend_domid,
 | 
			
		||||
            domid,
 | 
			
		||||
            console_evtchn,
 | 
			
		||||
            console_mfn,
 | 
			
		||||
            0,
 | 
			
		||||
            Some(console_evtchn),
 | 
			
		||||
            Some(console_mfn),
 | 
			
		||||
        )?;
 | 
			
		||||
 | 
			
		||||
        for (index, _) in config.consoles.iter().enumerate() {
 | 
			
		||||
            self.console_device_add(
 | 
			
		||||
                &dom_path,
 | 
			
		||||
                &backend_dom_path,
 | 
			
		||||
                config.backend_domid,
 | 
			
		||||
                domid,
 | 
			
		||||
                index + 1,
 | 
			
		||||
                None,
 | 
			
		||||
                None,
 | 
			
		||||
            )?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (index, disk) in config.disks.iter().enumerate() {
 | 
			
		||||
            self.disk_device_add(
 | 
			
		||||
                &dom_path,
 | 
			
		||||
@ -438,35 +456,54 @@ impl XenClient {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::too_many_arguments, clippy::unnecessary_unwrap)]
 | 
			
		||||
    fn console_device_add(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        dom_path: &str,
 | 
			
		||||
        backend_dom_path: &str,
 | 
			
		||||
        backend_domid: u32,
 | 
			
		||||
        domid: u32,
 | 
			
		||||
        port: u32,
 | 
			
		||||
        mfn: u64,
 | 
			
		||||
        index: usize,
 | 
			
		||||
        port: Option<u32>,
 | 
			
		||||
        mfn: Option<u64>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let backend_entries = vec![
 | 
			
		||||
        let mut backend_entries = vec![
 | 
			
		||||
            ("frontend-id", domid.to_string()),
 | 
			
		||||
            ("online", "1".to_string()),
 | 
			
		||||
            ("state", "1".to_string()),
 | 
			
		||||
            ("protocol", "vt100".to_string()),
 | 
			
		||||
        ];
 | 
			
		||||
 | 
			
		||||
        let frontend_entries = vec![
 | 
			
		||||
        let mut frontend_entries = vec![
 | 
			
		||||
            ("backend-id", backend_domid.to_string()),
 | 
			
		||||
            ("limit", "1048576".to_string()),
 | 
			
		||||
            ("type", "xenconsoled".to_string()),
 | 
			
		||||
            ("output", "pty".to_string()),
 | 
			
		||||
            ("tty", "".to_string()),
 | 
			
		||||
            ("port", port.to_string()),
 | 
			
		||||
            ("ring-ref", mfn.to_string()),
 | 
			
		||||
        ];
 | 
			
		||||
 | 
			
		||||
        if index == 0 {
 | 
			
		||||
            frontend_entries.push(("type", "xenconsoled".to_string()));
 | 
			
		||||
        } else {
 | 
			
		||||
            frontend_entries.push(("type", "ioemu".to_string()));
 | 
			
		||||
            backend_entries.push(("connection", "pty".to_string()));
 | 
			
		||||
            backend_entries.push(("output", "pty".to_string()));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if port.is_some() && mfn.is_some() {
 | 
			
		||||
            frontend_entries.extend_from_slice(&[
 | 
			
		||||
                ("port", port.unwrap().to_string()),
 | 
			
		||||
                ("ring-ref", mfn.unwrap().to_string()),
 | 
			
		||||
            ]);
 | 
			
		||||
        } else {
 | 
			
		||||
            frontend_entries.extend_from_slice(&[
 | 
			
		||||
                ("state", "1".to_string()),
 | 
			
		||||
                ("protocol", "vt100".to_string()),
 | 
			
		||||
            ]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.device_add(
 | 
			
		||||
            "console",
 | 
			
		||||
            0,
 | 
			
		||||
            index as u64,
 | 
			
		||||
            dom_path,
 | 
			
		||||
            backend_dom_path,
 | 
			
		||||
            backend_domid,
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,16 @@ rtnetlink = { workspace = true }
 | 
			
		||||
netlink-packet-route = { workspace = true }
 | 
			
		||||
tokio = { workspace = true }
 | 
			
		||||
futures = { workspace = true }
 | 
			
		||||
smoltcp = { workspace = true }
 | 
			
		||||
libc = { workspace = true }
 | 
			
		||||
udp-stream = { workspace = true }
 | 
			
		||||
 | 
			
		||||
[dependencies.advmac]
 | 
			
		||||
path = "../libs/advmac"
 | 
			
		||||
 | 
			
		||||
[dependencies.ipstack]
 | 
			
		||||
path = "../libs/ipstack"
 | 
			
		||||
features = ["log"]
 | 
			
		||||
 | 
			
		||||
[lib]
 | 
			
		||||
path = "src/lib.rs"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,30 +1,21 @@
 | 
			
		||||
use std::os::fd::AsRawFd;
 | 
			
		||||
use std::panic::UnwindSafe;
 | 
			
		||||
use std::str::FromStr;
 | 
			
		||||
use std::sync::{Arc, Mutex};
 | 
			
		||||
use std::time::Duration;
 | 
			
		||||
use std::{panic, thread};
 | 
			
		||||
 | 
			
		||||
use advmac::MacAddr6;
 | 
			
		||||
use anyhow::{anyhow, Result};
 | 
			
		||||
use futures::TryStreamExt;
 | 
			
		||||
use log::{error, info, warn};
 | 
			
		||||
use ipstack::stream::IpStackStream;
 | 
			
		||||
use log::{debug, error, info, warn};
 | 
			
		||||
use netlink_packet_route::link::LinkAttribute;
 | 
			
		||||
use smoltcp::iface::{Config, Interface, SocketSet};
 | 
			
		||||
use smoltcp::phy::{self, RawSocket};
 | 
			
		||||
use smoltcp::time::Instant;
 | 
			
		||||
use smoltcp::wire::{EthernetAddress, HardwareAddress, IpCidr};
 | 
			
		||||
use raw_socket::{AsyncRawSocket, RawSocket};
 | 
			
		||||
use tokio::net::TcpStream;
 | 
			
		||||
use tokio::time::sleep;
 | 
			
		||||
use udp_stream::UdpStream;
 | 
			
		||||
 | 
			
		||||
mod raw_socket;
 | 
			
		||||
 | 
			
		||||
pub struct NetworkBackend {
 | 
			
		||||
    pub interface: String,
 | 
			
		||||
    pub device: RawSocket,
 | 
			
		||||
    pub addresses: Vec<IpCidr>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
unsafe impl Send for NetworkBackend {}
 | 
			
		||||
impl UnwindSafe for NetworkBackend {}
 | 
			
		||||
 | 
			
		||||
pub struct NetworkService {
 | 
			
		||||
    pub network: String,
 | 
			
		||||
}
 | 
			
		||||
@ -36,18 +27,9 @@ impl NetworkService {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl NetworkBackend {
 | 
			
		||||
    pub fn new(iface: &str, cidrs: &[&str]) -> Result<NetworkBackend> {
 | 
			
		||||
        let device = RawSocket::new(iface, smoltcp::phy::Medium::Ethernet)?;
 | 
			
		||||
        let mut addresses: Vec<IpCidr> = Vec::new();
 | 
			
		||||
        for cidr in cidrs {
 | 
			
		||||
            let address =
 | 
			
		||||
                IpCidr::from_str(cidr).map_err(|_| anyhow!("failed to parse cidr: {}", *cidr))?;
 | 
			
		||||
            addresses.push(address);
 | 
			
		||||
        }
 | 
			
		||||
    pub fn new(iface: &str) -> Result<NetworkBackend> {
 | 
			
		||||
        Ok(NetworkBackend {
 | 
			
		||||
            interface: iface.to_string(),
 | 
			
		||||
            device,
 | 
			
		||||
            addresses,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -73,34 +55,56 @@ impl NetworkBackend {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn run(mut self) -> Result<()> {
 | 
			
		||||
        let result = panic::catch_unwind(move || self.run_maybe_panic());
 | 
			
		||||
    pub async fn run(&mut self) -> Result<()> {
 | 
			
		||||
        let mut config = ipstack::IpStackConfig::default();
 | 
			
		||||
        config.mtu(1500);
 | 
			
		||||
        config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes
 | 
			
		||||
        config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds
 | 
			
		||||
 | 
			
		||||
        if result.is_err() {
 | 
			
		||||
            return Err(anyhow!("network backend has terminated"));
 | 
			
		||||
        let mut socket = RawSocket::new(&self.interface)?;
 | 
			
		||||
        socket.bind_interface()?;
 | 
			
		||||
        let socket = AsyncRawSocket::new(socket)?;
 | 
			
		||||
        let mut stack = ipstack::IpStack::new(config, socket);
 | 
			
		||||
 | 
			
		||||
        while let Ok(stream) = stack.accept().await {
 | 
			
		||||
            self.process_stream(stream).await?
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        result.unwrap()
 | 
			
		||||
    async fn process_stream(&mut self, stream: IpStackStream) -> Result<()> {
 | 
			
		||||
        match stream {
 | 
			
		||||
            IpStackStream::Tcp(mut tcp) => {
 | 
			
		||||
                debug!("tcp: {}", tcp.peer_addr());
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    if let Ok(mut stream) = TcpStream::connect(tcp.peer_addr()).await {
 | 
			
		||||
                        let _ = tokio::io::copy_bidirectional(&mut stream, &mut tcp).await;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        warn!("failed to connect to tcp address: {}", tcp.peer_addr());
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
    fn run_maybe_panic(&mut self) -> Result<()> {
 | 
			
		||||
        let mac = MacAddr6::random();
 | 
			
		||||
        let mac = HardwareAddress::Ethernet(EthernetAddress(mac.to_array()));
 | 
			
		||||
        let config = Config::new(mac);
 | 
			
		||||
        let mut iface = Interface::new(config, &mut self.device, Instant::now());
 | 
			
		||||
        iface.update_ip_addrs(|addrs| {
 | 
			
		||||
            addrs
 | 
			
		||||
                .extend_from_slice(&self.addresses)
 | 
			
		||||
                .expect("failed to set ip addresses");
 | 
			
		||||
                });
 | 
			
		||||
 | 
			
		||||
        let mut sockets = SocketSet::new(vec![]);
 | 
			
		||||
        let fd = self.device.as_raw_fd();
 | 
			
		||||
        loop {
 | 
			
		||||
            let timestamp = Instant::now();
 | 
			
		||||
            iface.poll(timestamp, &mut self.device, &mut sockets);
 | 
			
		||||
            phy::wait(fd, iface.poll_delay(timestamp, &sockets))?;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpStackStream::Udp(mut udp) => {
 | 
			
		||||
                debug!("udp: {}", udp.peer_addr());
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    if let Ok(mut stream) = UdpStream::connect(udp.peer_addr()).await {
 | 
			
		||||
                        let _ = tokio::io::copy_bidirectional(&mut stream, &mut udp).await;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        warn!("failed to connect to udp address: {}", udp.peer_addr());
 | 
			
		||||
                    }
 | 
			
		||||
                });
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpStackStream::UnknownTransport(u) => {
 | 
			
		||||
                debug!("unknown transport: {}", u.dst_addr());
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            IpStackStream::UnknownNetwork(packet) => {
 | 
			
		||||
                debug!("unknown network: {:?}", packet);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -156,15 +160,15 @@ impl NetworkService {
 | 
			
		||||
        spawned: Arc<Mutex<Vec<String>>>,
 | 
			
		||||
    ) -> Result<()> {
 | 
			
		||||
        let interface = interface.to_string();
 | 
			
		||||
        let mut network = NetworkBackend::new(&interface, &[&self.network])?;
 | 
			
		||||
        let mut network = NetworkBackend::new(&interface)?;
 | 
			
		||||
        info!("initializing network backend for interface {}", interface);
 | 
			
		||||
        network.init().await?;
 | 
			
		||||
        tokio::time::sleep(Duration::from_secs(1)).await;
 | 
			
		||||
        info!("spawning network backend for interface {}", interface);
 | 
			
		||||
        thread::spawn(move || {
 | 
			
		||||
            if let Err(error) = network.run() {
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Err(error) = network.run().await {
 | 
			
		||||
                error!(
 | 
			
		||||
                    "failed to run network backend for interface {}: {}",
 | 
			
		||||
                    "network backend for interface {} has been stopped: {}",
 | 
			
		||||
                    interface, error
 | 
			
		||||
                );
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										202
									
								
								network/src/raw_socket.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								network/src/raw_socket.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,202 @@
 | 
			
		||||
use futures::ready;
 | 
			
		||||
use std::os::unix::io::{AsRawFd, RawFd};
 | 
			
		||||
use std::pin::Pin;
 | 
			
		||||
use std::task::{Context, Poll};
 | 
			
		||||
use std::{io, mem};
 | 
			
		||||
 | 
			
		||||
use anyhow::Result;
 | 
			
		||||
use tokio::io::unix::AsyncFd;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
 | 
			
		||||
 | 
			
		||||
const SIOCGIFINDEX: libc::c_ulong = 0x8933;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct RawSocket {
 | 
			
		||||
    protocol: libc::c_short,
 | 
			
		||||
    lower: libc::c_int,
 | 
			
		||||
    ifreq: Ifreq,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsRawFd for RawSocket {
 | 
			
		||||
    fn as_raw_fd(&self) -> RawFd {
 | 
			
		||||
        self.lower
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl RawSocket {
 | 
			
		||||
    pub fn new(name: &str) -> io::Result<RawSocket> {
 | 
			
		||||
        let protocol: libc::c_short = 0x0003;
 | 
			
		||||
        let lower = unsafe {
 | 
			
		||||
            let lower = libc::socket(
 | 
			
		||||
                libc::AF_PACKET,
 | 
			
		||||
                libc::SOCK_RAW | libc::SOCK_NONBLOCK,
 | 
			
		||||
                protocol.to_be() as i32,
 | 
			
		||||
            );
 | 
			
		||||
            if lower == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            lower
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        Ok(RawSocket {
 | 
			
		||||
            protocol,
 | 
			
		||||
            lower,
 | 
			
		||||
            ifreq: ifreq_for(name),
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn bind_interface(&mut self) -> io::Result<()> {
 | 
			
		||||
        let sockaddr = libc::sockaddr_ll {
 | 
			
		||||
            sll_family: libc::AF_PACKET as u16,
 | 
			
		||||
            sll_protocol: self.protocol.to_be() as u16,
 | 
			
		||||
            sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, SIOCGIFINDEX)?,
 | 
			
		||||
            sll_hatype: 1,
 | 
			
		||||
            sll_pkttype: 0,
 | 
			
		||||
            sll_halen: 6,
 | 
			
		||||
            sll_addr: [0; 8],
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let res = libc::bind(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr,
 | 
			
		||||
                mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t,
 | 
			
		||||
            );
 | 
			
		||||
            if res == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn recv(&self, buffer: &mut [u8]) -> io::Result<usize> {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let len = libc::recv(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                buffer.as_mut_ptr() as *mut libc::c_void,
 | 
			
		||||
                buffer.len(),
 | 
			
		||||
                0,
 | 
			
		||||
            );
 | 
			
		||||
            if len == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            Ok(len as usize)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn send(&self, buffer: &[u8]) -> io::Result<usize> {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            let len = libc::send(
 | 
			
		||||
                self.lower,
 | 
			
		||||
                buffer.as_ptr() as *const libc::c_void,
 | 
			
		||||
                buffer.len(),
 | 
			
		||||
                0,
 | 
			
		||||
            );
 | 
			
		||||
            if len == -1 {
 | 
			
		||||
                return Err(io::Error::last_os_error());
 | 
			
		||||
            }
 | 
			
		||||
            Ok(len as usize)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for RawSocket {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        unsafe {
 | 
			
		||||
            libc::close(self.lower);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[repr(C)]
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct Ifreq {
 | 
			
		||||
    ifr_name: [libc::c_char; libc::IF_NAMESIZE],
 | 
			
		||||
    ifr_data: libc::c_int, /* ifr_ifindex or ifr_mtu */
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn ifreq_for(name: &str) -> Ifreq {
 | 
			
		||||
    let mut ifreq = Ifreq {
 | 
			
		||||
        ifr_name: [0; libc::IF_NAMESIZE],
 | 
			
		||||
        ifr_data: 0,
 | 
			
		||||
    };
 | 
			
		||||
    for (i, byte) in name.as_bytes().iter().enumerate() {
 | 
			
		||||
        ifreq.ifr_name[i] = *byte as libc::c_char
 | 
			
		||||
    }
 | 
			
		||||
    ifreq
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn ifreq_ioctl(
 | 
			
		||||
    lower: libc::c_int,
 | 
			
		||||
    ifreq: &mut Ifreq,
 | 
			
		||||
    cmd: libc::c_ulong,
 | 
			
		||||
) -> io::Result<libc::c_int> {
 | 
			
		||||
    unsafe {
 | 
			
		||||
        let res = libc::ioctl(lower, cmd as _, ifreq as *mut Ifreq);
 | 
			
		||||
        if res == -1 {
 | 
			
		||||
            return Err(io::Error::last_os_error());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(ifreq.ifr_data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct AsyncRawSocket {
 | 
			
		||||
    inner: AsyncFd<RawSocket>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRawSocket {
 | 
			
		||||
    pub fn new(socket: RawSocket) -> Result<Self> {
 | 
			
		||||
        Ok(Self {
 | 
			
		||||
            inner: AsyncFd::new(socket)?,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRead for AsyncRawSocket {
 | 
			
		||||
    fn poll_read(
 | 
			
		||||
        self: Pin<&mut Self>,
 | 
			
		||||
        cx: &mut Context<'_>,
 | 
			
		||||
        buf: &mut ReadBuf<'_>,
 | 
			
		||||
    ) -> Poll<io::Result<()>> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let mut guard = ready!(self.inner.poll_read_ready(cx))?;
 | 
			
		||||
 | 
			
		||||
            let unfilled = buf.initialize_unfilled();
 | 
			
		||||
            match guard.try_io(|inner| inner.get_ref().recv(unfilled)) {
 | 
			
		||||
                Ok(Ok(len)) => {
 | 
			
		||||
                    buf.advance(len);
 | 
			
		||||
                    return Poll::Ready(Ok(()));
 | 
			
		||||
                }
 | 
			
		||||
                Ok(Err(err)) => return Poll::Ready(Err(err)),
 | 
			
		||||
                Err(_would_block) => continue,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncWrite for AsyncRawSocket {
 | 
			
		||||
    fn poll_write(
 | 
			
		||||
        self: Pin<&mut Self>,
 | 
			
		||||
        cx: &mut Context<'_>,
 | 
			
		||||
        buf: &[u8],
 | 
			
		||||
    ) -> Poll<io::Result<usize>> {
 | 
			
		||||
        loop {
 | 
			
		||||
            let mut guard = ready!(self.inner.poll_write_ready(cx))?;
 | 
			
		||||
 | 
			
		||||
            match guard.try_io(|inner| inner.get_ref().send(buf)) {
 | 
			
		||||
                Ok(result) => return Poll::Ready(result),
 | 
			
		||||
                Err(_would_block) => continue,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
 | 
			
		||||
        Poll::Ready(Ok(()))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
 | 
			
		||||
        Poll::Ready(Ok(()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user