diff --git a/Cargo.lock b/Cargo.lock index 65b5d85..4104ae3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1228,6 +1228,7 @@ name = "krata" version = "0.0.8" dependencies = [ "anyhow", + "async-trait", "bytes", "libc", "log", diff --git a/crates/daemon/src/event.rs b/crates/daemon/src/event.rs index 903db8b..76257d1 100644 --- a/crates/daemon/src/event.rs +++ b/crates/daemon/src/event.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Result; use krata::{ - idm::protocol::{idm_event::Event, IdmPacket}, + idm::protocol::{idm_event::Event, idm_packet::Content, IdmPacket}, v1::common::{GuestExitInfo, GuestState, GuestStatus}, }; use log::error; @@ -117,10 +117,14 @@ impl DaemonEventGenerator { } async fn handle_idm_packet(&mut self, id: Uuid, packet: IdmPacket) -> Result<()> { - if let Some(Event::Exit(exit)) = packet.event.and_then(|x| x.event) { - self.handle_exit_code(id, exit.code).await?; + match packet.content { + Some(Content::Event(event)) => match event.event { + Some(Event::Exit(exit)) => self.handle_exit_code(id, exit.code).await, + None => Ok(()), + }, + + _ => Ok(()), } - Ok(()) } async fn handle_exit_code(&mut self, id: Uuid, code: i32) -> Result<()> { diff --git a/crates/guest/src/background.rs b/crates/guest/src/background.rs index 264bd9e..9bd0faf 100644 --- a/crates/guest/src/background.rs +++ b/crates/guest/src/background.rs @@ -6,7 +6,7 @@ use anyhow::Result; use cgroups_rs::Cgroup; use krata::idm::{ client::IdmClient, - protocol::{idm_event::Event, IdmEvent, IdmExitEvent, IdmPacket}, + protocol::{idm_event::Event, idm_packet::Content, IdmEvent, IdmExitEvent, IdmPacket}, }; use log::debug; use nix::unistd::Pid; @@ -59,9 +59,9 @@ impl GuestBackground { self.idm .sender .send(IdmPacket { - event: Some(IdmEvent { + content: Some(Content::Event(IdmEvent { event: Some(Event::Exit(IdmExitEvent { code: event.status })), - }), + })), }) .await?; death(event.status).await?; diff --git a/crates/krata/Cargo.toml b/crates/krata/Cargo.toml index a845dcb..e0f46bc 100644 --- a/crates/krata/Cargo.toml +++ b/crates/krata/Cargo.toml @@ -10,6 +10,7 @@ resolver = "2" [dependencies] anyhow = { workspace = true } +async-trait = { workspace = true } bytes = { workspace = true } libc = { workspace = true } log = { workspace = true } diff --git a/crates/krata/proto/krata/internal/idm.proto b/crates/krata/proto/krata/internal/idm.proto index 015fe5a..0ec96de 100644 --- a/crates/krata/proto/krata/internal/idm.proto +++ b/crates/krata/proto/krata/internal/idm.proto @@ -6,8 +6,12 @@ option java_multiple_files = true; option java_package = "dev.krata.proto.internal.idm"; option java_outer_classname = "IdmProto"; -message IdmExitEvent { - int32 code = 1; +message IdmPacket { + oneof content { + IdmEvent event = 1; + IdmRequest request = 2; + IdmResponse response = 3; + } } message IdmEvent { @@ -16,6 +20,24 @@ message IdmEvent { } } -message IdmPacket { - IdmEvent event = 1; +message IdmExitEvent { + int32 code = 1; } + +message IdmRequest { + uint64 id = 1; + oneof request { + IdmPingRequest ping = 2; + } +} + +message IdmPingRequest {} + +message IdmResponse { + uint64 id = 1; + oneof response { + IdmPingResponse ping = 2; + } +} + +message IdmPingResponse {} diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index a3c7096..4be6335 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{path::Path, sync::Arc}; use super::protocol::IdmPacket; use anyhow::{anyhow, Result}; @@ -10,12 +10,68 @@ use tokio::{ fs::File, io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, select, - sync::mpsc::{channel, Receiver, Sender}, + sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, + }, task::JoinHandle, }; const IDM_PACKET_QUEUE_LEN: usize = 100; +#[async_trait::async_trait] +pub trait IdmBackend: Send { + async fn recv(&mut self) -> Result; + async fn send(&mut self, packet: IdmPacket) -> Result<()>; +} + +pub struct IdmFileBackend { + fd: Arc>>, +} + +impl IdmFileBackend { + pub async fn new(file: File) -> Result { + IdmFileBackend::set_raw_port(&file)?; + Ok(IdmFileBackend { + fd: Arc::new(Mutex::new(AsyncFd::new(file)?)), + }) + } + + fn set_raw_port(file: &File) -> Result<()> { + let mut termios = tcgetattr(file)?; + cfmakeraw(&mut termios); + tcsetattr(file, SetArg::TCSANOW, &termios)?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl IdmBackend for IdmFileBackend { + async fn recv(&mut self) -> Result { + let mut fd = self.fd.lock().await; + let mut guard = fd.readable_mut().await?; + let size = guard.get_inner_mut().read_u16_le().await?; + if size == 0 { + return Ok(IdmPacket::default()); + } + let mut buffer = BytesMut::with_capacity(size as usize); + guard.get_inner_mut().read_exact(&mut buffer).await?; + match IdmPacket::decode(buffer) { + Ok(packet) => Ok(packet), + + Err(error) => Err(anyhow!("received invalid idm packet: {}", error)), + } + } + + async fn send(&mut self, packet: IdmPacket) -> Result<()> { + let mut fd = self.fd.lock().await; + let data = packet.encode_to_vec(); + fd.get_mut().write_u16_le(data.len() as u16).await?; + fd.get_mut().write_all(&data).await?; + Ok(()) + } +} + pub struct IdmClient { pub receiver: Receiver, pub sender: Sender, @@ -29,18 +85,11 @@ impl Drop for IdmClient { } impl IdmClient { - pub async fn open>(path: P) -> Result { - let file = File::options() - .read(true) - .write(true) - .create(false) - .open(path) - .await?; - IdmClient::set_raw_port(&file)?; + pub async fn new<'a>(backend: Box) -> Result { let (rx_sender, rx_receiver) = channel(IDM_PACKET_QUEUE_LEN); let (tx_sender, tx_receiver) = channel(IDM_PACKET_QUEUE_LEN); let task = tokio::task::spawn(async move { - if let Err(error) = IdmClient::process(file, rx_sender, tx_receiver).await { + if let Err(error) = IdmClient::process(backend, rx_sender, tx_receiver).await { debug!("failed to handle idm client processing: {}", error); } }); @@ -51,38 +100,27 @@ impl IdmClient { }) } - fn set_raw_port(file: &File) -> Result<()> { - let mut termios = tcgetattr(file)?; - cfmakeraw(&mut termios); - tcsetattr(file, SetArg::TCSANOW, &termios)?; - Ok(()) + pub async fn open>(path: P) -> Result { + let file = File::options() + .read(true) + .write(true) + .create(false) + .open(path) + .await?; + let backend = IdmFileBackend::new(file).await?; + IdmClient::new(Box::new(backend) as Box).await } async fn process( - file: File, + mut backend: Box, sender: Sender, mut receiver: Receiver, ) -> Result<()> { - let mut file = AsyncFd::new(file)?; loop { select! { - x = file.readable_mut() => match x { - Ok(mut guard) => { - let size = guard.get_inner_mut().read_u16_le().await?; - if size == 0 { - continue; - } - let mut buffer = BytesMut::with_capacity(size as usize); - guard.get_inner_mut().read_exact(&mut buffer).await?; - match IdmPacket::decode(buffer) { - Ok(packet) => { - sender.send(packet).await?; - }, - - Err(error) => { - error!("received invalid idm packet: {}", error); - } - } + x = backend.recv() => match x { + Ok(packet) => { + sender.send(packet).await?; }, Err(error) => { @@ -91,13 +129,12 @@ impl IdmClient { }, x = receiver.recv() => match x { Some(packet) => { - let data = packet.encode_to_vec(); - if data.len() > u16::MAX as usize { - error!("unable to send idm packet, packet size exceeded (tried to send {} bytes)", data.len()); + let length = packet.encoded_len(); + if length > u16::MAX as usize { + error!("unable to send idm packet, packet size exceeded (tried to send {} bytes)", length); continue; } - file.get_mut().write_u16_le(data.len() as u16).await?; - file.get_mut().write_all(&data).await?; + backend.send(packet).await?; }, None => {