From ea9624955cac288b3b33f359dc01ba3c86d650d1 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Sat, 14 Dec 2024 18:03:33 -0500 Subject: [PATCH] feat(evtchn): harden evtchn handling and improve api (#431) --- crates/xen/xenevtchn/src/error.rs | 2 + crates/xen/xenevtchn/src/lib.rs | 129 +++++++++++++++++------------- crates/xen/xenevtchn/src/raw.rs | 10 +-- crates/xen/xenevtchn/src/sys.rs | 20 ++--- 4 files changed, 92 insertions(+), 69 deletions(-) diff --git a/crates/xen/xenevtchn/src/error.rs b/crates/xen/xenevtchn/src/error.rs index ec9e729..e1faa98 100644 --- a/crates/xen/xenevtchn/src/error.rs +++ b/crates/xen/xenevtchn/src/error.rs @@ -12,6 +12,8 @@ pub enum Error { LockAcquireFailed, #[error("event port already in use")] PortInUse, + #[error("failed to join blocking task")] + BlockingTaskJoin, } pub type Result = std::result::Result; diff --git a/crates/xen/xenevtchn/src/lib.rs b/crates/xen/xenevtchn/src/lib.rs index 61a4e4f..f22e8d2 100644 --- a/crates/xen/xenevtchn/src/lib.rs +++ b/crates/xen/xenevtchn/src/lib.rs @@ -3,7 +3,10 @@ pub mod raw; pub mod sys; use crate::error::{Error, Result}; -use crate::sys::{BindInterdomain, BindUnboundPort, BindVirq, Notify, UnbindPort}; +use crate::sys::{ + BindInterdomainRequest, BindUnboundPortRequest, BindVirqRequest, NotifyRequest, + UnbindPortRequest, +}; use crate::raw::EVENT_CHANNEL_DEVICE; use byteorder::{LittleEndian, ReadBytesExt}; @@ -16,12 +19,9 @@ use std::os::raw::c_void; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::fs::{File, OpenOptions}; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, Notify}; -const CHANNEL_QUEUE_LEN: usize = 30; - -type WakeMap = Arc>>>; +type WakeMap = Arc>>>; #[derive(Clone)] pub struct EventChannelService { @@ -32,7 +32,7 @@ pub struct EventChannelService { pub struct BoundEventChannel { pub local_port: u32, - pub receiver: Receiver, + pub receiver: Arc, pub service: EventChannelService, } @@ -59,7 +59,7 @@ impl EventChannelService { .write(true) .open(EVENT_CHANNEL_DEVICE) .await?; - let wakes = Arc::new(RwLock::new(HashMap::new())); + let wakes = Arc::new(Mutex::new(HashMap::new())); let flag = Arc::new(AtomicBool::new(false)); let processor = EventChannelProcessor { flag: flag.clone(), @@ -77,43 +77,52 @@ impl EventChannelService { pub async fn bind_virq(&self, virq: u32) -> Result { let handle = self.handle.lock().await; - unsafe { - let mut request = BindVirq { virq }; - Ok(sys::bind_virq(handle.as_raw_fd(), &mut request)? as u32) - } + let fd = handle.as_raw_fd(); + let mut request = BindVirqRequest { virq }; + let result = + tokio::task::spawn_blocking(move || unsafe { sys::bind_virq(fd, &mut request) }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + Ok(result) } pub async fn bind_interdomain(&self, domid: u32, port: u32) -> Result { let handle = self.handle.lock().await; - unsafe { - let mut request = BindInterdomain { - remote_domain: domid, - remote_port: port, - }; - Ok(sys::bind_interdomain(handle.as_raw_fd(), &mut request)? as u32) - } + let fd = handle.as_raw_fd(); + let mut request = BindInterdomainRequest { + remote_domain: domid, + remote_port: port, + }; + let result = + tokio::task::spawn_blocking(move || unsafe { sys::bind_interdomain(fd, &mut request) }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + Ok(result) } pub async fn bind_unbound_port(&self, domid: u32) -> Result { let handle = self.handle.lock().await; - unsafe { - let mut request = BindUnboundPort { - remote_domain: domid, - }; - Ok(sys::bind_unbound_port(handle.as_raw_fd(), &mut request)? as u32) - } + let fd = handle.as_raw_fd(); + let mut request = BindUnboundPortRequest { + remote_domain: domid, + }; + let result = tokio::task::spawn_blocking(move || unsafe { + sys::bind_unbound_port(fd, &mut request) + }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + Ok(result) } pub async fn unmask(&self, port: u32) -> Result<()> { let handle = self.handle.lock().await; let mut port = port; - let result = unsafe { - libc::write( - handle.as_raw_fd(), - &mut port as *mut u32 as *mut c_void, - size_of::(), - ) - }; + let fd = handle.as_raw_fd(); + let result = tokio::task::spawn_blocking(move || unsafe { + libc::write(fd, &mut port as *mut u32 as *mut c_void, size_of::()) + }) + .await + .map_err(|_| Error::BlockingTaskJoin)?; if result != size_of::() as isize { return Err(Error::Io(std::io::Error::from_raw_os_error(result as i32))); } @@ -122,25 +131,32 @@ impl EventChannelService { pub async fn unbind(&self, port: u32) -> Result { let handle = self.handle.lock().await; - unsafe { - let mut request = UnbindPort { port }; - let result = sys::unbind(handle.as_raw_fd(), &mut request)? as u32; - self.wakes.write().await.remove(&port); - Ok(result) - } + let mut request = UnbindPortRequest { port }; + let fd = handle.as_raw_fd(); + let result = tokio::task::spawn_blocking(move || unsafe { sys::unbind(fd, &mut request) }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + self.wakes.lock().await.remove(&port); + Ok(result) } pub async fn notify(&self, port: u32) -> Result { let handle = self.handle.lock().await; - unsafe { - let mut request = Notify { port }; - Ok(sys::notify(handle.as_raw_fd(), &mut request)? as u32) - } + let mut request = NotifyRequest { port }; + let fd = handle.as_raw_fd(); + let result = tokio::task::spawn_blocking(move || unsafe { sys::notify(fd, &mut request) }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + Ok(result) } pub async fn reset(&self) -> Result { let handle = self.handle.lock().await; - unsafe { Ok(sys::reset(handle.as_raw_fd())? as u32) } + let fd = handle.as_raw_fd(); + let result = tokio::task::spawn_blocking(move || unsafe { sys::reset(fd) }) + .await + .map_err(|_| Error::BlockingTaskJoin)?? as u32; + Ok(result) } pub async fn bind(&self, domid: u32, port: u32) -> Result { @@ -154,17 +170,15 @@ impl EventChannelService { Ok(bound) } - pub async fn subscribe(&self, port: u32) -> Result> { - let mut wakes = self.wakes.write().await; + pub async fn subscribe(&self, port: u32) -> Result> { + let mut wakes = self.wakes.lock().await; let receiver = match wakes.entry(port) { - Entry::Occupied(_) => { - return Err(Error::PortInUse); - } + Entry::Occupied(entry) => entry.get().clone(), Entry::Vacant(entry) => { - let (sender, receiver) = channel::(CHANNEL_QUEUE_LEN); - entry.insert(sender); - receiver + let notify = Arc::new(Notify::new()); + entry.insert(notify.clone()); + notify } }; Ok(receiver) @@ -194,9 +208,16 @@ impl EventChannelProcessor { pub fn process(&mut self) -> Result<()> { loop { let port = self.handle.read_u32::()?; - if let Some(wake) = self.wakes.blocking_read().get(&port) { - let _ = wake.try_send(port); - } + let receiver = match self.wakes.blocking_lock().entry(port) { + Entry::Occupied(entry) => entry.get().clone(), + + Entry::Vacant(entry) => { + let notify = Arc::new(Notify::new()); + entry.insert(notify.clone()); + notify + } + }; + receiver.notify_one(); } } } diff --git a/crates/xen/xenevtchn/src/raw.rs b/crates/xen/xenevtchn/src/raw.rs index 763b712..a2e727a 100644 --- a/crates/xen/xenevtchn/src/raw.rs +++ b/crates/xen/xenevtchn/src/raw.rs @@ -32,13 +32,13 @@ impl RawEventChannelService { pub fn bind_virq(&self, virq: u32) -> Result { let handle = self.handle.lock().map_err(|_| Error::LockAcquireFailed)?; - let mut request = sys::BindVirq { virq }; + let mut request = sys::BindVirqRequest { virq }; Ok(unsafe { sys::bind_virq(handle.as_raw_fd(), &mut request)? as u32 }) } pub fn bind_interdomain(&self, domid: u32, port: u32) -> Result { let handle = self.handle.lock().map_err(|_| Error::LockAcquireFailed)?; - let mut request = sys::BindInterdomain { + let mut request = sys::BindInterdomainRequest { remote_domain: domid, remote_port: port, }; @@ -47,7 +47,7 @@ impl RawEventChannelService { pub fn bind_unbound_port(&self, domid: u32) -> Result { let handle = self.handle.lock().map_err(|_| Error::LockAcquireFailed)?; - let mut request = sys::BindUnboundPort { + let mut request = sys::BindUnboundPortRequest { remote_domain: domid, }; Ok(unsafe { sys::bind_unbound_port(handle.as_raw_fd(), &mut request)? as u32 }) @@ -55,13 +55,13 @@ impl RawEventChannelService { pub fn unbind(&self, port: u32) -> Result { let handle = self.handle.lock().map_err(|_| Error::LockAcquireFailed)?; - let mut request = sys::UnbindPort { port }; + let mut request = sys::UnbindPortRequest { port }; Ok(unsafe { sys::unbind(handle.as_raw_fd(), &mut request)? as u32 }) } pub fn notify(&self, port: u32) -> Result { let handle = self.handle.lock().map_err(|_| Error::LockAcquireFailed)?; - let mut request = sys::Notify { port }; + let mut request = sys::NotifyRequest { port }; Ok(unsafe { sys::notify(handle.as_raw_fd(), &mut request)? as u32 }) } diff --git a/crates/xen/xenevtchn/src/sys.rs b/crates/xen/xenevtchn/src/sys.rs index 22154a3..67dbef1 100644 --- a/crates/xen/xenevtchn/src/sys.rs +++ b/crates/xen/xenevtchn/src/sys.rs @@ -2,34 +2,34 @@ use nix::{ioctl_none, ioctl_readwrite_bad}; use std::ffi::c_uint; #[repr(C)] -pub struct BindVirq { +pub struct BindVirqRequest { pub virq: c_uint, } #[repr(C)] -pub struct BindInterdomain { +pub struct BindInterdomainRequest { pub remote_domain: c_uint, pub remote_port: c_uint, } #[repr(C)] -pub struct BindUnboundPort { +pub struct BindUnboundPortRequest { pub remote_domain: c_uint, } #[repr(C)] -pub struct UnbindPort { +pub struct UnbindPortRequest { pub port: c_uint, } #[repr(C)] -pub struct Notify { +pub struct NotifyRequest { pub port: c_uint, } -ioctl_readwrite_bad!(bind_virq, 0x44500, BindVirq); -ioctl_readwrite_bad!(bind_interdomain, 0x84501, BindInterdomain); -ioctl_readwrite_bad!(bind_unbound_port, 0x44503, BindUnboundPort); -ioctl_readwrite_bad!(unbind, 0x44502, UnbindPort); -ioctl_readwrite_bad!(notify, 0x44504, Notify); +ioctl_readwrite_bad!(bind_virq, 0x44500, BindVirqRequest); +ioctl_readwrite_bad!(bind_interdomain, 0x84501, BindInterdomainRequest); +ioctl_readwrite_bad!(bind_unbound_port, 0x44503, BindUnboundPortRequest); +ioctl_readwrite_bad!(unbind, 0x44502, UnbindPortRequest); +ioctl_readwrite_bad!(notify, 0x44504, NotifyRequest); ioctl_none!(reset, 0x4505, 5);