From 183b4ddf732306a99150b8c818f8ca2d3ccda388 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Wed, 10 Apr 2024 11:44:10 +0000 Subject: [PATCH] feat: implement request response idm system --- crates/guest/src/background.rs | 31 +++++++++++- crates/krata/src/idm/client.rs | 86 ++++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/crates/guest/src/background.rs b/crates/guest/src/background.rs index 1195e65..a8fb732 100644 --- a/crates/guest/src/background.rs +++ b/crates/guest/src/background.rs @@ -6,7 +6,10 @@ use anyhow::Result; use cgroups_rs::Cgroup; use krata::idm::{ client::IdmClient, - protocol::{idm_event::Event, IdmEvent, IdmExitEvent}, + protocol::{ + idm_event::Event, idm_request::Request, idm_response::Response, IdmEvent, IdmExitEvent, + IdmPingResponse, IdmRequest, + }, }; use log::debug; use nix::unistd::Pid; @@ -31,6 +34,7 @@ impl GuestBackground { pub async fn run(&mut self) -> Result<()> { let mut event_subscription = self.idm.subscribe().await?; + let mut requests_subscription = self.idm.requests().await?; loop { select! { x = event_subscription.recv() => match x { @@ -48,6 +52,21 @@ impl GuestBackground { } }, + x = requests_subscription.recv() => match x { + Ok(request) => { + self.handle_idm_request(request).await?; + }, + + Err(broadcast::error::RecvError::Closed) => { + debug!("idm packet channel closed"); + break; + }, + + _ => { + continue; + } + }, + event = self.wait.recv() => match event { Some(event) => self.child_event(event).await?, None => { @@ -59,6 +78,16 @@ impl GuestBackground { Ok(()) } + async fn handle_idm_request(&mut self, packet: IdmRequest) -> Result<()> { + let id = packet.id; + if let Some(Request::Ping(_)) = packet.request { + self.idm + .respond(id, Response::Ping(IdmPingResponse {})) + .await?; + } + Ok(()) + } + async fn child_event(&mut self, event: ChildEvent) -> Result<()> { if event.pid == self.child { self.idm diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index df19dc9..2bb5482 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -1,8 +1,10 @@ -use std::{path::Path, sync::Arc}; +use std::{collections::HashMap, path::Path, sync::Arc}; use crate::idm::protocol::idm_packet::Content; -use super::protocol::{IdmEvent, IdmPacket}; +use super::protocol::{ + idm_request::Request, idm_response::Response, IdmEvent, IdmPacket, IdmRequest, IdmResponse, +}; use anyhow::{anyhow, Result}; use bytes::BytesMut; use log::{debug, error}; @@ -15,11 +17,13 @@ use tokio::{ sync::{ broadcast, mpsc::{channel, Receiver, Sender}, - Mutex, + oneshot, Mutex, }, task::JoinHandle, }; +type RequestMap = Arc>>>; + const IDM_PACKET_QUEUE_LEN: usize = 100; #[async_trait::async_trait] @@ -77,8 +81,11 @@ impl IdmBackend for IdmFileBackend { #[derive(Clone)] pub struct IdmClient { + request_backend_sender: broadcast::Sender, + next_request_id: Arc>, event_receiver_sender: broadcast::Sender, tx_sender: Sender, + requests: RequestMap, task: Arc>, } @@ -92,18 +99,31 @@ impl Drop for IdmClient { impl IdmClient { pub async fn new(backend: Box) -> Result { + let requests = Arc::new(Mutex::new(HashMap::new())); let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN); + let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN); let (tx_sender, tx_receiver) = channel(IDM_PACKET_QUEUE_LEN); let backend_event_sender = event_sender.clone(); + let request_backend_sender = internal_request_backend_sender.clone(); let task = tokio::task::spawn(async move { - if let Err(error) = - IdmClient::process(backend, backend_event_sender, event_receiver, tx_receiver).await + if let Err(error) = IdmClient::process( + backend, + backend_event_sender, + requests, + internal_request_backend_sender, + event_receiver, + tx_receiver, + ) + .await { debug!("failed to handle idm client processing: {}", error); } }); Ok(IdmClient { + next_request_id: Arc::new(Mutex::new(0)), event_receiver_sender: event_sender.clone(), + request_backend_sender, + requests: Arc::new(Mutex::new(HashMap::new())), tx_sender, task: Arc::new(task), }) @@ -129,13 +149,57 @@ impl IdmClient { Ok(()) } + pub async fn requests(&self) -> Result> { + Ok(self.request_backend_sender.subscribe()) + } + + pub async fn respond(&self, id: u64, response: Response) -> Result<()> { + let packet = IdmPacket { + content: Some(Content::Response(IdmResponse { + id, + response: Some(response), + })), + }; + self.tx_sender.send(packet).await?; + Ok(()) + } + pub async fn subscribe(&self) -> Result> { Ok(self.event_receiver_sender.subscribe()) } + pub async fn send(&self, request: Request) -> Result { + let (sender, receiver) = oneshot::channel(); + let mut requests = self.requests.lock().await; + let req = { + let mut guard = self.next_request_id.lock().await; + let req = *guard; + *guard = req.wrapping_add(1); + req + }; + requests.insert(req, sender); + drop(requests); + self.tx_sender + .send(IdmPacket { + content: Some(Content::Request(IdmRequest { + id: req, + request: Some(request), + })), + }) + .await?; + + if let Some(response) = receiver.await?.response { + Ok(response) + } else { + Err(anyhow!("response did not contain any content")) + } + } + async fn process( mut backend: Box, event_sender: broadcast::Sender, + requests: RequestMap, + request_backend_sender: broadcast::Sender, _event_receiver: broadcast::Receiver, mut receiver: Receiver, ) -> Result<()> { @@ -148,6 +212,18 @@ impl IdmClient { let _ = event_sender.send(event); }, + Some(Content::Request(request)) => { + let _ = request_backend_sender.send(request); + }, + + Some(Content::Response(response)) => { + let mut requests = requests.lock().await; + if let Some(sender) = requests.remove(&response.id) { + drop(requests); + let _ = sender.send(response); + } + }, + _ => {}, } },