use std::{ collections::{hash_map::Entry, HashMap}, sync::Arc, }; use anyhow::{anyhow, Result}; use bytes::{Buf, BytesMut}; use krata::idm::{ client::{IdmBackend, IdmInternalClient}, internal::INTERNAL_IDM_CHANNEL, transport::IdmTransportPacket, }; use kratart::channel::ChannelService; use log::{debug, error, warn}; use prost::Message; use tokio::{ select, sync::{ broadcast, mpsc::{channel, Receiver, Sender}, Mutex, }, task::JoinHandle, }; use uuid::Uuid; use crate::zlt::ZoneLookupTable; type BackendFeedMap = Arc>>>; type ClientMap = Arc>>; #[derive(Clone)] pub struct DaemonIdmHandle { zlt: ZoneLookupTable, clients: ClientMap, feeds: BackendFeedMap, tx_sender: Sender<(u32, IdmTransportPacket)>, task: Arc>, snoop_sender: broadcast::Sender, } impl DaemonIdmHandle { pub fn snoop(&self) -> broadcast::Receiver { self.snoop_sender.subscribe() } pub async fn client(&self, uuid: Uuid) -> Result { let Some(domid) = self.zlt.lookup_domid_by_uuid(&uuid).await else { return Err(anyhow!("unable to find domain {}", uuid)); }; self.client_by_domid(domid).await } pub async fn client_by_domid(&self, domid: u32) -> Result { client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds).await } } impl Drop for DaemonIdmHandle { fn drop(&mut self) { if Arc::strong_count(&self.task) <= 1 { self.task.abort(); } } } #[derive(Clone)] pub struct DaemonIdmSnoopPacket { pub from: u32, pub to: u32, pub packet: IdmTransportPacket, } pub struct DaemonIdm { zlt: ZoneLookupTable, clients: ClientMap, feeds: BackendFeedMap, tx_sender: Sender<(u32, IdmTransportPacket)>, tx_raw_sender: Sender<(u32, Vec)>, tx_receiver: Receiver<(u32, IdmTransportPacket)>, rx_receiver: Receiver<(u32, Option>)>, snoop_sender: broadcast::Sender, task: JoinHandle<()>, } impl DaemonIdm { pub async fn new(zlt: ZoneLookupTable) -> Result { debug!("allocating channel service for idm"); let (service, tx_raw_sender, rx_receiver) = ChannelService::new("krata-channel".to_string(), None).await?; let (tx_sender, tx_receiver) = channel(100); let (snoop_sender, _) = broadcast::channel(100); debug!("starting idm channel service"); let task = service.launch().await?; let clients = Arc::new(Mutex::new(HashMap::new())); let feeds = Arc::new(Mutex::new(HashMap::new())); Ok(DaemonIdm { zlt, rx_receiver, tx_receiver, tx_sender, tx_raw_sender, snoop_sender, task, clients, feeds, }) } pub async fn launch(mut self) -> Result { let zlt = self.zlt.clone(); let clients = self.clients.clone(); let feeds = self.feeds.clone(); let tx_sender = self.tx_sender.clone(); let snoop_sender = self.snoop_sender.clone(); let task = tokio::task::spawn(async move { let mut buffers: HashMap = HashMap::new(); while let Err(error) = self.process(&mut buffers).await { error!("failed to process idm: {}", error); } }); Ok(DaemonIdmHandle { zlt, clients, feeds, tx_sender, snoop_sender, task: Arc::new(task), }) } async fn process_rx_packet( &mut self, domid: u32, data: Option>, buffers: &mut HashMap, ) -> Result<()> { // check if data is present, if it is not, that signals a closed channel. if let Some(data) = data { let buffer = buffers.entry(domid).or_insert_with_key(|_| BytesMut::new()); buffer.extend_from_slice(&data); loop { // check if the buffer is less than the header size, if so, wait for more data if buffer.len() < 6 { break; } // check for the magic bytes 0xff, 0xff at the start of the message, if that doesn't // exist, clear the buffer. this ensures that partial messages won't be processed. if buffer[0] != 0xff || buffer[1] != 0xff { buffer.clear(); return Ok(()); } // read the size from the buffer as a little endian u32 let size = (buffer[2] as u32 | (buffer[3] as u32) << 8 | (buffer[4] as u32) << 16 | (buffer[5] as u32) << 24) as usize; let needed = size + 6; if buffer.len() < needed { return Ok(()); } let mut packet = buffer.split_to(needed); // advance the buffer by the header, leaving only the raw data. packet.advance(6); match IdmTransportPacket::decode(packet) { Ok(packet) => { let _ = client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds) .await?; let guard = self.feeds.lock().await; if let Some(feed) = guard.get(&domid) { let _ = feed.try_send(packet.clone()); } let _ = self.snoop_sender.send(DaemonIdmSnoopPacket { from: domid, to: 0, packet, }); } Err(packet) => { warn!("received invalid packet from domain {}: {}", domid, packet); } } } } else { let mut clients = self.clients.lock().await; let mut feeds = self.feeds.lock().await; clients.remove(&domid); feeds.remove(&domid); } Ok(()) } async fn tx_packet(&mut self, domid: u32, packet: IdmTransportPacket) -> Result<()> { let data = packet.encode_to_vec(); let mut buffer = vec![0u8; 6]; let length = data.len() as u32; // magic bytes buffer[0] = 0xff; buffer[1] = 0xff; // little endian u32 for message size buffer[2] = length as u8; buffer[3] = (length << 8) as u8; buffer[4] = (length << 16) as u8; buffer[5] = (length << 24) as u8; buffer.extend_from_slice(&data); self.tx_raw_sender.send((domid, buffer)).await?; let _ = self.snoop_sender.send(DaemonIdmSnoopPacket { from: 0, to: domid, packet, }); Ok(()) } async fn process(&mut self, buffers: &mut HashMap) -> Result<()> { loop { select! { x = self.rx_receiver.recv() => match x { Some((domid, data)) => { self.process_rx_packet(domid, data, buffers).await?; }, None => { break; } }, x = self.tx_receiver.recv() => match x { Some((domid, packet)) => { self.tx_packet(domid, packet).await?; }, None => { break; } } } } Ok(()) } } impl Drop for DaemonIdm { fn drop(&mut self) { self.task.abort(); } } async fn client_or_create( domid: u32, tx_sender: &Sender<(u32, IdmTransportPacket)>, clients: &ClientMap, feeds: &BackendFeedMap, ) -> Result { let mut clients = clients.lock().await; let mut feeds = feeds.lock().await; match clients.entry(domid) { Entry::Occupied(entry) => Ok(entry.get().clone()), Entry::Vacant(entry) => { let (rx_sender, rx_receiver) = channel(100); feeds.insert(domid, rx_sender); let backend = IdmDaemonBackend { domid, rx_receiver, tx_sender: tx_sender.clone(), }; let client = IdmInternalClient::new( INTERNAL_IDM_CHANNEL, Box::new(backend) as Box, ) .await?; entry.insert(client.clone()); Ok(client) } } } pub struct IdmDaemonBackend { domid: u32, rx_receiver: Receiver, tx_sender: Sender<(u32, IdmTransportPacket)>, } #[async_trait::async_trait] impl IdmBackend for IdmDaemonBackend { async fn recv(&mut self) -> Result> { if let Some(packet) = self.rx_receiver.recv().await { Ok(vec![packet]) } else { Err(anyhow!("idm receive channel closed")) } } async fn send(&mut self, packet: IdmTransportPacket) -> Result<()> { self.tx_sender.send((self.domid, packet)).await?; Ok(()) } }