xenstore: use read thread to avoid need for non-blocking I/O

This commit is contained in:
Alex Zenla 2024-04-02 03:02:00 +00:00
parent 7940eea588
commit 0fd6318c5f
No known key found for this signature in database
GPG Key ID: 067B238899B51269

View File

@ -1,7 +1,7 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
ffi::CString, ffi::CString,
io::ErrorKind, io::Read,
os::{ os::{
fd::{AsRawFd, FromRawFd, IntoRawFd}, fd::{AsRawFd, FromRawFd, IntoRawFd},
unix::fs::FileTypeExt, unix::fs::FileTypeExt,
@ -9,11 +9,10 @@ use std::{
sync::Arc, sync::Arc,
}; };
use libc::{fcntl, F_GETFL, F_SETFL, O_NONBLOCK}; use log::{debug, warn};
use log::warn;
use tokio::{ use tokio::{
fs::{metadata, File}, fs::{metadata, File},
io::{AsyncReadExt, AsyncWriteExt}, io::AsyncWriteExt,
net::UnixStream, net::UnixStream,
select, select,
sync::{ sync::{
@ -64,8 +63,8 @@ pub struct XsdSocket {
next_request_id: Arc<Mutex<u32>>, next_request_id: Arc<Mutex<u32>>,
next_watch_id: Arc<Mutex<u32>>, next_watch_id: Arc<Mutex<u32>>,
processor_task: Arc<JoinHandle<()>>, processor_task: Arc<JoinHandle<()>>,
rx_task: Arc<JoinHandle<()>>,
unwatch_sender: Sender<(u32, String)>, unwatch_sender: Sender<(u32, String)>,
_rx_task: Arc<std::thread::JoinHandle<()>>,
} }
impl XsdSocket { impl XsdSocket {
@ -78,15 +77,10 @@ impl XsdSocket {
let file = if socket { let file = if socket {
let stream = UnixStream::connect(path).await?; let stream = UnixStream::connect(path).await?;
let stream = stream.into_std()?; let stream = stream.into_std()?;
stream.set_nonblocking(true)?; stream.set_nonblocking(false)?;
unsafe { File::from_raw_fd(stream.into_raw_fd()) } unsafe { File::from_raw_fd(stream.into_raw_fd()) }
} else { } else {
File::options() File::options().read(true).write(true).open(path).await?
.read(true)
.write(true)
.custom_flags(O_NONBLOCK)
.open(path)
.await?
}; };
XsdSocket::from_handle(file).await XsdSocket::from_handle(file).await
@ -101,7 +95,7 @@ impl XsdSocket {
let (rx_sender, rx_receiver) = channel::<XsdMessage>(10); let (rx_sender, rx_receiver) = channel::<XsdMessage>(10);
let (tx_sender, tx_receiver) = channel::<XsdMessage>(10); let (tx_sender, tx_receiver) = channel::<XsdMessage>(10);
let (unwatch_sender, unwatch_receiver) = channel::<(u32, String)>(1000); let (unwatch_sender, unwatch_receiver) = channel::<(u32, String)>(1000);
let read: File = handle.try_clone().await?; let read: std::fs::File = unsafe { std::fs::File::from_raw_fd(handle.as_raw_fd()) };
let mut processor = XsdSocketProcessor { let mut processor = XsdSocketProcessor {
handle, handle,
@ -119,11 +113,13 @@ impl XsdSocket {
} }
}); });
let rx_task = tokio::task::spawn(async move { let rx_task = std::thread::Builder::new()
if let Err(error) = XsdSocketProcessor::process_rx(read, rx_sender).await { .name("xenstore-reader".to_string())
warn!("failed to process xen store responses: {}", error); .spawn(move || {
} if let Err(error) = XsdSocketProcessor::process_rx(read, rx_sender) {
}); debug!("failed to process xen store bus: {}", error);
}
})?;
Ok(XsdSocket { Ok(XsdSocket {
tx_sender, tx_sender,
@ -132,8 +128,8 @@ impl XsdSocket {
next_request_id, next_request_id,
next_watch_id: Arc::new(Mutex::new(0u32)), next_watch_id: Arc::new(Mutex::new(0u32)),
processor_task: Arc::new(processor_task), processor_task: Arc::new(processor_task),
rx_task: Arc::new(rx_task),
unwatch_sender, unwatch_sender,
_rx_task: Arc::new(rx_task),
}) })
} }
@ -201,80 +197,28 @@ struct XsdSocketProcessor {
} }
impl XsdSocketProcessor { impl XsdSocketProcessor {
async fn process_rx(mut read: File, rx_sender: Sender<XsdMessage>) -> Result<()> { fn process_rx(mut read: std::fs::File, rx_sender: Sender<XsdMessage>) -> Result<()> {
let mut header_buffer: Vec<u8> = vec![0u8; XsdMessageHeader::SIZE]; let mut header_buffer: Vec<u8> = vec![0u8; XsdMessageHeader::SIZE];
let mut buffer: Vec<u8> = vec![0u8; XEN_BUS_MAX_PACKET_SIZE - XsdMessageHeader::SIZE]; let mut buffer: Vec<u8> = vec![0u8; XEN_BUS_MAX_PACKET_SIZE - XsdMessageHeader::SIZE];
loop { loop {
select! { let message =
message = XsdSocketProcessor::read_message(&mut header_buffer, &mut buffer, &mut read) => { XsdSocketProcessor::read_message(&mut header_buffer, &mut buffer, &mut read)?;
let message = message?; rx_sender.blocking_send(message)?;
rx_sender.send(message).await?;
},
_ = rx_sender.closed() => {
break;
}
};
} }
Ok(())
} }
fn set_nonblocking(fd: i32, nonblock: bool) -> Result<()> { fn read_message(
let mut flags = unsafe { fcntl(fd, F_GETFL) };
if flags == -1 {
return Err(Error::Io(std::io::Error::new(
ErrorKind::Unsupported,
"failed to get fd flags",
)));
}
if nonblock {
flags |= O_NONBLOCK;
} else {
flags &= !O_NONBLOCK;
}
let result = unsafe { fcntl(fd, F_SETFL, flags) };
if result == -1 {
return Err(Error::Io(std::io::Error::new(
ErrorKind::Unsupported,
"failed to set fd flags",
)));
}
Ok(())
}
async fn read_message(
header_buffer: &mut [u8], header_buffer: &mut [u8],
buffer: &mut [u8], buffer: &mut [u8],
read: &mut File, read: &mut std::fs::File,
) -> Result<XsdMessage> { ) -> Result<XsdMessage> {
XsdSocketProcessor::set_nonblocking(read.as_raw_fd(), true)?; read.read_exact(header_buffer)?;
let header_size = loop {
match read.read_exact(header_buffer).await {
Ok(size) => break size,
Err(error) => {
if error.kind() == ErrorKind::WouldBlock {
tokio::task::yield_now().await;
continue;
}
return Err(error.into());
}
};
};
if header_size < XsdMessageHeader::SIZE {
return Err(Error::InvalidBusData);
}
let header = XsdMessageHeader::decode(header_buffer)?; let header = XsdMessageHeader::decode(header_buffer)?;
if header.len as usize > buffer.len() { if header.len as usize > buffer.len() {
return Err(Error::InvalidBusData); return Err(Error::InvalidBusData);
} }
let payload_buffer = &mut buffer[0..header.len as usize]; let payload_buffer = &mut buffer[0..header.len as usize];
XsdSocketProcessor::set_nonblocking(read.as_raw_fd(), false)?; read.read_exact(payload_buffer)?;
let payload_size = read.read_exact(payload_buffer).await?;
if payload_size != header.len as usize {
return Err(Error::InvalidBusData);
}
Ok(XsdMessage { Ok(XsdMessage {
header, header,
payload: payload_buffer.to_vec(), payload: payload_buffer.to_vec(),
@ -392,10 +336,6 @@ impl XsdMessage {
impl Drop for XsdSocket { impl Drop for XsdSocket {
fn drop(&mut self) { fn drop(&mut self) {
if Arc::strong_count(&self.rx_task) <= 1 {
self.rx_task.abort();
}
if Arc::strong_count(&self.processor_task) <= 1 { if Arc::strong_count(&self.processor_task) <= 1 {
self.processor_task.abort(); self.processor_task.abort();
} }