diff --git a/controller/bin/control.rs b/controller/bin/control.rs index 1716ce9..f6ec6ae 100644 --- a/controller/bin/control.rs +++ b/controller/bin/control.rs @@ -110,10 +110,15 @@ fn main() -> Result<()> { Commands::List { .. } => { let containers = controller.list()?; let mut table = cli_tables::Table::new(); - let header = vec!["uuid", "ipv4", "image"]; + let header = vec!["uuid", "ipv4", "ipv6", "image"]; table.push_row(&header)?; for container in containers { - let row = vec![container.uuid.to_string(), container.ipv4, container.image]; + let row = vec![ + container.uuid.to_string(), + container.ipv4, + container.ipv6, + container.image, + ]; table.push_row_string(&row)?; } diff --git a/controller/src/ctl/mod.rs b/controller/src/ctl/mod.rs index fff968a..26b5fbc 100644 --- a/controller/src/ctl/mod.rs +++ b/controller/src/ctl/mod.rs @@ -41,6 +41,7 @@ pub struct ContainerInfo { pub image: String, pub loops: Vec, pub ipv4: String, + pub ipv6: String, } impl Controller { @@ -83,21 +84,30 @@ impl Controller { let name = format!("hypha-{uuid}"); let image_info = self.compile(image)?; - let mut mac = MacAddr6::random(); - mac.set_local(true); - mac.set_multicast(false); - let ipv4 = self.allocate_ipv4()?; - let ipv6 = mac.to_link_local_ipv6(); + let mut gateway_mac = MacAddr6::random(); + gateway_mac.set_local(true); + gateway_mac.set_multicast(false); + let mut container_mac = MacAddr6::random(); + container_mac.set_local(true); + container_mac.set_multicast(false); + + let guest_ipv4 = self.allocate_ipv4()?; + let guest_ipv6 = container_mac.to_link_local_ipv6(); + let gateway_ipv4 = "192.168.42.1"; + let gateway_ipv6 = "fe80::1"; + let ipv4_network_mask: u32 = 24; + let ipv6_network_mask: u32 = 10; + let launch_config = LaunchInfo { network: Some(LaunchNetwork { link: "eth0".to_string(), ipv4: LaunchNetworkIpv4 { - address: format!("{}/24", ipv4), - gateway: "192.168.42.1".to_string(), + address: format!("{}/{}", guest_ipv4, ipv4_network_mask), + gateway: gateway_ipv4.to_string(), }, ipv6: LaunchNetworkIpv6 { - address: format!("{}/10", ipv6), - gateway: "fe80::1".to_string(), + address: format!("{}/{}", guest_ipv6, ipv6_network_mask), + gateway: gateway_ipv6.to_string(), }, resolver: LaunchNetworkResolver { nameservers: vec![ @@ -135,7 +145,8 @@ impl Controller { let cmdline_options = [if debug { "debug" } else { "quiet" }, "elevator=noop"]; let cmdline = cmdline_options.join(" "); - let mac = mac.to_string().replace('-', ":"); + let container_mac_string = container_mac.to_string().replace('-', ":"); + let gateway_mac_string = gateway_mac.to_string().replace('-', ":"); let config = DomainConfig { backend_domid: 0, name: &name, @@ -158,7 +169,7 @@ impl Controller { ], consoles: vec![], vifs: vec![DomainNetworkInterface { - mac: &mac, + mac: &container_mac_string, mtu: 1500, bridge: None, script: None, @@ -178,7 +189,30 @@ impl Controller { ), ), ("hypha/image".to_string(), image.to_string()), - ("hypha/ipv4".to_string(), ipv4.to_string()), + ( + "hypha/network/guest/ipv4".to_string(), + format!("{}/{}", guest_ipv4, ipv4_network_mask), + ), + ( + "hypha/network/guest/ipv6".to_string(), + format!("{}/{}", guest_ipv6, ipv6_network_mask), + ), + ( + "hypha/network/guest/mac".to_string(), + container_mac_string.clone(), + ), + ( + "hypha/network/gateway/ipv4".to_string(), + format!("{}/{}", gateway_ipv4, ipv4_network_mask), + ), + ( + "hypha/network/gateway/ipv6".to_string(), + format!("{}/{}", gateway_ipv6, ipv6_network_mask), + ), + ( + "hypha/network/gateway/mac".to_string(), + gateway_mac_string.clone(), + ), ], }; match self.client.create(&config) { @@ -305,7 +339,12 @@ impl Controller { let ipv4 = self .client .store - .read_string_optional(&format!("{}/hypha/ipv4", &dom_path))? + .read_string_optional(&format!("{}/hypha/network/guest/ipv4", &dom_path))? + .unwrap_or("unknown".to_string()); + let ipv6 = self + .client + .store + .read_string_optional(&format!("{}/hypha/network/guest/ipv6", &dom_path))? .unwrap_or("unknown".to_string()); let loops = Controller::parse_loop_set(&loops); containers.push(ContainerInfo { @@ -314,6 +353,7 @@ impl Controller { image, loops, ipv4, + ipv6, }); } Ok(containers) @@ -359,10 +399,11 @@ impl Controller { ]; for domid_candidate in self.client.store.list_any("/local/domain")? { let dom_path = format!("/local/domain/{}", domid_candidate); - let ip_path = format!("{}/hypha/ipv4", dom_path); + let ip_path = format!("{}/hypha/network/guest/ipv4", dom_path); let existing_ip = self.client.store.read_string_optional(&ip_path)?; if let Some(existing_ip) = existing_ip { - used.push(Ipv4Addr::from_str(&existing_ip)?); + let ipv4_network = Ipv4Network::from_str(&existing_ip)?; + used.push(ipv4_network.ip()); } } diff --git a/network/Cargo.toml b/network/Cargo.toml index 972048c..f401bcf 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -18,10 +18,14 @@ udp-stream = { workspace = true } smoltcp = { workspace = true } etherparse = { workspace = true } async-trait = { workspace = true } +uuid = { workspace = true } [dependencies.advmac] path = "../libs/advmac" +[dependencies.xenstore] +path = "../libs/xen/xenstore" + [lib] path = "src/lib.rs" @@ -32,3 +36,7 @@ path = "bin/network.rs" [[example]] name = "ping" path = "examples/ping.rs" + +[[example]] +name = "autonet" +path = "examples/autonet.rs" diff --git a/network/bin/network.rs b/network/bin/network.rs index 77ec2bd..ab30740 100644 --- a/network/bin/network.rs +++ b/network/bin/network.rs @@ -1,35 +1,15 @@ -use std::str::FromStr; - -use advmac::MacAddr6; use anyhow::Result; use clap::Parser; use env_logger::Env; use hyphanet::NetworkService; #[derive(Parser, Debug)] -struct NetworkArgs { - #[arg(long, default_value = "192.168.42.1/24")] - ipv4_network: String, +struct NetworkArgs {} - #[arg(long, default_value = "fe80::1/10")] - ipv6_network: String, - - #[arg(long)] - force_mac_address: Option, -} - -#[tokio::main] +#[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> Result<()> { env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init(); - let args = NetworkArgs::parse(); - - let force_mac_address = if let Some(mac_str) = args.force_mac_address { - Some(MacAddr6::from_str(&mac_str)?) - } else { - None - }; - - let mut service = NetworkService::new(args.ipv4_network, args.ipv6_network, force_mac_address)?; - service.watch().await?; - Ok(()) + let _ = NetworkArgs::parse(); + let mut service = NetworkService::new()?; + service.watch().await } diff --git a/network/examples/autonet.rs b/network/examples/autonet.rs new file mode 100644 index 0000000..ef5fe26 --- /dev/null +++ b/network/examples/autonet.rs @@ -0,0 +1,13 @@ +use std::{thread::sleep, time::Duration}; + +use anyhow::Result; +use hyphanet::autonet::AutoNetworkCollector; + +fn main() -> Result<()> { + let mut collector = AutoNetworkCollector::new()?; + loop { + let changeset = collector.read_changes()?; + println!("{:?}", changeset); + sleep(Duration::from_secs(2)); + } +} diff --git a/network/src/autonet.rs b/network/src/autonet.rs new file mode 100644 index 0000000..ef0bf83 --- /dev/null +++ b/network/src/autonet.rs @@ -0,0 +1,181 @@ +use anyhow::{anyhow, Result}; +use smoltcp::wire::{EthernetAddress, Ipv4Cidr, Ipv6Cidr}; +use std::{collections::HashMap, str::FromStr}; +use uuid::Uuid; +use xenstore::client::{XsdClient, XsdInterface, XsdTransaction}; + +pub struct AutoNetworkCollector { + client: XsdClient, + known: HashMap, +} + +#[derive(Debug, Clone)] +pub struct NetworkSide { + pub ipv4: Ipv4Cidr, + pub ipv6: Ipv6Cidr, + pub mac: EthernetAddress, +} + +#[derive(Debug, Clone)] +pub struct NetworkMetadata { + pub domid: u32, + pub uuid: Uuid, + pub guest: NetworkSide, + pub gateway: NetworkSide, +} + +impl NetworkMetadata { + pub fn interface(&self) -> String { + format!("vif{}.20", self.domid) + } +} + +#[derive(Debug, Clone)] +pub struct AutoNetworkChangeset { + pub added: Vec, + pub removed: Vec, +} + +impl AutoNetworkCollector { + pub fn new() -> Result { + Ok(AutoNetworkCollector { + client: XsdClient::open()?, + known: HashMap::new(), + }) + } + + pub fn read(&mut self) -> Result> { + let mut networks = Vec::new(); + let mut tx = self.client.transaction()?; + for domid_string in tx.list_any("/local/domain")? { + let Ok(domid) = domid_string.parse::() else { + continue; + }; + + let dom_path = format!("/local/domain/{}", domid_string); + let Some(uuid_string) = tx.read_string_optional(&format!("{}/hypha/uuid", dom_path))? + else { + continue; + }; + + let Ok(uuid) = uuid_string.parse::() else { + continue; + }; + + let Ok(guest) = + AutoNetworkCollector::read_network_side(uuid, &mut tx, &dom_path, "guest") + else { + continue; + }; + + let Ok(gateway) = + AutoNetworkCollector::read_network_side(uuid, &mut tx, &dom_path, "gateway") + else { + continue; + }; + + networks.push(NetworkMetadata { + domid, + uuid, + guest, + gateway, + }); + } + tx.commit()?; + Ok(networks) + } + + fn read_network_side( + uuid: Uuid, + tx: &mut XsdTransaction<'_>, + dom_path: &str, + side: &str, + ) -> Result { + let side_path = format!("{}/hypha/network/{}", dom_path, side); + let Some(ipv4) = tx.read_string_optional(&format!("{}/ipv4", side_path))? else { + return Err(anyhow!( + "hypha domain {} is missing {} ipv4 network entry", + uuid, + side + )); + }; + + let Some(ipv6) = tx.read_string_optional(&format!("{}/ipv6", side_path))? else { + return Err(anyhow!( + "hypha domain {} is missing {} ipv6 network entry", + uuid, + side + )); + }; + + let Some(mac) = tx.read_string_optional(&format!("{}/mac", side_path))? else { + return Err(anyhow!( + "hypha domain {} is missing {} mac address entry", + uuid, + side + )); + }; + + let Ok(ipv4) = Ipv4Cidr::from_str(&ipv4) else { + return Err(anyhow!( + "hypha domain {} has invalid {} ipv4 network cidr entry: {}", + uuid, + side, + ipv4 + )); + }; + + let Ok(ipv6) = Ipv6Cidr::from_str(&ipv6) else { + return Err(anyhow!( + "hypha domain {} has invalid {} ipv6 network cidr entry: {}", + uuid, + side, + ipv6 + )); + }; + + let Ok(mac) = EthernetAddress::from_str(&mac) else { + return Err(anyhow!( + "hypha domain {} has invalid {} mac address entry: {}", + uuid, + side, + mac + )); + }; + + Ok(NetworkSide { ipv4, ipv6, mac }) + } + + pub fn read_changes(&mut self) -> Result { + let mut seen: Vec = Vec::new(); + let mut added: Vec = Vec::new(); + let mut removed: Vec = Vec::new(); + + for network in self.read()? { + seen.push(network.uuid); + if self.known.contains_key(&network.uuid) { + continue; + } + let _ = self.known.insert(network.uuid, network.clone()); + added.push(network); + } + + let mut gone: Vec = Vec::new(); + for uuid in self.known.keys() { + if seen.contains(uuid) { + continue; + } + gone.push(*uuid); + } + + for uuid in &gone { + let Some(network) = self.known.remove(uuid) else { + continue; + }; + + removed.push(network); + } + + Ok(AutoNetworkChangeset { added, removed }) + } +} diff --git a/network/src/backend.rs b/network/src/backend.rs index f849d96..b44a430 100644 --- a/network/src/backend.rs +++ b/network/src/backend.rs @@ -1,18 +1,18 @@ +use crate::autonet::NetworkMetadata; use crate::chandev::ChannelDevice; use crate::nat::NatRouter; use crate::pkt::RecvPacket; use crate::proxynat::ProxyNatHandlerFactory; use crate::raw_socket::{AsyncRawSocket, RawSocketProtocol}; -use advmac::MacAddr6; +use crate::vbridge::{BridgeJoinHandle, VirtualBridge}; use anyhow::{anyhow, Result}; use etherparse::SlicedPacket; use futures::TryStreamExt; -use log::debug; +use log::{debug, info, warn}; use smoltcp::iface::{Config, Interface, SocketSet}; use smoltcp::phy::Medium; use smoltcp::time::Instant; use smoltcp::wire::{HardwareAddress, IpCidr}; -use std::str::FromStr; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::select; @@ -20,14 +20,13 @@ use tokio::sync::mpsc::{channel, Receiver}; #[derive(Clone)] pub struct NetworkBackend { - ipv4: String, - ipv6: String, - force_mac_address: Option, - interface: String, + metadata: NetworkMetadata, + bridge: VirtualBridge, } enum NetworkStackSelect<'a> { Receive(&'a [u8]), + BridgeSend(Option>), Send(Option>), Reclaim, } @@ -40,18 +39,25 @@ struct NetworkStack<'a> { interface: Interface, sockets: SocketSet<'a>, router: NatRouter, + bridge: BridgeJoinHandle, } impl NetworkStack<'_> { async fn poll(&mut self, buffer: &mut [u8]) -> Result<()> { let what = select! { x = self.kdev.read(buffer) => NetworkStackSelect::Receive(&buffer[0..x?]), + x = self.bridge.bridge_rx_receiver.recv() => NetworkStackSelect::BridgeSend(x), + x = self.bridge.broadcast_rx_receiver.recv() => NetworkStackSelect::BridgeSend(x.ok()), x = self.tx.recv() => NetworkStackSelect::Send(x), _ = self.router.process_reclaim() => NetworkStackSelect::Reclaim, }; match what { NetworkStackSelect::Receive(packet) => { + if let Err(error) = self.bridge.bridge_tx_sender.try_send(packet.to_vec()) { + warn!("failed to send guest packet to bridge: {}", error); + } + let slice = SlicedPacket::from_ethernet(packet)?; let packet = RecvPacket::new(packet, &slice)?; if let Err(error) = self.router.process(&packet).await { @@ -63,6 +69,14 @@ impl NetworkStack<'_> { .poll(Instant::now(), &mut self.udev, &mut self.sockets); } + NetworkStackSelect::BridgeSend(Some(packet)) => { + if let Err(error) = self.udev.tx.try_send(packet) { + warn!("failed to send bridge packet to guest: {}", error); + } + } + + NetworkStackSelect::BridgeSend(None) => {} + NetworkStackSelect::Send(packet) => { if let Some(packet) = packet { self.kdev.write_all(&packet).await? @@ -77,34 +91,21 @@ impl NetworkStack<'_> { } impl NetworkBackend { - pub fn new( - ipv4: &str, - ipv6: &str, - force_mac_address: &Option, - interface: &str, - ) -> Result { - Ok(Self { - ipv4: ipv4.to_string(), - ipv6: ipv6.to_string(), - force_mac_address: *force_mac_address, - interface: interface.to_string(), - }) + pub fn new(metadata: NetworkMetadata, bridge: VirtualBridge) -> Result { + Ok(Self { metadata, bridge }) } pub async fn init(&mut self) -> Result<()> { + let interface = self.metadata.interface(); let (connection, handle, _) = rtnetlink::new_connection()?; tokio::spawn(connection); - let mut links = handle - .link() - .get() - .match_name(self.interface.to_string()) - .execute(); + let mut links = handle.link().get().match_name(interface.clone()).execute(); let link = links.try_next().await?; if link.is_none() { return Err(anyhow!( "unable to find network interface named {}", - self.interface + interface )); } let link = link.unwrap(); @@ -114,35 +115,28 @@ impl NetworkBackend { } pub async fn run(&self) -> Result<()> { - let mut stack = self.create_network_stack()?; + let mut stack = self.create_network_stack().await?; let mut buffer = vec![0u8; stack.mtu]; loop { stack.poll(&mut buffer).await?; } } - fn create_network_stack(&self) -> Result { + async fn create_network_stack(&self) -> Result { + let interface = self.metadata.interface(); let proxy = Box::new(ProxyNatHandlerFactory::new()); - let ipv4 = IpCidr::from_str(&self.ipv4) - .map_err(|_| anyhow!("failed to parse ipv4 cidr: {}", self.ipv4))?; - let ipv6 = IpCidr::from_str(&self.ipv6) - .map_err(|_| anyhow!("failed to parse ipv6 cidr: {}", self.ipv6))?; - let addresses: Vec = vec![ipv4, ipv6]; - let mut kdev = - AsyncRawSocket::bound_to_interface(&self.interface, RawSocketProtocol::Ethernet)?; - let mtu = kdev.mtu_of_interface(&self.interface)?; + let addresses: Vec = vec![ + self.metadata.gateway.ipv4.into(), + self.metadata.gateway.ipv6.into(), + ]; + let mut kdev = AsyncRawSocket::bound_to_interface(&interface, RawSocketProtocol::Ethernet)?; + let mtu = kdev.mtu_of_interface(&interface)?; let (tx_sender, tx_receiver) = channel::>(100); let mut udev = ChannelDevice::new(mtu, Medium::Ethernet, tx_sender.clone()); - let mac = self.force_mac_address.unwrap_or_else(|| { - let mut mac = MacAddr6::random(); - mac.set_local(true); - mac.set_multicast(false); - mac - }); - let mac = smoltcp::wire::EthernetAddress(mac.to_array()); + let mac = self.metadata.gateway.mac; let nat = NatRouter::new(mtu, proxy, mac, addresses.clone(), tx_sender.clone()); - let mac = HardwareAddress::Ethernet(mac); - let config = Config::new(mac); + let hardware_addr = HardwareAddress::Ethernet(mac); + let config = Config::new(hardware_addr); let mut iface = Interface::new(config, &mut udev, Instant::now()); iface.update_ip_addrs(|addrs| { addrs @@ -150,6 +144,7 @@ impl NetworkBackend { .expect("failed to set ip addresses"); }); let sockets = SocketSet::new(vec![]); + let handle = self.bridge.join(self.metadata.guest.mac).await?; Ok(NetworkStack { mtu, tx: tx_receiver, @@ -158,6 +153,23 @@ impl NetworkBackend { interface: iface, sockets, router: nat, + bridge: handle, }) } + + pub async fn launch(self) -> Result<()> { + tokio::task::spawn(async move { + info!( + "lauched network backend for hypha guest {}", + self.metadata.uuid + ); + if let Err(error) = self.run().await { + warn!( + "network backend for hypha guest {} failed: {}", + self.metadata.uuid, error + ); + } + }); + Ok(()) + } } diff --git a/network/src/lib.rs b/network/src/lib.rs index c346758..e9eea23 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -1,14 +1,13 @@ -use advmac::MacAddr6; -use anyhow::Result; -use futures::TryStreamExt; -use log::{error, info, warn}; -use netlink_packet_route::link::LinkAttribute; -use std::sync::{Arc, Mutex}; use std::time::Duration; + +use anyhow::Result; +use autonet::{AutoNetworkChangeset, AutoNetworkCollector, NetworkMetadata}; use tokio::time::sleep; +use vbridge::VirtualBridge; use crate::backend::NetworkBackend; +pub mod autonet; pub mod backend; pub mod chandev; pub mod icmp; @@ -16,99 +15,45 @@ pub mod nat; pub mod pkt; pub mod proxynat; pub mod raw_socket; +pub mod vbridge; pub struct NetworkService { - pub ipv4: String, - pub ipv6: String, - pub force_mac_address: Option, + pub bridge: VirtualBridge, } impl NetworkService { - pub fn new( - ipv4: String, - ipv6: String, - force_mac_address: Option, - ) -> Result { + pub fn new() -> Result { Ok(NetworkService { - ipv4, - ipv6, - force_mac_address, + bridge: VirtualBridge::new()?, }) } } impl NetworkService { pub async fn watch(&mut self) -> Result<()> { - let spawned: Arc>> = Arc::new(Mutex::new(Vec::new())); - let (connection, handle, _) = rtnetlink::new_connection()?; - tokio::spawn(connection); + let mut collector = AutoNetworkCollector::new()?; loop { - let mut stream = handle.link().get().execute(); - while let Some(message) = stream.try_next().await? { - let mut name: Option = None; - for attribute in &message.attributes { - if let LinkAttribute::IfName(if_name) = attribute { - name = Some(if_name.clone()); - } - } - - if name.is_none() { - continue; - } - - let name = name.unwrap(); - if !name.starts_with("vif") { - continue; - } - - if let Ok(spawns) = spawned.lock() { - if spawns.contains(&name) { - continue; - } - } - - if let Err(error) = self.add_network_backend(&name, spawned.clone()).await { - warn!( - "failed to initialize network backend for interface {}: {}", - name, error - ); - } - - if let Ok(mut spawns) = spawned.lock() { - spawns.push(name.clone()); - } - } - + let changeset = collector.read_changes()?; + self.process_network_changeset(changeset)?; sleep(Duration::from_secs(2)).await; } } - async fn add_network_backend( - &mut self, - interface: &str, - spawned: Arc>>, - ) -> Result<()> { - let interface = interface.to_string(); - let mut network = - NetworkBackend::new(&self.ipv4, &self.ipv6, &self.force_mac_address, &interface)?; - info!("initializing network backend for interface {}", interface); + fn process_network_changeset(&mut self, changeset: AutoNetworkChangeset) -> Result<()> { + for metadata in &changeset.added { + futures::executor::block_on(async { + self.add_network_backend(metadata.clone()).await + })?; + } + + Ok(()) + } + + async fn add_network_backend(&mut self, metadata: NetworkMetadata) -> Result<()> { + let mut network = NetworkBackend::new(metadata, self.bridge.clone())?; network.init().await?; tokio::time::sleep(Duration::from_secs(1)).await; - info!("spawning network backend for interface {}", interface); - tokio::spawn(async move { - if let Err(error) = network.run().await { - error!( - "network backend for interface {} has been stopped: {}", - interface, error - ); - } - - if let Ok(mut spawns) = spawned.lock() { - if let Some(position) = spawns.iter().position(|x| *x == interface) { - spawns.remove(position); - } - } - }); + network.launch().await?; Ok(()) } } diff --git a/network/src/vbridge.rs b/network/src/vbridge.rs new file mode 100644 index 0000000..ba52f17 --- /dev/null +++ b/network/src/vbridge.rs @@ -0,0 +1,164 @@ +use anyhow::{anyhow, Result}; +use etherparse::Ethernet2Header; +use log::{debug, trace, warn}; +use smoltcp::wire::EthernetAddress; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::Arc, +}; +use tokio::sync::broadcast::{ + channel as broadcast_channel, Receiver as BroadcastReceiver, Sender as BroadcastSender, +}; +use tokio::{ + select, + sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, + }, + task::JoinHandle, +}; + +const BROADCAST_MAC_ADDR: &[u8; 6] = &[0xff; 6]; + +const BRIDGE_TX_QUEUE_LEN: usize = 4; +const BRIDGE_RX_QUEUE_LEN: usize = 4; +const BROADCAST_RX_QUEUE_LEN: usize = 4; + +#[derive(Debug)] +struct BridgeMember { + pub bridge_rx_sender: Sender>, +} + +pub struct BridgeJoinHandle { + pub bridge_tx_sender: Sender>, + pub bridge_rx_receiver: Receiver>, + pub broadcast_rx_receiver: BroadcastReceiver>, +} + +type VirtualBridgeMemberMap = Arc>>; + +#[derive(Clone)] +pub struct VirtualBridge { + bridge_tx_sender: Sender>, + members: VirtualBridgeMemberMap, + broadcast_rx_sender: BroadcastSender>, + _task: Arc>, +} + +enum VirtualBridgeSelect { + BroadcastSent(Option>), + PacketReceived(Option>), +} + +impl VirtualBridge { + pub fn new() -> Result { + let (bridge_tx_sender, bridge_tx_receiver) = channel::>(BRIDGE_TX_QUEUE_LEN); + let (broadcast_rx_sender, broadcast_rx_receiver) = + broadcast_channel(BROADCAST_RX_QUEUE_LEN); + + let members = Arc::new(Mutex::new(HashMap::new())); + let handle = { + let members = members.clone(); + let broadcast_rx_sender = broadcast_rx_sender.clone(); + tokio::task::spawn(async move { + if let Err(error) = VirtualBridge::process( + members, + bridge_tx_receiver, + broadcast_rx_sender, + broadcast_rx_receiver, + ) + .await + { + warn!("virtual bridge processing task failed: {}", error); + } + }) + }; + + Ok(VirtualBridge { + bridge_tx_sender, + members, + broadcast_rx_sender, + _task: Arc::new(handle), + }) + } + + pub async fn join(&self, mac: EthernetAddress) -> Result { + let (bridge_rx_sender, bridge_rx_receiver) = channel::>(BRIDGE_RX_QUEUE_LEN); + let member = BridgeMember { bridge_rx_sender }; + + match self.members.lock().await.entry(mac.0) { + Entry::Occupied(_) => { + return Err(anyhow!( + "virtual bridge already has a member with address {}", + mac + )); + } + Entry::Vacant(entry) => { + entry.insert(member); + } + }; + debug!("virtual bridge member has joined: {}", mac); + Ok(BridgeJoinHandle { + bridge_rx_receiver, + broadcast_rx_receiver: self.broadcast_rx_sender.subscribe(), + bridge_tx_sender: self.bridge_tx_sender.clone(), + }) + } + + async fn process( + members: VirtualBridgeMemberMap, + mut bridge_tx_receiver: Receiver>, + broadcast_rx_sender: BroadcastSender>, + mut broadcast_rx_receiver: BroadcastReceiver>, + ) -> Result<()> { + loop { + let selection = select! { + biased; + x = bridge_tx_receiver.recv() => VirtualBridgeSelect::PacketReceived(x), + x = broadcast_rx_receiver.recv() => VirtualBridgeSelect::BroadcastSent(x.ok()), + }; + + match selection { + VirtualBridgeSelect::PacketReceived(Some(packet)) => { + let header = match Ethernet2Header::from_slice(&packet) { + Ok((header, _)) => header, + Err(error) => { + debug!("virtual bridge failed to parse ethernet header: {}", error); + continue; + } + }; + + let destination = &header.destination; + if destination == BROADCAST_MAC_ADDR { + trace!( + "broadcasting bridged packet from {}", + EthernetAddress(header.source) + ); + broadcast_rx_sender.send(packet)?; + continue; + } + match members.lock().await.get(destination) { + Some(member) => { + member.bridge_rx_sender.try_send(packet)?; + trace!( + "sending bridged packet from {} to {}", + EthernetAddress(header.source), + EthernetAddress(header.destination) + ); + } + None => { + trace!( + "no bridge member with address: {}", + EthernetAddress(*destination) + ); + } + } + } + + VirtualBridgeSelect::PacketReceived(None) => break, + VirtualBridgeSelect::BroadcastSent(_) => {} + } + } + Ok(()) + } +}