diff --git a/crates/daemon/src/event.rs b/crates/daemon/src/event.rs index 76257d1..55d9872 100644 --- a/crates/daemon/src/event.rs +++ b/crates/daemon/src/event.rs @@ -6,10 +6,10 @@ use std::{ use anyhow::Result; use krata::{ - idm::protocol::{idm_event::Event, idm_packet::Content, IdmPacket}, + idm::protocol::{idm_event::Event, IdmEvent}, v1::common::{GuestExitInfo, GuestState, GuestStatus}, }; -use log::error; +use log::{error, warn}; use tokio::{ select, sync::{ @@ -21,15 +21,12 @@ use tokio::{ }; use uuid::Uuid; -use crate::{ - db::GuestStore, - idm::{DaemonIdmHandle, DaemonIdmSubscribeHandle}, -}; +use crate::{db::GuestStore, idm::DaemonIdmHandle}; pub type DaemonEvent = krata::v1::control::watch_events_reply::Event; const EVENT_CHANNEL_QUEUE_LEN: usize = 1000; -const IDM_CHANNEL_QUEUE_LEN: usize = 1000; +const IDM_EVENT_CHANNEL_QUEUE_LEN: usize = 1000; #[derive(Clone)] pub struct DaemonEventContext { @@ -52,9 +49,9 @@ pub struct DaemonEventGenerator { guest_reconciler_notify: Sender, feed: broadcast::Receiver, idm: DaemonIdmHandle, - idms: HashMap, - idm_sender: Sender<(u32, IdmPacket)>, - idm_receiver: Receiver<(u32, IdmPacket)>, + idms: HashMap)>, + idm_sender: Sender<(u32, IdmEvent)>, + idm_receiver: Receiver<(u32, IdmEvent)>, _event_sender: broadcast::Sender, } @@ -65,7 +62,7 @@ impl DaemonEventGenerator { idm: DaemonIdmHandle, ) -> Result<(DaemonEventContext, DaemonEventGenerator)> { let (sender, _) = broadcast::channel(EVENT_CHANNEL_QUEUE_LEN); - let (idm_sender, idm_receiver) = channel(IDM_CHANNEL_QUEUE_LEN); + let (idm_sender, idm_receiver) = channel(IDM_EVENT_CHANNEL_QUEUE_LEN); let generator = DaemonEventGenerator { guests, guest_reconciler_notify, @@ -97,15 +94,27 @@ impl DaemonEventGenerator { match status { GuestStatus::Started => { if let Entry::Vacant(e) = self.idms.entry(domid) { - let subscribe = - self.idm.subscribe(domid, self.idm_sender.clone()).await?; - e.insert((id, subscribe)); + let client = self.idm.client(domid).await?; + let mut receiver = client.subscribe().await?; + let sender = self.idm_sender.clone(); + let task = tokio::task::spawn(async move { + loop { + let Ok(event) = receiver.recv().await else { + break; + }; + + if let Err(error) = sender.send((domid, event)).await { + warn!("unable to deliver idm event: {}", error); + } + } + }); + e.insert((id, task)); } } GuestStatus::Destroyed => { if let Some((_, handle)) = self.idms.remove(&domid) { - handle.unsubscribe().await?; + handle.abort(); } } @@ -116,14 +125,10 @@ impl DaemonEventGenerator { Ok(()) } - async fn handle_idm_packet(&mut self, id: Uuid, packet: IdmPacket) -> Result<()> { - match packet.content { - Some(Content::Event(event)) => match event.event { - Some(Event::Exit(exit)) => self.handle_exit_code(id, exit.code).await, - None => Ok(()), - }, - - _ => Ok(()), + async fn handle_idm_event(&mut self, id: Uuid, event: IdmEvent) -> Result<()> { + match event.event { + Some(Event::Exit(exit)) => self.handle_exit_code(id, exit.code).await, + None => Ok(()), } } @@ -146,9 +151,9 @@ impl DaemonEventGenerator { async fn evaluate(&mut self) -> Result<()> { select! { x = self.idm_receiver.recv() => match x { - Some((domid, packet)) => { + Some((domid, event)) => { if let Some((id, _)) = self.idms.get(&domid) { - self.handle_idm_packet(*id, packet).await?; + self.handle_idm_event(*id, event).await?; } Ok(()) }, diff --git a/crates/daemon/src/idm.rs b/crates/daemon/src/idm.rs index 18ba0f2..a2e04b8 100644 --- a/crates/daemon/src/idm.rs +++ b/crates/daemon/src/idm.rs @@ -1,8 +1,14 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::Arc, +}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use bytes::{Buf, BytesMut}; -use krata::idm::protocol::IdmPacket; +use krata::idm::{ + client::{IdmBackend, IdmClient}, + protocol::IdmPacket, +}; use kratart::channel::ChannelService; use log::{error, warn}; use prost::Message; @@ -15,53 +21,20 @@ use tokio::{ task::JoinHandle, }; -type ListenerMap = Arc>>>; +type BackendFeedMap = Arc>>>; +type ClientMap = Arc>>; #[derive(Clone)] pub struct DaemonIdmHandle { - listeners: ListenerMap, + clients: ClientMap, + feeds: BackendFeedMap, tx_sender: Sender<(u32, IdmPacket)>, task: Arc>, } -#[derive(Clone)] -pub struct DaemonIdmSubscribeHandle { - domid: u32, - tx_sender: Sender<(u32, IdmPacket)>, - listeners: ListenerMap, -} - -impl DaemonIdmSubscribeHandle { - pub async fn send(&self, packet: IdmPacket) -> Result<()> { - self.tx_sender.send((self.domid, packet)).await?; - Ok(()) - } - - pub async fn unsubscribe(&self) -> Result<()> { - let mut guard = self.listeners.lock().await; - let _ = guard.remove(&self.domid); - Ok(()) - } -} - impl DaemonIdmHandle { - pub async fn send(&self, domid: u32, packet: IdmPacket) -> Result<()> { - self.tx_sender.send((domid, packet)).await?; - Ok(()) - } - - pub async fn subscribe( - &self, - domid: u32, - sender: Sender<(u32, IdmPacket)>, - ) -> Result { - let mut guard = self.listeners.lock().await; - guard.insert(domid, sender); - Ok(DaemonIdmSubscribeHandle { - domid, - tx_sender: self.tx_sender.clone(), - listeners: self.listeners.clone(), - }) + pub async fn client(&self, domid: u32) -> Result { + client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds).await } } @@ -74,7 +47,8 @@ impl Drop for DaemonIdmHandle { } pub struct DaemonIdm { - listeners: ListenerMap, + clients: ClientMap, + feeds: BackendFeedMap, tx_sender: Sender<(u32, IdmPacket)>, tx_raw_sender: Sender<(u32, Vec)>, tx_receiver: Receiver<(u32, IdmPacket)>, @@ -88,19 +62,22 @@ impl DaemonIdm { ChannelService::new("krata-channel".to_string(), None).await?; let (tx_sender, tx_receiver) = channel(100); let task = service.launch().await?; - let listeners = Arc::new(Mutex::new(HashMap::new())); + let clients = Arc::new(Mutex::new(HashMap::new())); + let feeds = Arc::new(Mutex::new(HashMap::new())); Ok(DaemonIdm { rx_receiver, tx_receiver, tx_sender, tx_raw_sender, task, - listeners, + clients, + feeds, }) } pub async fn launch(mut self) -> Result { - let listeners = self.listeners.clone(); + let clients = self.clients.clone(); + let feeds = self.feeds.clone(); let tx_sender = self.tx_sender.clone(); let task = tokio::task::spawn(async move { let mut buffers: HashMap = HashMap::new(); @@ -109,7 +86,8 @@ impl DaemonIdm { } }); Ok(DaemonIdmHandle { - listeners, + clients, + feeds, tx_sender, task: Arc::new(task), }) @@ -134,11 +112,9 @@ impl DaemonIdm { packet.advance(2); match IdmPacket::decode(packet) { Ok(packet) => { - let guard = self.listeners.lock().await; - if let Some(sender) = guard.get(&domid) { - if let Err(error) = sender.try_send((domid, packet)) { - warn!("dropped idm packet from domain {}: {}", domid, error); - } + let guard = self.feeds.lock().await; + if let Some(feed) = guard.get(&domid) { + let _ = feed.try_send(packet); } } @@ -173,3 +149,50 @@ impl Drop for DaemonIdm { self.task.abort(); } } + +async fn client_or_create( + domid: u32, + tx_sender: &Sender<(u32, IdmPacket)>, + clients: &ClientMap, + feeds: &BackendFeedMap, +) -> Result { + let mut clients = clients.lock().await; + let mut feeds = feeds.lock().await; + match clients.entry(domid) { + Entry::Occupied(entry) => Ok(entry.get().clone()), + Entry::Vacant(entry) => { + let (rx_sender, rx_receiver) = channel(100); + feeds.insert(domid, rx_sender); + let backend = IdmDaemonBackend { + domid, + rx_receiver, + tx_sender: tx_sender.clone(), + }; + let client = IdmClient::new(Box::new(backend) as Box).await?; + entry.insert(client.clone()); + Ok(client) + } + } +} + +pub struct IdmDaemonBackend { + domid: u32, + rx_receiver: Receiver, + tx_sender: Sender<(u32, IdmPacket)>, +} + +#[async_trait::async_trait] +impl IdmBackend for IdmDaemonBackend { + async fn recv(&mut self) -> Result { + if let Some(packet) = self.rx_receiver.recv().await { + Ok(packet) + } else { + Err(anyhow!("idm receive channel closed")) + } + } + + async fn send(&mut self, packet: IdmPacket) -> Result<()> { + self.tx_sender.send((self.domid, packet)).await?; + Ok(()) + } +} diff --git a/crates/guest/src/background.rs b/crates/guest/src/background.rs index 9bd0faf..1195e65 100644 --- a/crates/guest/src/background.rs +++ b/crates/guest/src/background.rs @@ -6,11 +6,11 @@ use anyhow::Result; use cgroups_rs::Cgroup; use krata::idm::{ client::IdmClient, - protocol::{idm_event::Event, idm_packet::Content, IdmEvent, IdmExitEvent, IdmPacket}, + protocol::{idm_event::Event, IdmEvent, IdmExitEvent}, }; use log::debug; use nix::unistd::Pid; -use tokio::select; +use tokio::{select, sync::broadcast}; pub struct GuestBackground { idm: IdmClient, @@ -30,16 +30,21 @@ impl GuestBackground { } pub async fn run(&mut self) -> Result<()> { + let mut event_subscription = self.idm.subscribe().await?; loop { select! { - x = self.idm.receiver.recv() => match x { - Some(_packet) => { + x = event_subscription.recv() => match x { + Ok(_event) => { }, - None => { + Err(broadcast::error::RecvError::Closed) => { debug!("idm packet channel closed"); break; + }, + + _ => { + continue; } }, @@ -57,11 +62,8 @@ impl GuestBackground { async fn child_event(&mut self, event: ChildEvent) -> Result<()> { if event.pid == self.child { self.idm - .sender - .send(IdmPacket { - content: Some(Content::Event(IdmEvent { - event: Some(Event::Exit(IdmExitEvent { code: event.status })), - })), + .emit(IdmEvent { + event: Some(Event::Exit(IdmExitEvent { code: event.status })), }) .await?; death(event.status).await?; diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index 4be6335..df19dc9 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -1,6 +1,8 @@ use std::{path::Path, sync::Arc}; -use super::protocol::IdmPacket; +use crate::idm::protocol::idm_packet::Content; + +use super::protocol::{IdmEvent, IdmPacket}; use anyhow::{anyhow, Result}; use bytes::BytesMut; use log::{debug, error}; @@ -11,6 +13,7 @@ use tokio::{ io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, select, sync::{ + broadcast, mpsc::{channel, Receiver, Sender}, Mutex, }, @@ -72,31 +75,37 @@ impl IdmBackend for IdmFileBackend { } } +#[derive(Clone)] pub struct IdmClient { - pub receiver: Receiver, - pub sender: Sender, - task: JoinHandle<()>, + event_receiver_sender: broadcast::Sender, + tx_sender: Sender, + task: Arc>, } impl Drop for IdmClient { fn drop(&mut self) { - self.task.abort(); + if Arc::strong_count(&self.task) <= 1 { + self.task.abort(); + } } } impl IdmClient { - pub async fn new<'a>(backend: Box) -> Result { - let (rx_sender, rx_receiver) = channel(IDM_PACKET_QUEUE_LEN); + pub async fn new(backend: Box) -> Result { + let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN); let (tx_sender, tx_receiver) = channel(IDM_PACKET_QUEUE_LEN); + let backend_event_sender = event_sender.clone(); let task = tokio::task::spawn(async move { - if let Err(error) = IdmClient::process(backend, rx_sender, tx_receiver).await { + if let Err(error) = + IdmClient::process(backend, backend_event_sender, event_receiver, tx_receiver).await + { debug!("failed to handle idm client processing: {}", error); } }); Ok(IdmClient { - receiver: rx_receiver, - sender: tx_sender, - task, + event_receiver_sender: event_sender.clone(), + tx_sender, + task: Arc::new(task), }) } @@ -111,16 +120,36 @@ impl IdmClient { IdmClient::new(Box::new(backend) as Box).await } + pub async fn emit(&self, event: IdmEvent) -> Result<()> { + self.tx_sender + .send(IdmPacket { + content: Some(Content::Event(event)), + }) + .await?; + Ok(()) + } + + pub async fn subscribe(&self) -> Result> { + Ok(self.event_receiver_sender.subscribe()) + } + async fn process( mut backend: Box, - sender: Sender, + event_sender: broadcast::Sender, + _event_receiver: broadcast::Receiver, mut receiver: Receiver, ) -> Result<()> { loop { select! { x = backend.recv() => match x { Ok(packet) => { - sender.send(packet).await?; + match packet.content { + Some(Content::Event(event)) => { + let _ = event_sender.send(event); + }, + + _ => {}, + } }, Err(error) => {