diff --git a/crates/daemon/src/console.rs b/crates/daemon/src/console.rs index a6e57bf..38e4792 100644 --- a/crates/daemon/src/console.rs +++ b/crates/daemon/src/console.rs @@ -79,7 +79,7 @@ impl Drop for DaemonConsoleHandle { pub struct DaemonConsole { listeners: ListenerMap, buffers: BufferMap, - receiver: Receiver<(u32, Vec)>, + receiver: Receiver<(u32, Option>)>, sender: Sender<(u32, Vec)>, task: JoinHandle<()>, } @@ -124,16 +124,22 @@ impl DaemonConsole { }; let mut buffers = self.buffers.lock().await; - let buffer = buffers - .entry(domid) - .or_insert_with_key(|_| RawConsoleBuffer::boxed()); - buffer.extend_from_slice(&data); - drop(buffers); - let mut listeners = self.listeners.lock().await; - if let Some(senders) = listeners.get_mut(&domid) { - senders.retain(|sender| { - !matches!(sender.try_send(data.to_vec()), Err(TrySendError::Closed(_))) - }); + if let Some(data) = data { + let buffer = buffers + .entry(domid) + .or_insert_with_key(|_| RawConsoleBuffer::boxed()); + buffer.extend_from_slice(&data); + drop(buffers); + let mut listeners = self.listeners.lock().await; + if let Some(senders) = listeners.get_mut(&domid) { + senders.retain(|sender| { + !matches!(sender.try_send(data.to_vec()), Err(TrySendError::Closed(_))) + }); + } + } else { + buffers.remove(&domid); + let mut listeners = self.listeners.lock().await; + listeners.remove(&domid); } } Ok(()) diff --git a/crates/daemon/src/idm.rs b/crates/daemon/src/idm.rs index a2e04b8..2a155cf 100644 --- a/crates/daemon/src/idm.rs +++ b/crates/daemon/src/idm.rs @@ -52,7 +52,7 @@ pub struct DaemonIdm { tx_sender: Sender<(u32, IdmPacket)>, tx_raw_sender: Sender<(u32, Vec)>, tx_receiver: Receiver<(u32, IdmPacket)>, - rx_receiver: Receiver<(u32, Vec)>, + rx_receiver: Receiver<(u32, Option>)>, task: JoinHandle<()>, } @@ -98,29 +98,37 @@ impl DaemonIdm { select! { x = self.rx_receiver.recv() => match x { Some((domid, data)) => { - let buffer = buffers.entry(domid).or_insert_with_key(|_| BytesMut::new()); - buffer.extend_from_slice(&data); - if buffer.len() < 2 { - continue; - } - let size = (buffer[0] as u16 | (buffer[1] as u16) << 8) as usize; - let needed = size + 2; - if buffer.len() < needed { - continue; - } - let mut packet = buffer.split_to(needed); - packet.advance(2); - match IdmPacket::decode(packet) { - Ok(packet) => { - let guard = self.feeds.lock().await; - if let Some(feed) = guard.get(&domid) { - let _ = feed.try_send(packet); + if let Some(data) = data { + let buffer = buffers.entry(domid).or_insert_with_key(|_| BytesMut::new()); + buffer.extend_from_slice(&data); + if buffer.len() < 2 { + continue; + } + let size = (buffer[0] as u16 | (buffer[1] as u16) << 8) as usize; + let needed = size + 2; + if buffer.len() < needed { + continue; + } + let mut packet = buffer.split_to(needed); + packet.advance(2); + match IdmPacket::decode(packet) { + Ok(packet) => { + let _ = client_or_create(domid, &self.tx_sender, &self.clients, &self.feeds).await?; + let guard = self.feeds.lock().await; + if let Some(feed) = guard.get(&domid) { + let _ = feed.try_send(packet); + } + } + + Err(packet) => { + warn!("received invalid packet from domain {}: {}", domid, packet); } } - - Err(packet) => { - warn!("received invalid packet from domain {}: {}", domid, packet); - } + } else { + let mut clients = self.clients.lock().await; + let mut feeds = self.feeds.lock().await; + clients.remove(&domid); + feeds.remove(&domid); } }, diff --git a/crates/runtime/src/channel.rs b/crates/runtime/src/channel.rs index c1d12a5..86ec0c7 100644 --- a/crates/runtime/src/channel.rs +++ b/crates/runtime/src/channel.rs @@ -48,7 +48,7 @@ pub struct ChannelService { gnttab: GrantTab, input_receiver: Receiver<(u32, Vec)>, pub input_sender: Sender<(u32, Vec)>, - output_sender: Sender<(u32, Vec)>, + output_sender: Sender<(u32, Option>)>, } impl ChannelService { @@ -58,7 +58,7 @@ impl ChannelService { ) -> Result<( ChannelService, Sender<(u32, Vec)>, - Receiver<(u32, Vec)>, + Receiver<(u32, Option>)>, )> { let (input_sender, input_receiver) = channel(GROUPED_CHANNEL_QUEUE_LEN); let (output_sender, output_receiver) = channel(GROUPED_CHANNEL_QUEUE_LEN); @@ -203,12 +203,14 @@ pub struct ChannelBackend { pub domid: u32, pub id: u32, pub sender: Sender>, + raw_sender: Sender<(u32, Option>)>, task: JoinHandle<()>, } impl Drop for ChannelBackend { fn drop(&mut self) { self.task.abort(); + let _ = self.raw_sender.try_send((self.domid, None)); debug!( "destroyed channel backend for domain {} channel {}", self.domid, self.id @@ -226,7 +228,7 @@ impl ChannelBackend { store: XsdClient, evtchn: EventChannel, gnttab: GrantTab, - output_sender: Sender<(u32, Vec)>, + output_sender: Sender<(u32, Option>)>, use_reserved_ref: Option, ) -> Result { let processor = KrataChannelBackendProcessor { @@ -242,11 +244,14 @@ impl ChannelBackend { let (input_sender, input_receiver) = channel(SINGLE_CHANNEL_QUEUE_LEN); - let task = processor.launch(output_sender, input_receiver).await?; + let task = processor + .launch(output_sender.clone(), input_receiver) + .await?; Ok(ChannelBackend { domid, id, task, + raw_sender: output_sender, sender: input_sender, }) } @@ -304,7 +309,7 @@ impl KrataChannelBackendProcessor { async fn launch( &self, - output_sender: Sender<(u32, Vec)>, + output_sender: Sender<(u32, Option>)>, input_receiver: Receiver>, ) -> Result> { let owned = self.clone(); @@ -321,7 +326,7 @@ impl KrataChannelBackendProcessor { async fn processor( &self, - sender: Sender<(u32, Vec)>, + sender: Sender<(u32, Option>)>, mut receiver: Receiver>, ) -> Result<()> { self.init().await?; @@ -396,7 +401,7 @@ impl KrataChannelBackendProcessor { unsafe { let buffer = self.read_output_buffer(channel.local_port, &memory).await?; if !buffer.is_empty() { - sender.send((self.domid, buffer)).await?; + sender.send((self.domid, Some(buffer))).await?; } }; @@ -466,7 +471,7 @@ impl KrataChannelBackendProcessor { unsafe { let buffer = self.read_output_buffer(channel.local_port, &memory).await?; if !buffer.is_empty() { - sender.send((self.domid, buffer)).await?; + sender.send((self.domid, Some(buffer))).await?; } }; channel.unmask_sender.send(channel.local_port).await?;