diff --git a/crates/ctl/src/cli/exec.rs b/crates/ctl/src/cli/exec.rs new file mode 100644 index 0000000..8b56223 --- /dev/null +++ b/crates/ctl/src/cli/exec.rs @@ -0,0 +1,70 @@ +use std::collections::HashMap; + +use anyhow::Result; + +use clap::Parser; +use krata::v1::{ + common::{GuestTaskSpec, GuestTaskSpecEnvVar}, + control::{control_service_client::ControlServiceClient, ExecGuestRequest}, +}; + +use tonic::{transport::Channel, Request}; + +use crate::console::StdioConsoleStream; + +use super::resolve_guest; + +#[derive(Parser)] +#[command(about = "Execute a command inside the guest")] +pub struct ExecCommand { + #[arg[short, long, help = "Environment variables"]] + env: Option>, + #[arg(short = 'w', long, help = "Working directory")] + working_directory: Option, + #[arg(help = "Guest to exec inside, either the name or the uuid")] + guest: String, + #[arg( + allow_hyphen_values = true, + trailing_var_arg = true, + help = "Command to run inside the guest" + )] + command: Vec, +} + +impl ExecCommand { + pub async fn run(self, mut client: ControlServiceClient) -> Result<()> { + let guest_id: String = resolve_guest(&mut client, &self.guest).await?; + let initial = ExecGuestRequest { + guest_id, + task: Some(GuestTaskSpec { + environment: env_map(&self.env.unwrap_or_default()) + .iter() + .map(|(key, value)| GuestTaskSpecEnvVar { + key: key.clone(), + value: value.clone(), + }) + .collect(), + command: self.command, + working_directory: self.working_directory.unwrap_or_default(), + }), + data: vec![], + }; + + let stream = StdioConsoleStream::stdin_stream_exec(initial).await; + + let response = client.exec_guest(Request::new(stream)).await?.into_inner(); + + let code = StdioConsoleStream::exec_output(response).await?; + std::process::exit(code); + } +} + +fn env_map(env: &[String]) -> HashMap { + let mut map = HashMap::::new(); + for item in env { + if let Some((key, value)) = item.split_once('=') { + map.insert(key.to_string(), value.to_string()); + } + } + map +} diff --git a/crates/ctl/src/cli/idm_snoop.rs b/crates/ctl/src/cli/idm_snoop.rs index 228e350..b66c7c2 100644 --- a/crates/ctl/src/cli/idm_snoop.rs +++ b/crates/ctl/src/cli/idm_snoop.rs @@ -106,13 +106,19 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option { .ok() .and_then(|event| proto2dynamic(event).ok()), - IdmTransportPacketForm::Request => internal::Request::decode(&packet.data) - .ok() - .and_then(|event| proto2dynamic(event).ok()), + IdmTransportPacketForm::Request + | IdmTransportPacketForm::StreamRequest + | IdmTransportPacketForm::StreamRequestUpdate => { + internal::Request::decode(&packet.data) + .ok() + .and_then(|event| proto2dynamic(event).ok()) + } - IdmTransportPacketForm::Response => internal::Response::decode(&packet.data) - .ok() - .and_then(|event| proto2dynamic(event).ok()), + IdmTransportPacketForm::Response | IdmTransportPacketForm::StreamResponseUpdate => { + internal::Response::decode(&packet.data) + .ok() + .and_then(|event| proto2dynamic(event).ok()) + } _ => None, } @@ -132,6 +138,11 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option { IdmTransportPacketForm::Event => "event".to_string(), IdmTransportPacketForm::Request => "request".to_string(), IdmTransportPacketForm::Response => "response".to_string(), + IdmTransportPacketForm::StreamRequest => "stream-request".to_string(), + IdmTransportPacketForm::StreamRequestUpdate => "stream-request-update".to_string(), + IdmTransportPacketForm::StreamRequestClosed => "stream-request-closed".to_string(), + IdmTransportPacketForm::StreamResponseUpdate => "stream-response-update".to_string(), + IdmTransportPacketForm::StreamResponseClosed => "stream-response-closed".to_string(), _ => format!("unknown-{}", packet.form), }, data: base64::prelude::BASE64_STANDARD.encode(&packet.data), diff --git a/crates/ctl/src/cli/launch.rs b/crates/ctl/src/cli/launch.rs index 01bfd8f..aa52141 100644 --- a/crates/ctl/src/cli/launch.rs +++ b/crates/ctl/src/cli/launch.rs @@ -29,7 +29,7 @@ pub enum LaunchImageFormat { #[derive(Parser)] #[command(about = "Launch a new guest")] -pub struct LauchCommand { +pub struct LaunchCommand { #[arg(long, default_value = "squashfs", help = "Image format")] image_format: LaunchImageFormat, #[arg(long, help = "Overwrite image cache on pull")] @@ -68,6 +68,8 @@ pub struct LauchCommand { kernel: Option, #[arg(short = 'I', long, help = "OCI initrd image for guest to use")] initrd: Option, + #[arg(short = 'w', long, help = "Working directory")] + working_directory: Option, #[arg(help = "Container image for guest to use")] oci: String, #[arg( @@ -78,7 +80,7 @@ pub struct LauchCommand { command: Vec, } -impl LauchCommand { +impl LaunchCommand { pub async fn run( self, mut client: ControlServiceClient, @@ -130,6 +132,7 @@ impl LauchCommand { }) .collect(), command: self.command, + working_directory: self.working_directory.unwrap_or_default(), }), annotations: vec![], }), diff --git a/crates/ctl/src/cli/mod.rs b/crates/ctl/src/cli/mod.rs index 4653dd3..cc549cf 100644 --- a/crates/ctl/src/cli/mod.rs +++ b/crates/ctl/src/cli/mod.rs @@ -1,5 +1,6 @@ pub mod attach; pub mod destroy; +pub mod exec; pub mod identify_host; pub mod idm_snoop; pub mod launch; @@ -21,10 +22,10 @@ use krata::{ use tonic::{transport::Channel, Request}; use self::{ - attach::AttachCommand, destroy::DestroyCommand, identify_host::IdentifyHostCommand, - idm_snoop::IdmSnoopCommand, launch::LauchCommand, list::ListCommand, logs::LogsCommand, - metrics::MetricsCommand, pull::PullCommand, resolve::ResolveCommand, top::TopCommand, - watch::WatchCommand, + attach::AttachCommand, destroy::DestroyCommand, exec::ExecCommand, + identify_host::IdentifyHostCommand, idm_snoop::IdmSnoopCommand, launch::LaunchCommand, + list::ListCommand, logs::LogsCommand, metrics::MetricsCommand, pull::PullCommand, + resolve::ResolveCommand, top::TopCommand, watch::WatchCommand, }; #[derive(Parser)] @@ -47,7 +48,7 @@ pub struct ControlCommand { #[derive(Subcommand)] pub enum Commands { - Launch(LauchCommand), + Launch(LaunchCommand), Destroy(DestroyCommand), List(ListCommand), Attach(AttachCommand), @@ -59,6 +60,7 @@ pub enum Commands { IdmSnoop(IdmSnoopCommand), Top(TopCommand), IdentifyHost(IdentifyHostCommand), + Exec(ExecCommand), } impl ControlCommand { @@ -114,6 +116,10 @@ impl ControlCommand { Commands::IdentifyHost(identify) => { identify.run(client).await?; } + + Commands::Exec(exec) => { + exec.run(client).await?; + } } Ok(()) } diff --git a/crates/ctl/src/console.rs b/crates/ctl/src/console.rs index a57edd8..57efb67 100644 --- a/crates/ctl/src/console.rs +++ b/crates/ctl/src/console.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use async_stream::stream; use crossterm::{ terminal::{disable_raw_mode, enable_raw_mode, is_raw_mode_enabled}, @@ -8,12 +8,15 @@ use krata::{ events::EventStream, v1::{ common::GuestStatus, - control::{watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest}, + control::{ + watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest, ExecGuestReply, + ExecGuestRequest, + }, }, }; use log::debug; use tokio::{ - io::{stdin, stdout, AsyncReadExt, AsyncWriteExt}, + io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt}, task::JoinHandle, }; use tokio_stream::{Stream, StreamExt}; @@ -45,6 +48,31 @@ impl StdioConsoleStream { } } + pub async fn stdin_stream_exec( + initial: ExecGuestRequest, + ) -> impl Stream { + let mut stdin = stdin(); + stream! { + yield initial; + + let mut buffer = vec![0u8; 60]; + loop { + let size = match stdin.read(&mut buffer).await { + Ok(size) => size, + Err(error) => { + debug!("failed to read stdin: {}", error); + break; + } + }; + let data = buffer[0..size].to_vec(); + if size == 1 && buffer[0] == 0x1d { + break; + } + yield ExecGuestRequest { guest_id: String::default(), task: None, data }; + } + } + } + pub async fn stdout(mut stream: Streaming) -> Result<()> { if stdin().is_tty() { enable_raw_mode()?; @@ -62,6 +90,32 @@ impl StdioConsoleStream { Ok(()) } + pub async fn exec_output(mut stream: Streaming) -> Result { + let mut stdout = stdout(); + let mut stderr = stderr(); + while let Some(reply) = stream.next().await { + let reply = reply?; + if !reply.stdout.is_empty() { + stdout.write_all(&reply.stdout).await?; + stdout.flush().await?; + } + + if !reply.stderr.is_empty() { + stderr.write_all(&reply.stderr).await?; + stderr.flush().await?; + } + + if reply.exited { + if reply.error.is_empty() { + return Ok(reply.exit_code); + } else { + return Err(anyhow!("exec failed: {}", reply.error)); + } + } + } + Ok(-1) + } + pub async fn guest_exit_hook( id: String, events: EventStream, diff --git a/crates/daemon/src/control.rs b/crates/daemon/src/control.rs index f56bfda..486adfd 100644 --- a/crates/daemon/src/control.rs +++ b/crates/daemon/src/control.rs @@ -2,18 +2,19 @@ use async_stream::try_stream; use futures::Stream; use krata::{ idm::internal::{ - request::Request as IdmRequestType, response::Response as IdmResponseType, MetricsRequest, - Request as IdmRequest, + exec_stream_request_update::Update, request::Request as IdmRequestType, + response::Response as IdmResponseType, ExecEnvVar, ExecStreamRequestStart, + ExecStreamRequestStdin, ExecStreamRequestUpdate, MetricsRequest, Request as IdmRequest, }, v1::{ common::{Guest, GuestState, GuestStatus, OciImageFormat}, control::{ control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest, CreateGuestReply, CreateGuestRequest, DestroyGuestReply, DestroyGuestRequest, - IdentifyHostReply, IdentifyHostRequest, ListGuestsReply, ListGuestsRequest, - PullImageReply, PullImageRequest, ReadGuestMetricsReply, ReadGuestMetricsRequest, - ResolveGuestReply, ResolveGuestRequest, SnoopIdmReply, SnoopIdmRequest, - WatchEventsReply, WatchEventsRequest, + ExecGuestReply, ExecGuestRequest, IdentifyHostReply, IdentifyHostRequest, + ListGuestsReply, ListGuestsRequest, PullImageReply, PullImageRequest, + ReadGuestMetricsReply, ReadGuestMetricsRequest, ResolveGuestReply, ResolveGuestRequest, + SnoopIdmReply, SnoopIdmRequest, WatchEventsReply, WatchEventsRequest, }, }, }; @@ -101,6 +102,9 @@ enum PullImageSelect { #[tonic::async_trait] impl ControlService for DaemonControlService { + type ExecGuestStream = + Pin> + Send + 'static>>; + type ConsoleDataStream = Pin> + Send + 'static>>; @@ -166,6 +170,98 @@ impl ControlService for DaemonControlService { })) } + async fn exec_guest( + &self, + request: Request>, + ) -> Result, Status> { + let mut input = request.into_inner(); + let Some(request) = input.next().await else { + return Err(ApiError { + message: "expected to have at least one request".to_string(), + } + .into()); + }; + let request = request?; + + let Some(task) = request.task else { + return Err(ApiError { + message: "task is missing".to_string(), + } + .into()); + }; + + let uuid = Uuid::from_str(&request.guest_id).map_err(|error| ApiError { + message: error.to_string(), + })?; + let idm = self.idm.client(uuid).await.map_err(|error| ApiError { + message: error.to_string(), + })?; + + let idm_request = IdmRequest { + request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate { + update: Some(Update::Start(ExecStreamRequestStart { + environment: task + .environment + .into_iter() + .map(|x| ExecEnvVar { + key: x.key, + value: x.value, + }) + .collect(), + command: task.command, + working_directory: task.working_directory, + })), + })), + }; + + let output = try_stream! { + let mut handle = idm.send_stream(idm_request).await.map_err(|x| ApiError { + message: x.to_string(), + })?; + + loop { + select! { + x = input.next() => if let Some(update) = x { + let update: Result = update.map_err(|error| ApiError { + message: error.to_string() + }.into()); + + if let Ok(update) = update { + if !update.data.is_empty() { + let _ = handle.update(IdmRequest { + request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate { + update: Some(Update::Stdin(ExecStreamRequestStdin { + data: update.data, + })), + }))}).await; + } + } + }, + x = handle.receiver.recv() => match x { + Some(response) => { + let Some(IdmResponseType::ExecStream(update)) = response.response else { + break; + }; + let reply = ExecGuestReply { + exited: update.exited, + error: update.error, + exit_code: update.exit_code, + stdout: update.stdout, + stderr: update.stderr + }; + yield reply; + }, + None => { + break; + } + } + }; + } + }; + + Ok(Response::new(Box::pin(output) as Self::ExecGuestStream)) + } + async fn destroy_guest( &self, request: Request, diff --git a/crates/guest/src/background.rs b/crates/guest/src/background.rs index d981343..ddd21b5 100644 --- a/crates/guest/src/background.rs +++ b/crates/guest/src/background.rs @@ -1,16 +1,17 @@ use crate::{ childwait::{ChildEvent, ChildWait}, death, + exec::GuestExecTask, metrics::MetricsCollector, }; use anyhow::Result; use cgroups_rs::Cgroup; use krata::idm::{ - client::IdmInternalClient, + client::{IdmClientStreamResponseHandle, IdmInternalClient}, internal::{ event::Event as EventType, request::Request as RequestType, - response::Response as ResponseType, Event, ExitEvent, MetricsResponse, PingResponse, - Request, Response, + response::Response as ResponseType, Event, ExecStreamResponseUpdate, ExitEvent, + MetricsResponse, PingResponse, Request, Response, }, }; use log::debug; @@ -41,11 +42,11 @@ impl GuestBackground { pub async fn run(&mut self) -> Result<()> { let mut event_subscription = self.idm.subscribe().await?; let mut requests_subscription = self.idm.requests().await?; + let mut request_streams_subscription = self.idm.request_streams().await?; loop { select! { x = event_subscription.recv() => match x { Ok(_event) => { - }, Err(broadcast::error::RecvError::Closed) => { @@ -73,6 +74,21 @@ impl GuestBackground { } }, + x = request_streams_subscription.recv() => match x { + Ok(handle) => { + self.handle_idm_stream_request(handle).await?; + }, + + Err(broadcast::error::RecvError::Closed) => { + debug!("idm packet channel closed"); + break; + }, + + _ => { + continue; + } + }, + event = self.wait.recv() => match event { Some(event) => self.child_event(event).await?, None => { @@ -107,7 +123,33 @@ impl GuestBackground { self.idm.respond(id, response).await?; } - None => {} + _ => {} + } + Ok(()) + } + + async fn handle_idm_stream_request( + &mut self, + handle: IdmClientStreamResponseHandle, + ) -> Result<()> { + if let Some(RequestType::ExecStream(_)) = &handle.initial.request { + tokio::task::spawn(async move { + let exec = GuestExecTask { handle }; + if let Err(error) = exec.run().await { + let _ = exec + .handle + .respond(Response { + response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate { + exited: true, + error: error.to_string(), + exit_code: -1, + stdout: vec![], + stderr: vec![], + })), + }) + .await; + } + }); } Ok(()) } diff --git a/crates/guest/src/exec.rs b/crates/guest/src/exec.rs new file mode 100644 index 0000000..82fb360 --- /dev/null +++ b/crates/guest/src/exec.rs @@ -0,0 +1,172 @@ +use std::{collections::HashMap, process::Stdio}; + +use anyhow::{anyhow, Result}; +use krata::idm::{ + client::IdmClientStreamResponseHandle, + internal::{ + exec_stream_request_update::Update, request::Request as RequestType, + ExecStreamResponseUpdate, + }, + internal::{response::Response as ResponseType, Request, Response}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + join, + process::Command, +}; + +pub struct GuestExecTask { + pub handle: IdmClientStreamResponseHandle, +} + +impl GuestExecTask { + pub async fn run(&self) -> Result<()> { + let mut receiver = self.handle.take().await?; + + let Some(ref request) = self.handle.initial.request else { + return Err(anyhow!("request was empty")); + }; + + let RequestType::ExecStream(update) = request else { + return Err(anyhow!("request was not an exec update")); + }; + + let Some(Update::Start(ref start)) = update.update else { + return Err(anyhow!("first request did not contain a start update")); + }; + + let mut cmd = start.command.clone(); + if cmd.is_empty() { + return Err(anyhow!("command line was empty")); + } + let exe = cmd.remove(0); + let mut env = HashMap::new(); + for entry in &start.environment { + env.insert(entry.key.clone(), entry.value.clone()); + } + + if !env.contains_key("PATH") { + env.insert( + "PATH".to_string(), + "/bin:/usr/bin:/usr/local/bin".to_string(), + ); + } + + let dir = if start.working_directory.is_empty() { + "/".to_string() + } else { + start.working_directory.clone() + }; + + let mut child = Command::new(exe) + .args(cmd) + .envs(env) + .current_dir(dir) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .map_err(|error| anyhow!("failed to spawn: {}", error))?; + + let mut stdin = child + .stdin + .take() + .ok_or_else(|| anyhow!("stdin was missing"))?; + let mut stdout = child + .stdout + .take() + .ok_or_else(|| anyhow!("stdout was missing"))?; + let mut stderr = child + .stderr + .take() + .ok_or_else(|| anyhow!("stderr was missing"))?; + + let stdout_handle = self.handle.clone(); + let stdout_task = tokio::task::spawn(async move { + let mut stdout_buffer = vec![0u8; 8 * 1024]; + loop { + let Ok(size) = stdout.read(&mut stdout_buffer).await else { + break; + }; + if size > 0 { + let response = Response { + response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate { + exited: false, + exit_code: 0, + error: String::new(), + stdout: stdout_buffer[0..size].to_vec(), + stderr: vec![], + })), + }; + let _ = stdout_handle.respond(response).await; + } else { + break; + } + } + }); + + let stderr_handle = self.handle.clone(); + let stderr_task = tokio::task::spawn(async move { + let mut stderr_buffer = vec![0u8; 8 * 1024]; + loop { + let Ok(size) = stderr.read(&mut stderr_buffer).await else { + break; + }; + if size > 0 { + let response = Response { + response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate { + exited: false, + exit_code: 0, + error: String::new(), + stdout: vec![], + stderr: stderr_buffer[0..size].to_vec(), + })), + }; + let _ = stderr_handle.respond(response).await; + } else { + break; + } + } + }); + + let stdin_task = tokio::task::spawn(async move { + loop { + let Some(request) = receiver.recv().await else { + break; + }; + + let Some(RequestType::ExecStream(update)) = request.request else { + continue; + }; + + let Some(Update::Stdin(update)) = update.update else { + continue; + }; + + if stdin.write_all(&update.data).await.is_err() { + break; + } + } + }); + + let exit = child.wait().await?; + let code = exit.code().unwrap_or(-1); + + let _ = join!(stdout_task, stderr_task); + stdin_task.abort(); + + let response = Response { + response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate { + exited: true, + exit_code: code, + error: String::new(), + stdout: vec![], + stderr: vec![], + })), + }; + self.handle.respond(response).await?; + + Ok(()) + } +} diff --git a/crates/guest/src/init.rs b/crates/guest/src/init.rs index 1f1189f..7ef9854 100644 --- a/crates/guest/src/init.rs +++ b/crates/guest/src/init.rs @@ -479,7 +479,7 @@ impl GuestInit { env.insert("TERM".to_string(), "xterm".to_string()); } - let path = GuestInit::resolve_executable(&env, path.into())?; + let path = resolve_executable(&env, path.into())?; let Some(file_name) = path.file_name() else { return Err(anyhow!("cannot get file name of command path")); }; @@ -537,27 +537,6 @@ impl GuestInit { map } - fn resolve_executable(env: &HashMap, path: PathBuf) -> Result { - if path.is_absolute() { - return Ok(path); - } - - if path.is_file() { - return Ok(path.absolutize()?.to_path_buf()); - } - - if let Some(path_var) = env.get("PATH") { - for item in path_var.split(':') { - let mut exe_path: PathBuf = item.into(); - exe_path.push(&path); - if exe_path.is_file() { - return Ok(exe_path); - } - } - } - Ok(path) - } - fn env_list(env: HashMap) -> Vec { env.iter() .map(|(key, value)| format!("{}={}", key, value)) @@ -613,3 +592,24 @@ impl GuestInit { Ok(()) } } + +pub fn resolve_executable(env: &HashMap, path: PathBuf) -> Result { + if path.is_absolute() { + return Ok(path); + } + + if path.is_file() { + return Ok(path.absolutize()?.to_path_buf()); + } + + if let Some(path_var) = env.get("PATH") { + for item in path_var.split(':') { + let mut exe_path: PathBuf = item.into(); + exe_path.push(&path); + if exe_path.is_file() { + return Ok(exe_path); + } + } + } + Ok(path) +} diff --git a/crates/guest/src/lib.rs b/crates/guest/src/lib.rs index b88e833..a5afc46 100644 --- a/crates/guest/src/lib.rs +++ b/crates/guest/src/lib.rs @@ -6,6 +6,7 @@ use xenstore::{XsdClient, XsdInterface}; pub mod background; pub mod childwait; +pub mod exec; pub mod init; pub mod metrics; diff --git a/crates/krata/proto/krata/idm/internal.proto b/crates/krata/proto/krata/idm/internal.proto index 634643f..5c96a12 100644 --- a/crates/krata/proto/krata/idm/internal.proto +++ b/crates/krata/proto/krata/idm/internal.proto @@ -36,6 +36,36 @@ enum MetricFormat { METRIC_FORMAT_DURATION_SECONDS = 3; } +message ExecEnvVar { + string key = 1; + string value = 2; +} + +message ExecStreamRequestStart { + repeated ExecEnvVar environment = 1; + repeated string command = 2; + string working_directory = 3; +} + +message ExecStreamRequestStdin { + bytes data = 1; +} + +message ExecStreamRequestUpdate { + oneof update { + ExecStreamRequestStart start = 1; + ExecStreamRequestStdin stdin = 2; + } +} + +message ExecStreamResponseUpdate { + bool exited = 1; + string error = 2; + int32 exit_code = 3; + bytes stdout = 4; + bytes stderr = 5; +} + message Event { oneof event { ExitEvent exit = 1; @@ -46,6 +76,7 @@ message Request { oneof request { PingRequest ping = 1; MetricsRequest metrics = 2; + ExecStreamRequestUpdate exec_stream = 3; } } @@ -53,5 +84,6 @@ message Response { oneof response { PingResponse ping = 1; MetricsResponse metrics = 2; + ExecStreamResponseUpdate exec_stream = 3; } } diff --git a/crates/krata/proto/krata/idm/transport.proto b/crates/krata/proto/krata/idm/transport.proto index 37d9283..5fc7b86 100644 --- a/crates/krata/proto/krata/idm/transport.proto +++ b/crates/krata/proto/krata/idm/transport.proto @@ -19,4 +19,9 @@ enum IdmTransportPacketForm { IDM_TRANSPORT_PACKET_FORM_EVENT = 2; IDM_TRANSPORT_PACKET_FORM_REQUEST = 3; IDM_TRANSPORT_PACKET_FORM_RESPONSE = 4; + IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST = 5; + IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_UPDATE = 6; + IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_UPDATE = 7; + IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_CLOSED = 8; + IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_CLOSED = 9; } diff --git a/crates/krata/proto/krata/v1/common.proto b/crates/krata/proto/krata/v1/common.proto index 0b07c46..d5e18c8 100644 --- a/crates/krata/proto/krata/v1/common.proto +++ b/crates/krata/proto/krata/v1/common.proto @@ -49,6 +49,7 @@ message GuestOciImageSpec { message GuestTaskSpec { repeated GuestTaskSpecEnvVar environment = 1; repeated string command = 2; + string working_directory = 3; } message GuestTaskSpecEnvVar { diff --git a/crates/krata/proto/krata/v1/control.proto b/crates/krata/proto/krata/v1/control.proto index 03a0881..ca49f5b 100644 --- a/crates/krata/proto/krata/v1/control.proto +++ b/crates/krata/proto/krata/v1/control.proto @@ -17,6 +17,8 @@ service ControlService { rpc ResolveGuest(ResolveGuestRequest) returns (ResolveGuestReply); rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply); + rpc ExecGuest(stream ExecGuestRequest) returns (stream ExecGuestReply); + rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply); rpc ReadGuestMetrics(ReadGuestMetricsRequest) returns (ReadGuestMetricsReply); @@ -62,6 +64,20 @@ message ListGuestsReply { repeated krata.v1.common.Guest guests = 1; } +message ExecGuestRequest { + string guest_id = 1; + krata.v1.common.GuestTaskSpec task = 2; + bytes data = 3; +} + +message ExecGuestReply { + bool exited = 1; + string error = 2; + int32 exit_code = 3; + bytes stdout = 4; + bytes stderr = 5; +} + message ConsoleDataRequest { string guest_id = 1; bytes data = 2; diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index f6a740a..a935250 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -31,7 +31,9 @@ use super::{ transport::{IdmTransportPacket, IdmTransportPacketForm}, }; -type RequestMap = Arc::Response>>>>; +type OneshotRequestMap = Arc::Response>>>>; +type StreamRequestMap = Arc::Response>>>>; +type StreamRequestUpdateMap = Arc>>>; pub type IdmInternalClient = IdmClient; const IDM_PACKET_QUEUE_LEN: usize = 100; @@ -106,10 +108,12 @@ impl IdmBackend for IdmFileBackend { pub struct IdmClient { channel: u64, request_backend_sender: broadcast::Sender<(u64, R)>, + request_stream_backend_sender: broadcast::Sender>, next_request_id: Arc>, event_receiver_sender: broadcast::Sender, tx_sender: Sender, - requests: RequestMap, + requests: OneshotRequestMap, + request_streams: StreamRequestMap, task: Arc>, } @@ -121,21 +125,122 @@ impl Drop for IdmClient { } } +pub struct IdmClientStreamRequestHandle { + pub id: u64, + pub receiver: Receiver, + pub client: IdmClient, +} + +impl IdmClientStreamRequestHandle { + pub async fn update(&self, request: R) -> Result<()> { + self.client + .tx_sender + .send(IdmTransportPacket { + id: self.id, + channel: self.client.channel, + form: IdmTransportPacketForm::StreamRequestUpdate.into(), + data: request.encode()?, + }) + .await?; + Ok(()) + } +} + +impl Drop for IdmClientStreamRequestHandle { + fn drop(&mut self) { + let id = self.id; + let client = self.client.clone(); + tokio::task::spawn(async move { + let _ = client + .tx_sender + .send(IdmTransportPacket { + id, + channel: client.channel, + form: IdmTransportPacketForm::StreamRequestClosed.into(), + data: vec![], + }) + .await; + }); + } +} + +#[derive(Clone)] +pub struct IdmClientStreamResponseHandle { + pub initial: R, + pub id: u64, + channel: u64, + tx_sender: Sender, + receiver: Arc>>>, +} + +impl IdmClientStreamResponseHandle { + pub async fn respond(&self, response: R::Response) -> Result<()> { + self.tx_sender + .send(IdmTransportPacket { + id: self.id, + channel: self.channel, + form: IdmTransportPacketForm::StreamResponseUpdate.into(), + data: response.encode()?, + }) + .await?; + Ok(()) + } + + pub async fn take(&self) -> Result> { + let mut guard = self.receiver.lock().await; + let Some(receiver) = (*guard).take() else { + return Err(anyhow!("request has already been claimed!")); + }; + Ok(receiver) + } +} + +impl Drop for IdmClientStreamResponseHandle { + fn drop(&mut self) { + if Arc::strong_count(&self.receiver) <= 1 { + let id = self.id; + let channel = self.channel; + let tx_sender = self.tx_sender.clone(); + tokio::task::spawn(async move { + let _ = tx_sender + .send(IdmTransportPacket { + id, + channel, + form: IdmTransportPacketForm::StreamResponseClosed.into(), + data: vec![], + }) + .await; + }); + } + } +} + impl IdmClient { pub async fn new(channel: u64, backend: Box) -> Result { let requests = Arc::new(Mutex::new(HashMap::new())); + let request_streams = Arc::new(Mutex::new(HashMap::new())); + let request_update_streams = Arc::new(Mutex::new(HashMap::new())); let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN); let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN); + let (internal_request_stream_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN); let (tx_sender, tx_receiver) = mpsc::channel(IDM_PACKET_QUEUE_LEN); let backend_event_sender = event_sender.clone(); let request_backend_sender = internal_request_backend_sender.clone(); + let request_stream_backend_sender = internal_request_stream_backend_sender.clone(); let requests_for_client = requests.clone(); + let request_streams_for_client = request_streams.clone(); + let tx_sender_for_client = tx_sender.clone(); let task = tokio::task::spawn(async move { if let Err(error) = IdmClient::process( backend, + channel, + tx_sender, backend_event_sender, requests, + request_streams, + request_update_streams, internal_request_backend_sender, + internal_request_stream_backend_sender, event_receiver, tx_receiver, ) @@ -149,8 +254,10 @@ impl IdmClient { next_request_id: Arc::new(Mutex::new(0)), event_receiver_sender: event_sender.clone(), request_backend_sender, + request_stream_backend_sender, requests: requests_for_client, - tx_sender, + request_streams: request_streams_for_client, + tx_sender: tx_sender_for_client, task: Arc::new(task), }) } @@ -194,6 +301,12 @@ impl IdmClient { Ok(self.request_backend_sender.subscribe()) } + pub async fn request_streams( + &self, + ) -> Result>> { + Ok(self.request_stream_backend_sender.subscribe()) + } + pub async fn respond(&self, id: u64, response: T) -> Result<()> { let packet = IdmTransportPacket { id, @@ -244,11 +357,43 @@ impl IdmClient { Ok(response) } + pub async fn send_stream(&self, request: R) -> Result> { + let (sender, receiver) = mpsc::channel::(100); + let req = { + let mut guard = self.next_request_id.lock().await; + let req = *guard; + *guard = req.wrapping_add(1); + req + }; + let mut requests = self.request_streams.lock().await; + requests.insert(req, sender); + drop(requests); + self.tx_sender + .send(IdmTransportPacket { + id: req, + channel: self.channel, + form: IdmTransportPacketForm::StreamRequest.into(), + data: request.encode()?, + }) + .await?; + Ok(IdmClientStreamRequestHandle { + id: req, + receiver, + client: self.clone(), + }) + } + + #[allow(clippy::too_many_arguments)] async fn process( mut backend: Box, + channel: u64, + tx_sender: Sender, event_sender: broadcast::Sender, - requests: RequestMap, + requests: OneshotRequestMap, + request_streams: StreamRequestMap, + request_update_streams: StreamRequestUpdateMap, request_backend_sender: broadcast::Sender<(u64, R)>, + request_stream_backend_sender: broadcast::Sender>, _event_receiver: broadcast::Receiver, mut receiver: Receiver, ) -> Result<()> { @@ -256,6 +401,10 @@ impl IdmClient { select! { x = backend.recv() => match x { Ok(packet) => { + if packet.channel != channel { + continue; + } + match packet.form() { IdmTransportPacketForm::Event => { if let Ok(event) = E::decode(&packet.data) { @@ -280,6 +429,50 @@ impl IdmClient { } }, + IdmTransportPacketForm::StreamRequest => { + if let Ok(request) = R::decode(&packet.data) { + let mut update_streams = request_update_streams.lock().await; + let (sender, receiver) = mpsc::channel(100); + update_streams.insert(packet.id, sender.clone()); + let handle = IdmClientStreamResponseHandle { + initial: request, + id: packet.id, + channel, + tx_sender: tx_sender.clone(), + receiver: Arc::new(Mutex::new(Some(receiver))), + }; + let _ = request_stream_backend_sender.send(handle); + } + } + + IdmTransportPacketForm::StreamRequestUpdate => { + if let Ok(request) = R::decode(&packet.data) { + let mut update_streams = request_update_streams.lock().await; + if let Some(stream) = update_streams.get_mut(&packet.id) { + let _ = stream.try_send(request); + } + } + } + + IdmTransportPacketForm::StreamRequestClosed => { + let mut update_streams = request_update_streams.lock().await; + update_streams.remove(&packet.id); + } + + IdmTransportPacketForm::StreamResponseUpdate => { + let requests = request_streams.lock().await; + if let Some(sender) = requests.get(&packet.id) { + if let Ok(response) = R::Response::decode(&packet.data) { + let _ = sender.try_send(response); + } + } + } + + IdmTransportPacketForm::StreamResponseClosed => { + let mut requests = request_streams.lock().await; + requests.remove(&packet.id); + } + _ => {}, } },