mirror of
https://github.com/edera-dev/krata.git
synced 2025-08-05 14:11:32 +00:00
xenstore: watch support
This commit is contained in:
@ -1,54 +1,326 @@
|
||||
use crate::error::{Error, Result};
|
||||
use crate::sys::{XsdMessageHeader, XSD_ERROR};
|
||||
use std::ffi::CString;
|
||||
use std::fs::{metadata, File};
|
||||
use std::io::{Read, Write};
|
||||
use std::mem::size_of;
|
||||
use std::{collections::HashMap, ffi::CString, io::ErrorKind, sync::Arc, time::Duration};
|
||||
|
||||
use libc::O_NONBLOCK;
|
||||
use log::warn;
|
||||
use tokio::{
|
||||
fs::{metadata, File},
|
||||
io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt},
|
||||
select,
|
||||
sync::{
|
||||
mpsc::{channel, Receiver, Sender},
|
||||
oneshot::{self, channel as oneshot_channel},
|
||||
Mutex,
|
||||
},
|
||||
task::JoinHandle,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
sys::{XsdMessageHeader, XSD_ERROR, XSD_UNWATCH, XSD_WATCH_EVENT},
|
||||
};
|
||||
|
||||
const XEN_BUS_PATHS: &[&str] = &["/dev/xen/xenbus"];
|
||||
const XEN_BUS_MAX_PAYLOAD_SIZE: usize = 4096;
|
||||
const XEN_BUS_MAX_PACKET_SIZE: usize = XsdMessageHeader::SIZE + XEN_BUS_MAX_PAYLOAD_SIZE;
|
||||
|
||||
fn find_bus_path() -> Option<String> {
|
||||
async fn find_bus_path() -> Option<&'static str> {
|
||||
for path in XEN_BUS_PATHS {
|
||||
match metadata(path) {
|
||||
Ok(_) => return Some(String::from(*path)),
|
||||
match metadata(path).await {
|
||||
Ok(_) => return Some(path),
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub struct XsdFileTransport {
|
||||
handle: File,
|
||||
struct WatchState {
|
||||
sender: Sender<String>,
|
||||
}
|
||||
|
||||
impl XsdFileTransport {
|
||||
fn new(path: &str) -> Result<XsdFileTransport> {
|
||||
let handle = File::options().read(true).write(true).open(path)?;
|
||||
Ok(XsdFileTransport { handle })
|
||||
struct ReplyState {
|
||||
sender: oneshot::Sender<XsdMessage>,
|
||||
}
|
||||
|
||||
type ReplyMap = Arc<Mutex<HashMap<u32, ReplyState>>>;
|
||||
type WatchMap = Arc<Mutex<HashMap<u32, WatchState>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct XsdSocket {
|
||||
tx_sender: Sender<XsdMessage>,
|
||||
replies: ReplyMap,
|
||||
watches: WatchMap,
|
||||
next_request_id: Arc<Mutex<u32>>,
|
||||
next_watch_id: Arc<Mutex<u32>>,
|
||||
processor_task: Arc<JoinHandle<()>>,
|
||||
rx_task: Arc<JoinHandle<()>>,
|
||||
unwatch_sender: Sender<u32>,
|
||||
}
|
||||
|
||||
impl XsdSocket {
|
||||
pub async fn open() -> Result<XsdSocket> {
|
||||
let path = match find_bus_path().await {
|
||||
Some(path) => path,
|
||||
None => return Err(Error::BusNotFound),
|
||||
};
|
||||
|
||||
let file = File::options()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.custom_flags(O_NONBLOCK)
|
||||
.open(path)
|
||||
.await?;
|
||||
XsdSocket::from_handle(file).await
|
||||
}
|
||||
|
||||
async fn xsd_read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
|
||||
Ok(self.handle.read_exact(buf)?)
|
||||
pub async fn from_handle(handle: File) -> Result<XsdSocket> {
|
||||
let replies: ReplyMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
let watches: WatchMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let next_request_id = Arc::new(Mutex::new(0u32));
|
||||
|
||||
let (rx_sender, rx_receiver) = channel::<XsdMessage>(10);
|
||||
let (tx_sender, tx_receiver) = channel::<XsdMessage>(10);
|
||||
let (unwatch_sender, unwatch_receiver) = channel::<u32>(1000);
|
||||
let read: File = handle.try_clone().await?;
|
||||
|
||||
let mut processor = XsdSocketProcessor {
|
||||
handle,
|
||||
replies: replies.clone(),
|
||||
watches: watches.clone(),
|
||||
next_request_id: next_request_id.clone(),
|
||||
tx_receiver,
|
||||
rx_receiver,
|
||||
unwatch_receiver,
|
||||
};
|
||||
|
||||
let processor_task = tokio::task::spawn(async move {
|
||||
if let Err(error) = processor.process().await {
|
||||
warn!("failed to process xen store messages: {}", error);
|
||||
}
|
||||
});
|
||||
|
||||
let rx_task = tokio::task::spawn(async move {
|
||||
if let Err(error) = XsdSocketProcessor::process_rx(read, rx_sender).await {
|
||||
warn!("failed to process xen store responses: {}", error);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(XsdSocket {
|
||||
tx_sender,
|
||||
replies,
|
||||
watches,
|
||||
next_request_id,
|
||||
next_watch_id: Arc::new(Mutex::new(0u32)),
|
||||
processor_task: Arc::new(processor_task),
|
||||
rx_task: Arc::new(rx_task),
|
||||
unwatch_sender,
|
||||
})
|
||||
}
|
||||
|
||||
async fn xsd_write_all(&mut self, buf: &[u8]) -> Result<()> {
|
||||
self.handle.write_all(buf)?;
|
||||
self.handle.flush()?;
|
||||
pub async fn send_buf(&self, tx: u32, typ: u32, payload: &[u8]) -> Result<XsdMessage> {
|
||||
let req = {
|
||||
let mut guard = self.next_request_id.lock().await;
|
||||
let req = *guard;
|
||||
*guard = req + 1;
|
||||
req
|
||||
};
|
||||
let (sender, receiver) = oneshot_channel::<XsdMessage>();
|
||||
self.replies.lock().await.insert(req, ReplyState { sender });
|
||||
|
||||
let header = XsdMessageHeader {
|
||||
typ,
|
||||
req,
|
||||
tx,
|
||||
len: payload.len() as u32,
|
||||
};
|
||||
let message = XsdMessage {
|
||||
header,
|
||||
payload: payload.to_vec(),
|
||||
};
|
||||
if let Err(error) = self.tx_sender.try_send(message) {
|
||||
return Err(error.into());
|
||||
}
|
||||
let reply = receiver.await?;
|
||||
if reply.header.typ == XSD_ERROR {
|
||||
let error = CString::from_vec_with_nul(reply.payload)?;
|
||||
return Err(Error::ResponseError(error.into_string()?));
|
||||
}
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
pub async fn send(&self, tx: u32, typ: u32, payload: &[&str]) -> Result<XsdMessage> {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
for item in payload {
|
||||
buf.extend_from_slice(item.as_bytes());
|
||||
buf.push(0);
|
||||
}
|
||||
self.send_buf(tx, typ, &buf).await
|
||||
}
|
||||
|
||||
pub async fn add_watch(&self) -> Result<(u32, Receiver<String>, Sender<u32>)> {
|
||||
let id = {
|
||||
let mut guard = self.next_watch_id.lock().await;
|
||||
let req = *guard;
|
||||
*guard = req + 1;
|
||||
req
|
||||
};
|
||||
let (sender, receiver) = channel(10);
|
||||
self.watches.lock().await.insert(id, WatchState { sender });
|
||||
Ok((id, receiver, self.unwatch_sender.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
struct XsdSocketProcessor {
|
||||
handle: File,
|
||||
replies: ReplyMap,
|
||||
watches: WatchMap,
|
||||
next_request_id: Arc<Mutex<u32>>,
|
||||
tx_receiver: Receiver<XsdMessage>,
|
||||
rx_receiver: Receiver<XsdMessage>,
|
||||
unwatch_receiver: Receiver<u32>,
|
||||
}
|
||||
|
||||
impl XsdSocketProcessor {
|
||||
async fn process_rx(read: File, rx_sender: Sender<XsdMessage>) -> Result<()> {
|
||||
let mut buffer: Vec<u8> = vec![0u8; XEN_BUS_MAX_PACKET_SIZE];
|
||||
let mut fd = AsyncFd::new(read)?;
|
||||
loop {
|
||||
select! {
|
||||
x = fd.readable_mut() => match x {
|
||||
Ok(mut guard) => {
|
||||
let future = XsdSocketProcessor::read_message(&mut buffer, guard.get_inner_mut());
|
||||
if let Ok(message) = timeout(Duration::from_secs(1), future).await {
|
||||
rx_sender.send(message?).await?;
|
||||
}
|
||||
},
|
||||
|
||||
Err(error) => {
|
||||
return Err(error.into());
|
||||
}
|
||||
},
|
||||
|
||||
_ = rx_sender.closed() => {
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_message(buffer: &mut [u8], read: &mut File) -> Result<XsdMessage> {
|
||||
let size = loop {
|
||||
match read.read(buffer).await {
|
||||
Ok(size) => break size,
|
||||
Err(error) => {
|
||||
if error.kind() == ErrorKind::WouldBlock {
|
||||
tokio::task::yield_now().await;
|
||||
continue;
|
||||
}
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
if size < XsdMessageHeader::SIZE {
|
||||
return Err(Error::InvalidBusData);
|
||||
}
|
||||
|
||||
let header = XsdMessageHeader::decode(&buffer[0..XsdMessageHeader::SIZE])?;
|
||||
if size < XsdMessageHeader::SIZE + header.len as usize {
|
||||
return Err(Error::InvalidBusData);
|
||||
}
|
||||
let payload =
|
||||
&mut buffer[XsdMessageHeader::SIZE..XsdMessageHeader::SIZE + header.len as usize];
|
||||
Ok(XsdMessage {
|
||||
header,
|
||||
payload: payload.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn process(&mut self) -> Result<()> {
|
||||
loop {
|
||||
select! {
|
||||
x = self.tx_receiver.recv() => match x {
|
||||
Some(message) => {
|
||||
let mut composed: Vec<u8> = Vec::new();
|
||||
message.header.encode_to(&mut composed)?;
|
||||
composed.extend_from_slice(&message.payload);
|
||||
self.handle.write_all(&composed).await?;
|
||||
}
|
||||
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
},
|
||||
|
||||
x = self.rx_receiver.recv() => match x {
|
||||
Some(message) => {
|
||||
if message.header.typ == XSD_WATCH_EVENT && message.header.req == 0 && message.header.tx == 0 {
|
||||
let strings = message.parse_string_vec()?;
|
||||
let Some(path) = strings.first() else {
|
||||
return Ok(());
|
||||
};
|
||||
let Some(token) = strings.get(1) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Ok(id) = token.parse::<u32>() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(state) = self.watches.lock().await.get(&id) {
|
||||
let _ = state.sender.try_send(path.clone());
|
||||
}
|
||||
} else if let Some(state) = self.replies.lock().await.remove(&message.header.req) {
|
||||
let _ = state.sender.send(message);
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
},
|
||||
|
||||
x = self.unwatch_receiver.recv() => match x {
|
||||
Some(id) => {
|
||||
let req = {
|
||||
let mut guard = self.next_request_id.lock().await;
|
||||
let req = *guard;
|
||||
*guard = req + 1;
|
||||
req
|
||||
};
|
||||
|
||||
let mut payload = id.to_string().as_bytes().to_vec();
|
||||
payload.push(0);
|
||||
let header = XsdMessageHeader {
|
||||
typ: XSD_UNWATCH,
|
||||
req,
|
||||
tx: 0,
|
||||
len: payload.len() as u32,
|
||||
};
|
||||
let mut data = header.encode()?;
|
||||
data.extend_from_slice(&payload);
|
||||
self.handle.write_all(&data).await?;
|
||||
},
|
||||
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XsdSocket {
|
||||
handle: XsdFileTransport,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct XsdResponse {
|
||||
pub struct XsdMessage {
|
||||
pub header: XsdMessageHeader,
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
impl XsdResponse {
|
||||
impl XsdMessage {
|
||||
pub fn parse_string(&self) -> Result<String> {
|
||||
Ok(CString::from_vec_with_nul(self.payload.clone())?.into_string()?)
|
||||
}
|
||||
@ -73,65 +345,14 @@ impl XsdResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl XsdSocket {
|
||||
pub async fn open() -> Result<XsdSocket> {
|
||||
let path = match find_bus_path() {
|
||||
Some(path) => path,
|
||||
None => return Err(Error::BusNotFound),
|
||||
};
|
||||
let transport = XsdFileTransport::new(&path)?;
|
||||
Ok(XsdSocket { handle: transport })
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, tx: u32, typ: u32, buf: &[u8]) -> Result<XsdResponse> {
|
||||
let header = XsdMessageHeader {
|
||||
typ,
|
||||
req: 0,
|
||||
tx,
|
||||
len: buf.len() as u32,
|
||||
};
|
||||
let header_bytes = bytemuck::bytes_of(&header);
|
||||
let mut composed: Vec<u8> = Vec::new();
|
||||
composed.extend_from_slice(header_bytes);
|
||||
composed.extend_from_slice(buf);
|
||||
self.handle.xsd_write_all(&composed).await?;
|
||||
let mut result_buf = vec![0u8; size_of::<XsdMessageHeader>()];
|
||||
match self.handle.xsd_read_exact(result_buf.as_mut_slice()).await {
|
||||
Ok(_) => {}
|
||||
Err(error) => {
|
||||
if result_buf.first().unwrap() == &0 {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
impl Drop for XsdSocket {
|
||||
fn drop(&mut self) {
|
||||
if Arc::strong_count(&self.rx_task) <= 1 {
|
||||
self.rx_task.abort();
|
||||
}
|
||||
let result_header = bytemuck::from_bytes::<XsdMessageHeader>(&result_buf);
|
||||
let mut payload = vec![0u8; result_header.len as usize];
|
||||
self.handle.xsd_read_exact(payload.as_mut_slice()).await?;
|
||||
if result_header.typ == XSD_ERROR {
|
||||
let error = CString::from_vec_with_nul(payload)?;
|
||||
return Err(Error::ResponseError(error.into_string()?));
|
||||
}
|
||||
let response = XsdResponse { header, payload };
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn send_single(&mut self, tx: u32, typ: u32, string: &str) -> Result<XsdResponse> {
|
||||
let text = CString::new(string)?;
|
||||
let buf = text.as_bytes_with_nul();
|
||||
self.send(tx, typ, buf).await
|
||||
}
|
||||
|
||||
pub async fn send_multiple(
|
||||
&mut self,
|
||||
tx: u32,
|
||||
typ: u32,
|
||||
array: &[&str],
|
||||
) -> Result<XsdResponse> {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
for item in array {
|
||||
buf.extend_from_slice(item.as_bytes());
|
||||
buf.push(0);
|
||||
if Arc::strong_count(&self.processor_task) <= 1 {
|
||||
self.processor_task.abort();
|
||||
}
|
||||
self.send(tx, typ, buf.as_slice()).await
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user