From 7b2de223205072815ab66a69a5d714cfbc374f45 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Sat, 23 Mar 2024 08:46:20 +0000 Subject: [PATCH] krata: rework xenstore watch for reliability --- crates/kratart/src/lib.rs | 3 +- crates/xen/xenstore/examples/watch.rs | 3 +- crates/xen/xenstore/src/bus.rs | 77 +++++++++++++++++---------- crates/xen/xenstore/src/lib.rs | 16 +++--- 4 files changed, 64 insertions(+), 35 deletions(-) diff --git a/crates/kratart/src/lib.rs b/crates/kratart/src/lib.rs index e7115b4..9c1aab8 100644 --- a/crates/kratart/src/lib.rs +++ b/crates/kratart/src/lib.rs @@ -247,7 +247,8 @@ impl Runtime { .await? .ok_or_else(|| anyhow!("unable to resolve guest: {}", uuid))?; let path = format!("/local/domain/{}/krata/guest/exit-code", info.domid); - let handle = context.xen.store.watch(&path).await?; + let handle = context.xen.store.create_watch().await?; + context.xen.store.bind_watch(&handle, &path).await?; let watch = ExitCodeWatch { handle, sender, diff --git a/crates/xen/xenstore/examples/watch.rs b/crates/xen/xenstore/examples/watch.rs index c3379c0..989265f 100644 --- a/crates/xen/xenstore/examples/watch.rs +++ b/crates/xen/xenstore/examples/watch.rs @@ -7,7 +7,8 @@ async fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); let path = args().nth(1).unwrap_or("/local/domain".to_string()); let client = XsdClient::open().await?; - let mut handle = client.watch(&path).await?; + let mut handle = client.create_watch().await?; + client.bind_watch(&handle, path).await?; let mut count = 0; loop { let Some(event) = handle.receiver.recv().await else { diff --git a/crates/xen/xenstore/src/bus.rs b/crates/xen/xenstore/src/bus.rs index 2f1a19d..e6de48f 100644 --- a/crates/xen/xenstore/src/bus.rs +++ b/crates/xen/xenstore/src/bus.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, ffi::CString, io::ErrorKind, sync::Arc, time::Duration}; +use std::{collections::HashMap, ffi::CString, io::ErrorKind, os::fd::AsRawFd, sync::Arc}; -use libc::O_NONBLOCK; +use libc::{fcntl, F_GETFL, F_SETFL, O_NONBLOCK}; use log::warn; use tokio::{ fs::{metadata, File}, - io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt}, select, sync::{ mpsc::{channel, Receiver, Sender}, @@ -12,7 +12,6 @@ use tokio::{ Mutex, }, task::JoinHandle, - time::timeout, }; use crate::{ @@ -182,22 +181,14 @@ struct XsdSocketProcessor { } impl XsdSocketProcessor { - async fn process_rx(read: File, rx_sender: Sender) -> Result<()> { - let mut buffer: Vec = vec![0u8; XEN_BUS_MAX_PACKET_SIZE]; - let mut fd = AsyncFd::new(read)?; + async fn process_rx(mut read: File, rx_sender: Sender) -> Result<()> { + let mut header_buffer: Vec = vec![0u8; XsdMessageHeader::SIZE]; + let mut buffer: Vec = vec![0u8; XEN_BUS_MAX_PACKET_SIZE - XsdMessageHeader::SIZE]; loop { select! { - x = fd.readable_mut() => match x { - Ok(mut guard) => { - let future = XsdSocketProcessor::read_message(&mut buffer, guard.get_inner_mut()); - if let Ok(message) = timeout(Duration::from_secs(1), future).await { - rx_sender.send(message?).await?; - } - }, - - Err(error) => { - return Err(error.into()); - } + message = XsdSocketProcessor::read_message(&mut header_buffer, &mut buffer, &mut read) => { + let message = message?; + rx_sender.send(message).await?; }, _ = rx_sender.closed() => { @@ -208,9 +199,37 @@ impl XsdSocketProcessor { Ok(()) } - async fn read_message(buffer: &mut [u8], read: &mut File) -> Result { - let size = loop { - match read.read(buffer).await { + fn set_nonblocking(fd: i32, nonblock: bool) -> Result<()> { + let mut flags = unsafe { fcntl(fd, F_GETFL) }; + if flags == -1 { + return Err(Error::Io(std::io::Error::new( + ErrorKind::Unsupported, + "failed to get fd flags", + ))); + } + if nonblock { + flags |= O_NONBLOCK; + } else { + flags &= !O_NONBLOCK; + } + let result = unsafe { fcntl(fd, F_SETFL, flags) }; + if result == -1 { + return Err(Error::Io(std::io::Error::new( + ErrorKind::Unsupported, + "failed to set fd flags", + ))); + } + Ok(()) + } + + async fn read_message( + header_buffer: &mut [u8], + buffer: &mut [u8], + read: &mut File, + ) -> Result { + XsdSocketProcessor::set_nonblocking(read.as_raw_fd(), true)?; + let header_size = loop { + match read.read_exact(header_buffer).await { Ok(size) => break size, Err(error) => { if error.kind() == ErrorKind::WouldBlock { @@ -222,19 +241,23 @@ impl XsdSocketProcessor { }; }; - if size < XsdMessageHeader::SIZE { + if header_size < XsdMessageHeader::SIZE { return Err(Error::InvalidBusData); } - let header = XsdMessageHeader::decode(&buffer[0..XsdMessageHeader::SIZE])?; - if size < XsdMessageHeader::SIZE + header.len as usize { + let header = XsdMessageHeader::decode(header_buffer)?; + if header.len as usize > buffer.len() { + return Err(Error::InvalidBusData); + } + let payload_buffer = &mut buffer[0..header.len as usize]; + XsdSocketProcessor::set_nonblocking(read.as_raw_fd(), false)?; + let payload_size = read.read_exact(payload_buffer).await?; + if payload_size != header.len as usize { return Err(Error::InvalidBusData); } - let payload = - &mut buffer[XsdMessageHeader::SIZE..XsdMessageHeader::SIZE + header.len as usize]; Ok(XsdMessage { header, - payload: payload.to_vec(), + payload: payload_buffer.to_vec(), }) } diff --git a/crates/xen/xenstore/src/lib.rs b/crates/xen/xenstore/src/lib.rs index 5c962cf..fd9992a 100644 --- a/crates/xen/xenstore/src/lib.rs +++ b/crates/xen/xenstore/src/lib.rs @@ -192,19 +192,23 @@ impl XsdClient { response.parse_bool() } - pub async fn watch>(&self, path: P) -> Result { + pub async fn create_watch(&self) -> Result { let (id, receiver, unwatch_sender) = self.socket.add_watch().await?; - let id_string = id.to_string(); - let _ = self - .socket - .send(0, XSD_WATCH, &[path.as_ref(), &id_string]) - .await?; Ok(XsdWatchHandle { id, receiver, unwatch_sender, }) } + + pub async fn bind_watch>(&self, handle: &XsdWatchHandle, path: P) -> Result<()> { + let id_string = handle.id.to_string(); + let _ = self + .socket + .send(0, XSD_WATCH, &[path.as_ref(), &id_string]) + .await?; + Ok(()) + } } #[derive(Clone)]