diff --git a/crates/kratart/examples/channel.rs b/crates/kratart/examples/channel.rs index adc8538..fa920e9 100644 --- a/crates/kratart/examples/channel.rs +++ b/crates/kratart/examples/channel.rs @@ -2,14 +2,18 @@ use anyhow::Result; use env_logger::Env; use kratart::chan::KrataChannelService; use xenevtchn::EventChannel; +use xengnt::GrantTab; use xenstore::XsdClient; #[tokio::main] async fn main() -> Result<()> { env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); - let mut krata = - KrataChannelService::new(EventChannel::open().await?, XsdClient::open().await?)?; - krata.init().await?; + let mut krata = KrataChannelService::new( + EventChannel::open().await?, + XsdClient::open().await?, + GrantTab::open()?, + )?; + krata.watch().await?; Ok(()) } diff --git a/crates/kratart/src/chan.rs b/crates/kratart/src/chan.rs index 420939a..3584052 100644 --- a/crates/kratart/src/chan.rs +++ b/crates/kratart/src/chan.rs @@ -1,38 +1,118 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + sync::atomic::{fence, Ordering}, + time::Duration, +}; -use anyhow::Result; +use anyhow::{anyhow, Result}; +use log::{error, info}; +use tokio::{ + select, + sync::{ + broadcast, + mpsc::{channel, Receiver, Sender}, + }, + task::JoinHandle, + time::sleep, +}; use xenevtchn::EventChannel; -use xengnt::{sys::GrantRef, GrantTab}; +use xengnt::{sys::GrantRef, GrantTab, MappedMemory}; use xenstore::{XsdClient, XsdInterface}; +const KRATA_SINGLE_CHANNEL_QUEUE_LEN: usize = 100; + #[repr(C)] struct XenConsoleInterface { - input: [u8; 1024], - output: [u8; 2048], + input: [u8; XenConsoleInterface::INPUT_SIZE], + output: [u8; XenConsoleInterface::OUTPUT_SIZE], in_cons: u32, in_prod: u32, out_cons: u32, out_prod: u32, } +unsafe impl Send for XenConsoleInterface {} + +impl XenConsoleInterface { + const INPUT_SIZE: usize = 1024; + const OUTPUT_SIZE: usize = 2048; +} + pub struct KrataChannelService { backends: HashMap<(u32, u32), KrataChannelBackend>, evtchn: EventChannel, store: XsdClient, + gnttab: GrantTab, } impl KrataChannelService { - pub fn new(evtchn: EventChannel, store: XsdClient) -> Result { + pub fn new( + evtchn: EventChannel, + store: XsdClient, + gnttab: GrantTab, + ) -> Result { Ok(KrataChannelService { backends: HashMap::new(), evtchn, store, + gnttab, }) } - pub async fn init(&mut self) -> Result<()> { + pub async fn watch(&mut self) -> Result<()> { + self.scan_all_backends().await?; + let mut watch_handle = self.store.create_watch().await?; + self.store + .bind_watch(&watch_handle, "/local/domain/0/backend/console".to_string()) + .await?; + loop { + let Some(_) = watch_handle.receiver.recv().await else { + break; + }; + + self.scan_all_backends().await?; + } + Ok(()) + } + + async fn ensure_backend_exists(&mut self, domid: u32, id: u32, path: String) -> Result<()> { + if self.backends.contains_key(&(domid, id)) { + return Ok(()); + } + let Some(frontend_path) = self.store.read_string(format!("{}/frontend", path)).await? + else { + return Ok(()); + }; + let Some(typ) = self + .store + .read_string(format!("{}/type", frontend_path)) + .await? + else { + return Ok(()); + }; + + if typ != "krata-channel" { + return Ok(()); + } + + let backend = KrataChannelBackend::new( + path.clone(), + frontend_path.clone(), + domid, + id, + self.store.clone(), + self.evtchn.clone(), + self.gnttab.clone(), + ) + .await?; + self.backends.insert((domid, id), backend); + Ok(()) + } + + async fn scan_all_backends(&mut self) -> Result<()> { let domains = self.store.list("/local/domain/0/backend/console").await?; - for domid_string in domains { + let mut seen: Vec<(u32, u32)> = Vec::new(); + for domid_string in &domains { let domid = domid_string.parse::()?; let domid_path = format!("/local/domain/0/backend/console/{}", domid); for id_string in self.store.list(&domid_path).await? { @@ -41,118 +121,335 @@ impl KrataChannelService { "/local/domain/0/backend/console/{}/{}", domid_string, id_string ); - let Some(frontend_path) = self - .store - .read_string(format!("{}/frontend", console_path)) - .await? - else { - continue; - }; - let Some(typ) = self - .store - .read_string(format!("{}/type", frontend_path)) - .await? - else { - continue; - }; - - if typ != "krata-channel" { - continue; - } - - let Some(ring_ref_string) = self - .store - .read_string(format!("{}/ring-ref", frontend_path)) - .await? - else { - continue; - }; - - let Some(port_string) = self - .store - .read_string(format!("{}/port", frontend_path)) - .await? - else { - continue; - }; - - let ring_ref = ring_ref_string.parse::()?; - let port = port_string.parse::()?; - let backend = KrataChannelBackend { - backend: console_path.clone(), - domid, - ring_ref, - port, - store: self.store.clone(), - evtchn: self.evtchn.clone(), - grant: GrantTab::open()?, - }; - - backend.init().await?; - self.backends.insert((domid, id), backend); + self.ensure_backend_exists(domid, id, console_path).await?; + seen.push((domid, id)); } } + + let mut gone: Vec<(u32, u32)> = Vec::new(); + for backend in self.backends.keys() { + if !seen.contains(backend) { + gone.push(*backend); + } + } + + for item in gone { + if let Some(backend) = self.backends.remove(&item) { + drop(backend); + } + } + Ok(()) } } +pub struct KrataChannelBackend { + pub domid: u32, + pub id: u32, + pub receiver: Receiver>, + pub sender: Sender>, + task: JoinHandle<()>, +} + +impl Drop for KrataChannelBackend { + fn drop(&mut self) { + self.task.abort(); + info!( + "destroyed channel backend for domain {} channel {}", + self.domid, self.id + ); + } +} + +impl KrataChannelBackend { + pub async fn new( + backend: String, + frontend: String, + domid: u32, + id: u32, + store: XsdClient, + evtchn: EventChannel, + gnttab: GrantTab, + ) -> Result { + let processor = KrataChannelBackendProcessor { + backend, + frontend, + domid, + id, + store, + evtchn, + gnttab, + }; + + let (output_sender, output_receiver) = channel(KRATA_SINGLE_CHANNEL_QUEUE_LEN); + let (input_sender, input_receiver) = channel(KRATA_SINGLE_CHANNEL_QUEUE_LEN); + + let task = processor.launch(output_sender, input_receiver).await?; + Ok(KrataChannelBackend { + domid, + id, + task, + receiver: output_receiver, + sender: input_sender, + }) + } +} + #[derive(Clone)] -pub struct KrataChannelBackend { +pub struct KrataChannelBackendProcessor { backend: String, + frontend: String, + id: u32, domid: u32, - ring_ref: u64, - port: u32, store: XsdClient, evtchn: EventChannel, - grant: GrantTab, + gnttab: GrantTab, } -impl KrataChannelBackend { - pub async fn init(&self) -> Result<()> { - self.store.write_string(&self.backend, "4").await?; +impl KrataChannelBackendProcessor { + async fn init(&self) -> Result<()> { + self.store + .write_string(format!("{}/state", self.backend), "3") + .await?; + info!( + "created channel backend for domain {} channel {}", + self.domid, self.id + ); Ok(()) } - pub async fn read(&self) -> Result<()> { - let memory = self.grant.map_grant_refs( + async fn on_frontend_state_change(&self) -> Result { + let state = self + .store + .read_string(format!("{}/state", self.backend)) + .await? + .unwrap_or("0".to_string()) + .parse::()?; + if state == 3 { + return Ok(true); + } + Ok(false) + } + + async fn on_self_state_change(&self) -> Result { + let state = self + .store + .read_string(format!("{}/state", self.backend)) + .await? + .unwrap_or("0".to_string()) + .parse::()?; + if state == 5 { + return Ok(true); + } + Ok(false) + } + + async fn launch( + &self, + output_sender: Sender>, + input_receiver: Receiver>, + ) -> Result> { + let owned = self.clone(); + Ok(tokio::task::spawn(async move { + if let Err(error) = owned.processor(output_sender, input_receiver).await { + error!("failed to process krata channel: {}", error); + } + let _ = owned + .store + .write_string(format!("{}/state", owned.backend), "6") + .await; + })) + } + + async fn processor( + &self, + sender: Sender>, + mut receiver: Receiver>, + ) -> Result<()> { + self.init().await?; + let mut frontend_state_change = self.store.create_watch().await?; + self.store + .bind_watch(&frontend_state_change, format!("{}/state", self.frontend)) + .await?; + + let (ring_ref, port) = loop { + match frontend_state_change.receiver.recv().await { + Some(_) => { + if self.on_frontend_state_change().await? { + let mut tries = 0; + let (ring_ref, port) = loop { + let ring_ref = self + .store + .read_string(format!("{}/ring-ref", self.frontend)) + .await?; + let port = self + .store + .read_string(format!("{}/port", self.frontend)) + .await?; + + if (ring_ref.is_none() || port.is_none()) && tries < 10 { + tries += 1; + self.store + .write_string(format!("{}/state", self.backend), "4") + .await?; + sleep(Duration::from_millis(250)).await; + continue; + } + break (ring_ref, port); + }; + + if ring_ref.is_none() || port.is_none() { + return Err(anyhow!("frontend did not give ring-ref and port")); + } + + let Ok(ring_ref) = ring_ref.unwrap().parse::() else { + return Err(anyhow!("frontend gave invalid ring-ref")); + }; + + let Ok(port) = port.unwrap().parse::() else { + return Err(anyhow!("frontend gave invalid port")); + }; + + break (ring_ref, port); + } + } + + None => { + return Ok(()); + } + } + }; + + self.store + .write_string(format!("{}/state", self.backend), "4") + .await?; + let memory = self.gnttab.map_grant_refs( vec![GrantRef { domid: self.domid, - reference: self.ring_ref as u32, + reference: ring_ref as u32, }], true, true, )?; - let interface = memory.ptr() as *mut XenConsoleInterface; - let mut channel = self.evtchn.bind(self.domid, self.port).await?; - unsafe { self.read_buffer(channel.local_port, interface).await? }; + let mut channel = self.evtchn.bind(self.domid, port).await?; + unsafe { + let buffer = self.read_output_buffer(channel.local_port, &memory).await?; + if !buffer.is_empty() { + sender.send(buffer).await?; + } + }; + + let mut self_state_change = self.store.create_watch().await?; + self.store + .bind_watch(&self_state_change, format!("{}/state", self.backend)) + .await?; loop { - channel.receiver.recv().await?; - unsafe { self.read_buffer(channel.local_port, interface).await? }; - channel.unmask_sender.send(channel.local_port).await?; + select! { + x = self_state_change.receiver.recv() => match x { + Some(_) => { + match self.on_self_state_change().await { + Err(error) => { + error!("failed to process state change for domain {} channel {}: {}", self.domid, self.id, error); + }, + + Ok(stop) => { + if stop { + break; + } + } + } + }, + + None => { + break; + } + }, + + x = receiver.recv() => match x { + Some(data) => { + let mut index = 0; + loop { + if index >= data.len() { + break; + } + let interface = memory.ptr() as *mut XenConsoleInterface; + let cons = unsafe { (*interface).in_cons }; + let mut prod = unsafe { (*interface).in_prod }; + fence(Ordering::Release); + let space = (prod - cons) as usize; + if space > XenConsoleInterface::INPUT_SIZE { + error!("channel for domid {} has an invalid input space of {}", self.domid, space); + } + let free = XenConsoleInterface::INPUT_SIZE - space; + let want = data.len().min(free); + let buffer = &data[index..want]; + for b in buffer { + unsafe { (*interface).input[prod as usize & (XenConsoleInterface::INPUT_SIZE - 1)] = *b; }; + prod += 1; + } + fence(Ordering::Release); + unsafe { (*interface).in_prod = prod; }; + self.evtchn.notify(channel.local_port).await?; + index += want; + } + }, + + None => { + break; + } + }, + + x = channel.receiver.recv() => match x { + Ok(_) => { + unsafe { + let buffer = self.read_output_buffer(channel.local_port, &memory).await?; + if !buffer.is_empty() { + sender.send(buffer).await?; + } + }; + channel.unmask_sender.send(channel.local_port).await?; + }, + + Err(error) => { + match error { + broadcast::error::RecvError::Closed => { + break; + }, + error => { + return Err(anyhow!("failed to receive event notification: {}", error)); + } + } + } + } + }; } + Ok(()) } - async unsafe fn read_buffer( + async unsafe fn read_output_buffer<'a>( &self, local_port: u32, - interface: *mut XenConsoleInterface, - ) -> Result<()> { + memory: &MappedMemory<'a>, + ) -> Result> { + let interface = memory.ptr() as *mut XenConsoleInterface; let mut cons = (*interface).out_cons; let prod = (*interface).out_prod; + fence(Ordering::Release); let size = prod - cons; - if size == 0 || size > 2048 { - return Ok(()); - } let mut data: Vec = Vec::new(); + if size == 0 || size as usize > XenConsoleInterface::OUTPUT_SIZE { + return Ok(data); + } loop { if cons == prod { break; } - data.push((*interface).output[cons as usize]); + data.push((*interface).output[cons as usize & (XenConsoleInterface::OUTPUT_SIZE - 1)]); cons += 1; } + fence(Ordering::AcqRel); (*interface).out_cons = cons; self.evtchn.notify(local_port).await?; - Ok(()) + Ok(data) } } diff --git a/crates/kratart/src/launch/mod.rs b/crates/kratart/src/launch.rs similarity index 97% rename from crates/kratart/src/launch/mod.rs rename to crates/kratart/src/launch.rs index 7945e4f..ef406e0 100644 --- a/crates/kratart/src/launch/mod.rs +++ b/crates/kratart/src/launch.rs @@ -9,7 +9,7 @@ use krata::launchcfg::{ LaunchInfo, LaunchNetwork, LaunchNetworkIpv4, LaunchNetworkIpv6, LaunchNetworkResolver, }; use uuid::Uuid; -use xenclient::{DomainChannel, DomainConfig, DomainDisk, DomainNetworkInterface}; +use xenclient::{DomainConfig, DomainDisk, DomainNetworkInterface}; use xenstore::XsdInterface; use crate::cfgblk::ConfigBlock; @@ -180,10 +180,11 @@ impl GuestLauncher { writable: false, }, ], - channels: vec![DomainChannel { - typ: "krata-channel".to_string(), - initialized: true, - }], + // channels: vec![DomainChannel { + // typ: "krata-channel".to_string(), + // initialized: false, + // }], + channels: vec![], vifs: vec![DomainNetworkInterface { mac: &guest_mac_string, mtu: 1500, diff --git a/crates/xen/xenclient/src/lib.rs b/crates/xen/xenclient/src/lib.rs index 72ac4a4..c1b8535 100644 --- a/crates/xen/xenclient/src/lib.rs +++ b/crates/xen/xenclient/src/lib.rs @@ -275,7 +275,7 @@ impl XenClient { initrd.as_slice(), config.max_vcpus, config.mem_mb, - 1 + config.channels.len(), + 1, )?; boot.boot(&mut arch, &mut state, config.cmdline)?; xenstore_evtchn = state.store_evtchn; diff --git a/crates/xen/xengnt/src/error.rs b/crates/xen/xengnt/src/error.rs index d4932a9..8e6ea8b 100644 --- a/crates/xen/xengnt/src/error.rs +++ b/crates/xen/xengnt/src/error.rs @@ -8,8 +8,8 @@ pub enum Error { Io(#[from] io::Error), #[error("failed to read structure")] StructureReadFailed, - #[error("mmap failed")] - MmapFailed, + #[error("mmap failed: {0}")] + MmapFailed(nix::errno::Errno), } pub type Result = std::result::Result; diff --git a/crates/xen/xengnt/src/lib.rs b/crates/xen/xengnt/src/lib.rs index 5ae66b5..0c35d7c 100644 --- a/crates/xen/xengnt/src/lib.rs +++ b/crates/xen/xengnt/src/lib.rs @@ -2,10 +2,14 @@ pub mod error; pub mod sys; use error::{Error, Result}; +use nix::errno::Errno; use std::{ fs::{File, OpenOptions}, + marker::PhantomData, os::{fd::AsRawFd, raw::c_void}, sync::Arc, + thread::sleep, + time::Duration, }; use sys::{ AllocGref, DeallocGref, GetOffsetForVaddr, GrantRef, MapGrantRef, SetMaxGrants, UnmapGrantRef, @@ -151,24 +155,28 @@ pub struct GrantTab { const PAGE_SIZE: usize = 4096; #[allow(clippy::len_without_is_empty)] -pub struct MappedMemory { +pub struct MappedMemory<'a> { + gnttab: GrantTab, length: usize, - addr: *mut c_void, + addr: u64, + _ptr: PhantomData<&'a c_void>, } -impl MappedMemory { +unsafe impl Send for MappedMemory<'_> {} + +impl MappedMemory<'_> { pub fn len(&self) -> usize { self.length } pub fn ptr(&self) -> *mut c_void { - self.addr + self.addr as *mut c_void } } -impl Drop for MappedMemory { +impl Drop for MappedMemory<'_> { fn drop(&mut self) { - let _ = unsafe { munmap(self.addr, self.length) }; + let _ = self.gnttab.unmap(self); } } @@ -179,12 +187,12 @@ impl GrantTab { }) } - pub fn map_grant_refs( + pub fn map_grant_refs<'a>( &self, refs: Vec, read: bool, write: bool, - ) -> Result { + ) -> Result> { let (index, refs) = self.device.map_grant_ref(refs)?; unsafe { let mut flags: i32 = 0; @@ -196,21 +204,39 @@ impl GrantTab { flags |= PROT_WRITE; } - let addr = mmap( - std::ptr::null_mut(), - PAGE_SIZE * refs.len(), - flags, - MAP_SHARED, - self.device.handle.as_raw_fd(), - index as i64, - ); - if addr == MAP_FAILED { - return Err(Error::MmapFailed); - } + let addr = loop { + let addr = mmap( + std::ptr::null_mut(), + PAGE_SIZE * refs.len(), + flags, + MAP_SHARED, + self.device.handle.as_raw_fd(), + index as i64, + ); + let errno = Errno::last(); + if addr == MAP_FAILED { + if errno == Errno::EAGAIN { + sleep(Duration::from_micros(1000)); + continue; + } + return Err(Error::MmapFailed(errno)); + } + break addr; + }; + Ok(MappedMemory { - addr, + gnttab: self.clone(), + addr: addr as u64, length: PAGE_SIZE * refs.len(), + _ptr: PhantomData, }) } } + + fn unmap(&self, memory: &MappedMemory<'_>) -> Result<()> { + let (offset, count) = self.device.get_offset_for_vaddr(memory.addr)?; + let _ = unsafe { munmap(memory.addr as *mut c_void, memory.length) }; + self.device.unmap_grant_ref(offset, count)?; + Ok(()) + } } diff --git a/crates/xen/xengnt/src/sys.rs b/crates/xen/xengnt/src/sys.rs index 398b648..b69c50c 100644 --- a/crates/xen/xengnt/src/sys.rs +++ b/crates/xen/xengnt/src/sys.rs @@ -3,7 +3,7 @@ use std::mem::size_of; use nix::{ioc, ioctl_readwrite_bad}; #[repr(C)] -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct GrantRef { pub domid: u32, pub reference: u32, @@ -33,7 +33,7 @@ impl MapGrantRef { return None; } - let index = (*data.get(2)? as u64) << 32 | *data.get(3)? as u64; + let index = (*data.get(2)? as u64) | (*data.get(3)? as u64) << 32; for i in (4..data.len()).step_by(2) { let Some(domid) = data.get(i) else { break; @@ -146,10 +146,10 @@ impl AllocGref { return None; } - let index = (*data.get(4)? as u64) << 48 - | (*data.get(5)? as u64) << 32 - | (*data.get(6)? as u64) << 16 - | *data.get(7)? as u64; + let index = (*data.get(4)? as u64) + | (*data.get(5)? as u64) << 16 + | (*data.get(6)? as u64) << 32 + | (*data.get(7)? as u64) << 48; for i in (8..data.len()).step_by(2) { let Some(bits_low) = data.get(i) else { break; @@ -157,7 +157,7 @@ impl AllocGref { let Some(bits_high) = data.get(i + 1) else { break; }; - refs.push((*bits_low as u32) << 16 | *bits_high as u32); + refs.push((*bits_low as u32) | (*bits_high as u32) << 16); } Some((index, refs)) } diff --git a/crates/xen/xenstore/src/bus.rs b/crates/xen/xenstore/src/bus.rs index e6de48f..1cc9fa0 100644 --- a/crates/xen/xenstore/src/bus.rs +++ b/crates/xen/xenstore/src/bus.rs @@ -1,10 +1,20 @@ -use std::{collections::HashMap, ffi::CString, io::ErrorKind, os::fd::AsRawFd, sync::Arc}; +use std::{ + collections::HashMap, + ffi::CString, + io::ErrorKind, + os::{ + fd::{AsRawFd, FromRawFd, IntoRawFd}, + unix::fs::FileTypeExt, + }, + sync::Arc, +}; use libc::{fcntl, F_GETFL, F_SETFL, O_NONBLOCK}; use log::warn; use tokio::{ fs::{metadata, File}, io::{AsyncReadExt, AsyncWriteExt}, + net::UnixStream, select, sync::{ mpsc::{channel, Receiver, Sender}, @@ -19,14 +29,16 @@ use crate::{ sys::{XsdMessageHeader, XSD_ERROR, XSD_UNWATCH, XSD_WATCH_EVENT}, }; -const XEN_BUS_PATHS: &[&str] = &["/dev/xen/xenbus"]; +const XEN_BUS_PATHS: &[&str] = &["/var/run/xenstored/socket", "/dev/xen/xenbus"]; const XEN_BUS_MAX_PAYLOAD_SIZE: usize = 4096; const XEN_BUS_MAX_PACKET_SIZE: usize = XsdMessageHeader::SIZE + XEN_BUS_MAX_PAYLOAD_SIZE; -async fn find_bus_path() -> Option<&'static str> { +async fn find_bus_path() -> Option<(&'static str, bool)> { for path in XEN_BUS_PATHS { match metadata(path).await { - Ok(_) => return Some(path), + Ok(metadata) => { + return Some((path, metadata.file_type().is_socket())); + } Err(_) => continue, } } @@ -58,17 +70,25 @@ pub struct XsdSocket { impl XsdSocket { pub async fn open() -> Result { - let path = match find_bus_path().await { + let (path, socket) = match find_bus_path().await { Some(path) => path, None => return Err(Error::BusNotFound), }; - let file = File::options() - .read(true) - .write(true) - .custom_flags(O_NONBLOCK) - .open(path) - .await?; + let file = if socket { + let stream = UnixStream::connect(path).await?; + let stream = stream.into_std()?; + stream.set_nonblocking(true)?; + unsafe { File::from_raw_fd(stream.into_raw_fd()) } + } else { + File::options() + .read(true) + .write(true) + .custom_flags(O_NONBLOCK) + .open(path) + .await? + }; + XsdSocket::from_handle(file).await } diff --git a/crates/xen/xenstore/src/lib.rs b/crates/xen/xenstore/src/lib.rs index fd9992a..badb1c0 100644 --- a/crates/xen/xenstore/src/lib.rs +++ b/crates/xen/xenstore/src/lib.rs @@ -43,7 +43,7 @@ impl XsPermission { } pub struct XsdWatchHandle { - id: u32, + pub id: u32, unwatch_sender: Sender, pub receiver: Receiver, } @@ -202,7 +202,11 @@ impl XsdClient { } pub async fn bind_watch>(&self, handle: &XsdWatchHandle, path: P) -> Result<()> { - let id_string = handle.id.to_string(); + self.bind_watch_id(handle.id, path).await + } + + pub async fn bind_watch_id>(&self, id: u32, path: P) -> Result<()> { + let id_string = id.to_string(); let _ = self .socket .send(0, XSD_WATCH, &[path.as_ref(), &id_string])