use std::{ collections::HashMap, ffi::CString, io::Read, os::{ fd::{AsRawFd, FromRawFd, IntoRawFd}, unix::fs::FileTypeExt, }, sync::Arc, }; use log::{debug, warn}; use tokio::{ fs::{metadata, File}, io::AsyncWriteExt, net::UnixStream, select, sync::{ mpsc::{channel, Receiver, Sender}, oneshot::{self, channel as oneshot_channel}, Mutex, }, task::JoinHandle, }; use crate::{ error::{Error, Result}, sys::{XsdMessageHeader, XSD_ERROR, XSD_UNWATCH, XSD_WATCH_EVENT}, }; const XEN_BUS_PATHS: &[&str] = &["/var/run/xenstored/socket", "/dev/xen/xenbus"]; const XEN_BUS_MAX_PAYLOAD_SIZE: usize = 4096; const XEN_BUS_MAX_PACKET_SIZE: usize = XsdMessageHeader::SIZE + XEN_BUS_MAX_PAYLOAD_SIZE; async fn find_bus_path() -> Option<(&'static str, bool)> { for path in XEN_BUS_PATHS { match metadata(path).await { Ok(metadata) => { return Some((path, metadata.file_type().is_socket())); } Err(_) => continue, } } None } struct WatchState { sender: Sender, } struct ReplyState { sender: oneshot::Sender, } type ReplyMap = Arc>>; type WatchMap = Arc>>; #[derive(Clone)] pub struct XsdSocket { tx_sender: Sender, replies: ReplyMap, watches: WatchMap, next_request_id: Arc>, next_watch_id: Arc>, processor_task: Arc>, unwatch_sender: Sender<(u32, String)>, _rx_task: Arc>, } impl XsdSocket { pub async fn open() -> Result { let (path, socket) = match find_bus_path().await { Some(path) => path, None => return Err(Error::BusNotFound), }; let file = if socket { let stream = UnixStream::connect(path).await?; let stream = stream.into_std()?; stream.set_nonblocking(false)?; unsafe { File::from_raw_fd(stream.into_raw_fd()) } } else { File::options().read(true).write(true).open(path).await? }; XsdSocket::from_handle(file).await } pub async fn from_handle(handle: File) -> Result { let replies: ReplyMap = Arc::new(Mutex::new(HashMap::new())); let watches: WatchMap = Arc::new(Mutex::new(HashMap::new())); let next_request_id = Arc::new(Mutex::new(0u32)); let (rx_sender, rx_receiver) = channel::(10); let (tx_sender, tx_receiver) = channel::(10); let (unwatch_sender, unwatch_receiver) = channel::<(u32, String)>(1000); let read: std::fs::File = unsafe { std::fs::File::from_raw_fd(handle.as_raw_fd()) }; let mut processor = XsdSocketProcessor { handle, replies: replies.clone(), watches: watches.clone(), next_request_id: next_request_id.clone(), tx_receiver, rx_receiver, unwatch_receiver, }; let processor_task = tokio::task::spawn(async move { if let Err(error) = processor.process().await { warn!("failed to process xen store messages: {}", error); } }); let rx_task = std::thread::Builder::new() .name("xenstore-reader".to_string()) .spawn(move || { let mut read = read; if let Err(error) = XsdSocketProcessor::process_rx(&mut read, rx_sender) { debug!("failed to process xen store bus: {}", error); } std::mem::forget(read); })?; Ok(XsdSocket { tx_sender, replies, watches, next_request_id, next_watch_id: Arc::new(Mutex::new(0u32)), processor_task: Arc::new(processor_task), unwatch_sender, _rx_task: Arc::new(rx_task), }) } pub async fn send_buf(&self, tx: u32, typ: u32, payload: &[u8]) -> Result { let req = { let mut guard = self.next_request_id.lock().await; let req = *guard; *guard = req.wrapping_add(1); req }; let (sender, receiver) = oneshot_channel::(); self.replies.lock().await.insert(req, ReplyState { sender }); let header = XsdMessageHeader { typ, req, tx, len: payload.len() as u32, }; let message = XsdMessage { header, payload: payload.to_vec(), }; if let Err(error) = self.tx_sender.try_send(message) { return Err(error.into()); } let reply = receiver.await?; if reply.header.typ == XSD_ERROR { let error = CString::from_vec_with_nul(reply.payload)?; return Err(Error::ResponseError(error.into_string()?)); } Ok(reply) } pub async fn send(&self, tx: u32, typ: u32, payload: &[&str]) -> Result { let mut buf: Vec = Vec::new(); for item in payload { buf.extend_from_slice(item.as_bytes()); buf.push(0); } self.send_buf(tx, typ, &buf).await } pub async fn add_watch(&self) -> Result<(u32, Receiver, Sender<(u32, String)>)> { let id = { let mut guard = self.next_watch_id.lock().await; let watch = *guard; *guard = watch.wrapping_add(1); watch }; let (sender, receiver) = channel(10); self.watches.lock().await.insert(id, WatchState { sender }); Ok((id, receiver, self.unwatch_sender.clone())) } } struct XsdSocketProcessor { handle: File, replies: ReplyMap, watches: WatchMap, next_request_id: Arc>, tx_receiver: Receiver, rx_receiver: Receiver, unwatch_receiver: Receiver<(u32, String)>, } impl XsdSocketProcessor { fn process_rx(read: &mut std::fs::File, rx_sender: Sender) -> Result<()> { let mut header_buffer: Vec = vec![0u8; XsdMessageHeader::SIZE]; let mut buffer: Vec = vec![0u8; XEN_BUS_MAX_PACKET_SIZE - XsdMessageHeader::SIZE]; loop { let message = XsdSocketProcessor::read_message(&mut header_buffer, &mut buffer, read)?; rx_sender.blocking_send(message)?; } } fn read_message( header_buffer: &mut [u8], buffer: &mut [u8], read: &mut std::fs::File, ) -> Result { read.read_exact(header_buffer)?; let header = XsdMessageHeader::decode(header_buffer)?; if header.len as usize > buffer.len() { return Err(Error::InvalidBusData); } let payload_buffer = &mut buffer[0..header.len as usize]; read.read_exact(payload_buffer)?; Ok(XsdMessage { header, payload: payload_buffer.to_vec(), }) } async fn process(&mut self) -> Result<()> { loop { select! { x = self.tx_receiver.recv() => match x { Some(message) => { let mut composed: Vec = Vec::new(); message.header.encode_to(&mut composed)?; composed.extend_from_slice(&message.payload); self.handle.write_all(&composed).await?; } None => { break; } }, x = self.rx_receiver.recv() => match x { Some(message) => { if message.header.typ == XSD_WATCH_EVENT && message.header.req == 0 && message.header.tx == 0 { let strings = message.parse_string_vec()?; let Some(path) = strings.first() else { return Ok(()); }; let Some(token) = strings.get(1) else { return Ok(()); }; let Ok(id) = token.parse::() else { return Ok(()); }; if let Some(state) = self.watches.lock().await.get(&id) { let _ = state.sender.try_send(path.clone()); } } else if let Some(state) = self.replies.lock().await.remove(&message.header.req) { let _ = state.sender.send(message); } } None => { break; } }, x = self.unwatch_receiver.recv() => match x { Some((id, path)) => { let req = { let mut guard = self.next_request_id.lock().await; let req = *guard; *guard = req.wrapping_add(1); req }; let mut payload = id.to_string().as_bytes().to_vec(); payload.push(0); payload.extend_from_slice(path.to_string().as_bytes()); payload.push(0); let header = XsdMessageHeader { typ: XSD_UNWATCH, req, tx: 0, len: payload.len() as u32, }; let mut data = header.encode()?; data.extend_from_slice(&payload); self.handle.write_all(&data).await?; }, None => { break; } } } } Ok(()) } } #[derive(Debug)] pub struct XsdMessage { pub header: XsdMessageHeader, pub payload: Vec, } impl XsdMessage { pub fn parse_string(&self) -> Result { Ok(CString::from_vec_with_nul(self.payload.clone())?.into_string()?) } pub fn parse_string_vec(&self) -> Result> { let mut strings: Vec = Vec::new(); let mut buffer: Vec = Vec::new(); for b in &self.payload { if *b == 0 { let string = String::from_utf8(buffer.clone())?; strings.push(string); buffer.clear(); continue; } buffer.push(*b); } Ok(strings) } pub fn parse_bool(&self) -> Result { Ok(true) } } impl Drop for XsdSocket { fn drop(&mut self) { if Arc::strong_count(&self.processor_task) <= 1 { self.processor_task.abort(); } } }