use std::{collections::HashMap, ffi::CString, io::ErrorKind, sync::Arc, time::Duration}; use libc::O_NONBLOCK; use log::warn; use tokio::{ fs::{metadata, File}, io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, select, sync::{ mpsc::{channel, Receiver, Sender}, oneshot::{self, channel as oneshot_channel}, Mutex, }, task::JoinHandle, time::timeout, }; use crate::{ error::{Error, Result}, sys::{XsdMessageHeader, XSD_ERROR, XSD_UNWATCH, XSD_WATCH_EVENT}, }; const XEN_BUS_PATHS: &[&str] = &["/dev/xen/xenbus"]; const XEN_BUS_MAX_PAYLOAD_SIZE: usize = 4096; const XEN_BUS_MAX_PACKET_SIZE: usize = XsdMessageHeader::SIZE + XEN_BUS_MAX_PAYLOAD_SIZE; async fn find_bus_path() -> Option<&'static str> { for path in XEN_BUS_PATHS { match metadata(path).await { Ok(_) => return Some(path), Err(_) => continue, } } None } struct WatchState { sender: Sender, } struct ReplyState { sender: oneshot::Sender, } type ReplyMap = Arc>>; type WatchMap = Arc>>; #[derive(Clone)] pub struct XsdSocket { tx_sender: Sender, replies: ReplyMap, watches: WatchMap, next_request_id: Arc>, next_watch_id: Arc>, processor_task: Arc>, rx_task: Arc>, unwatch_sender: Sender, } impl XsdSocket { pub async fn open() -> Result { let path = match find_bus_path().await { Some(path) => path, None => return Err(Error::BusNotFound), }; let file = File::options() .read(true) .write(true) .custom_flags(O_NONBLOCK) .open(path) .await?; XsdSocket::from_handle(file).await } pub async fn from_handle(handle: File) -> Result { let replies: ReplyMap = Arc::new(Mutex::new(HashMap::new())); let watches: WatchMap = Arc::new(Mutex::new(HashMap::new())); let next_request_id = Arc::new(Mutex::new(0u32)); let (rx_sender, rx_receiver) = channel::(10); let (tx_sender, tx_receiver) = channel::(10); let (unwatch_sender, unwatch_receiver) = channel::(1000); let read: File = handle.try_clone().await?; let mut processor = XsdSocketProcessor { handle, replies: replies.clone(), watches: watches.clone(), next_request_id: next_request_id.clone(), tx_receiver, rx_receiver, unwatch_receiver, }; let processor_task = tokio::task::spawn(async move { if let Err(error) = processor.process().await { warn!("failed to process xen store messages: {}", error); } }); let rx_task = tokio::task::spawn(async move { if let Err(error) = XsdSocketProcessor::process_rx(read, rx_sender).await { warn!("failed to process xen store responses: {}", error); } }); Ok(XsdSocket { tx_sender, replies, watches, next_request_id, next_watch_id: Arc::new(Mutex::new(0u32)), processor_task: Arc::new(processor_task), rx_task: Arc::new(rx_task), unwatch_sender, }) } pub async fn send_buf(&self, tx: u32, typ: u32, payload: &[u8]) -> Result { let req = { let mut guard = self.next_request_id.lock().await; let req = *guard; *guard = req + 1; req }; let (sender, receiver) = oneshot_channel::(); self.replies.lock().await.insert(req, ReplyState { sender }); let header = XsdMessageHeader { typ, req, tx, len: payload.len() as u32, }; let message = XsdMessage { header, payload: payload.to_vec(), }; if let Err(error) = self.tx_sender.try_send(message) { return Err(error.into()); } let reply = receiver.await?; if reply.header.typ == XSD_ERROR { let error = CString::from_vec_with_nul(reply.payload)?; return Err(Error::ResponseError(error.into_string()?)); } Ok(reply) } pub async fn send(&self, tx: u32, typ: u32, payload: &[&str]) -> Result { let mut buf: Vec = Vec::new(); for item in payload { buf.extend_from_slice(item.as_bytes()); buf.push(0); } self.send_buf(tx, typ, &buf).await } pub async fn add_watch(&self) -> Result<(u32, Receiver, Sender)> { let id = { let mut guard = self.next_watch_id.lock().await; let req = *guard; *guard = req + 1; req }; let (sender, receiver) = channel(10); self.watches.lock().await.insert(id, WatchState { sender }); Ok((id, receiver, self.unwatch_sender.clone())) } } struct XsdSocketProcessor { handle: File, replies: ReplyMap, watches: WatchMap, next_request_id: Arc>, tx_receiver: Receiver, rx_receiver: Receiver, unwatch_receiver: Receiver, } impl XsdSocketProcessor { async fn process_rx(read: File, rx_sender: Sender) -> Result<()> { let mut buffer: Vec = vec![0u8; XEN_BUS_MAX_PACKET_SIZE]; let mut fd = AsyncFd::new(read)?; loop { select! { x = fd.readable_mut() => match x { Ok(mut guard) => { let future = XsdSocketProcessor::read_message(&mut buffer, guard.get_inner_mut()); if let Ok(message) = timeout(Duration::from_secs(1), future).await { rx_sender.send(message?).await?; } }, Err(error) => { return Err(error.into()); } }, _ = rx_sender.closed() => { break; } }; } Ok(()) } async fn read_message(buffer: &mut [u8], read: &mut File) -> Result { let size = loop { match read.read(buffer).await { Ok(size) => break size, Err(error) => { if error.kind() == ErrorKind::WouldBlock { tokio::task::yield_now().await; continue; } return Err(error.into()); } }; }; if size < XsdMessageHeader::SIZE { return Err(Error::InvalidBusData); } let header = XsdMessageHeader::decode(&buffer[0..XsdMessageHeader::SIZE])?; if size < XsdMessageHeader::SIZE + header.len as usize { return Err(Error::InvalidBusData); } let payload = &mut buffer[XsdMessageHeader::SIZE..XsdMessageHeader::SIZE + header.len as usize]; Ok(XsdMessage { header, payload: payload.to_vec(), }) } async fn process(&mut self) -> Result<()> { loop { select! { x = self.tx_receiver.recv() => match x { Some(message) => { let mut composed: Vec = Vec::new(); message.header.encode_to(&mut composed)?; composed.extend_from_slice(&message.payload); self.handle.write_all(&composed).await?; } None => { break; } }, x = self.rx_receiver.recv() => match x { Some(message) => { if message.header.typ == XSD_WATCH_EVENT && message.header.req == 0 && message.header.tx == 0 { let strings = message.parse_string_vec()?; let Some(path) = strings.first() else { return Ok(()); }; let Some(token) = strings.get(1) else { return Ok(()); }; let Ok(id) = token.parse::() else { return Ok(()); }; if let Some(state) = self.watches.lock().await.get(&id) { let _ = state.sender.try_send(path.clone()); } } else if let Some(state) = self.replies.lock().await.remove(&message.header.req) { let _ = state.sender.send(message); } } None => { break; } }, x = self.unwatch_receiver.recv() => match x { Some(id) => { let req = { let mut guard = self.next_request_id.lock().await; let req = *guard; *guard = req + 1; req }; let mut payload = id.to_string().as_bytes().to_vec(); payload.push(0); let header = XsdMessageHeader { typ: XSD_UNWATCH, req, tx: 0, len: payload.len() as u32, }; let mut data = header.encode()?; data.extend_from_slice(&payload); self.handle.write_all(&data).await?; }, None => { break; } } }; } Ok(()) } } #[derive(Debug)] pub struct XsdMessage { pub header: XsdMessageHeader, pub payload: Vec, } impl XsdMessage { pub fn parse_string(&self) -> Result { Ok(CString::from_vec_with_nul(self.payload.clone())?.into_string()?) } pub fn parse_string_vec(&self) -> Result> { let mut strings: Vec = Vec::new(); let mut buffer: Vec = Vec::new(); for b in &self.payload { if *b == 0 { let string = String::from_utf8(buffer.clone())?; strings.push(string); buffer.clear(); continue; } buffer.push(*b); } Ok(strings) } pub fn parse_bool(&self) -> Result { Ok(true) } } impl Drop for XsdSocket { fn drop(&mut self) { if Arc::strong_count(&self.rx_task) <= 1 { self.rx_task.abort(); } if Arc::strong_count(&self.processor_task) <= 1 { self.processor_task.abort(); } } }