feat: implement guest exec (#107)

This commit is contained in:
Alex Zenla 2024-04-22 13:13:43 -07:00 committed by GitHub
parent 82576df7b7
commit 284ed8f17b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 755 additions and 53 deletions

View File

@ -0,0 +1,70 @@
use std::collections::HashMap;
use anyhow::Result;
use clap::Parser;
use krata::v1::{
common::{GuestTaskSpec, GuestTaskSpecEnvVar},
control::{control_service_client::ControlServiceClient, ExecGuestRequest},
};
use tonic::{transport::Channel, Request};
use crate::console::StdioConsoleStream;
use super::resolve_guest;
#[derive(Parser)]
#[command(about = "Execute a command inside the guest")]
pub struct ExecCommand {
#[arg[short, long, help = "Environment variables"]]
env: Option<Vec<String>>,
#[arg(short = 'w', long, help = "Working directory")]
working_directory: Option<String>,
#[arg(help = "Guest to exec inside, either the name or the uuid")]
guest: String,
#[arg(
allow_hyphen_values = true,
trailing_var_arg = true,
help = "Command to run inside the guest"
)]
command: Vec<String>,
}
impl ExecCommand {
pub async fn run(self, mut client: ControlServiceClient<Channel>) -> Result<()> {
let guest_id: String = resolve_guest(&mut client, &self.guest).await?;
let initial = ExecGuestRequest {
guest_id,
task: Some(GuestTaskSpec {
environment: env_map(&self.env.unwrap_or_default())
.iter()
.map(|(key, value)| GuestTaskSpecEnvVar {
key: key.clone(),
value: value.clone(),
})
.collect(),
command: self.command,
working_directory: self.working_directory.unwrap_or_default(),
}),
data: vec![],
};
let stream = StdioConsoleStream::stdin_stream_exec(initial).await;
let response = client.exec_guest(Request::new(stream)).await?.into_inner();
let code = StdioConsoleStream::exec_output(response).await?;
std::process::exit(code);
}
}
fn env_map(env: &[String]) -> HashMap<String, String> {
let mut map = HashMap::<String, String>::new();
for item in env {
if let Some((key, value)) = item.split_once('=') {
map.insert(key.to_string(), value.to_string());
}
}
map
}

View File

@ -106,13 +106,19 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option<IdmSnoopLine> {
.ok()
.and_then(|event| proto2dynamic(event).ok()),
IdmTransportPacketForm::Request => internal::Request::decode(&packet.data)
.ok()
.and_then(|event| proto2dynamic(event).ok()),
IdmTransportPacketForm::Request
| IdmTransportPacketForm::StreamRequest
| IdmTransportPacketForm::StreamRequestUpdate => {
internal::Request::decode(&packet.data)
.ok()
.and_then(|event| proto2dynamic(event).ok())
}
IdmTransportPacketForm::Response => internal::Response::decode(&packet.data)
.ok()
.and_then(|event| proto2dynamic(event).ok()),
IdmTransportPacketForm::Response | IdmTransportPacketForm::StreamResponseUpdate => {
internal::Response::decode(&packet.data)
.ok()
.and_then(|event| proto2dynamic(event).ok())
}
_ => None,
}
@ -132,6 +138,11 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option<IdmSnoopLine> {
IdmTransportPacketForm::Event => "event".to_string(),
IdmTransportPacketForm::Request => "request".to_string(),
IdmTransportPacketForm::Response => "response".to_string(),
IdmTransportPacketForm::StreamRequest => "stream-request".to_string(),
IdmTransportPacketForm::StreamRequestUpdate => "stream-request-update".to_string(),
IdmTransportPacketForm::StreamRequestClosed => "stream-request-closed".to_string(),
IdmTransportPacketForm::StreamResponseUpdate => "stream-response-update".to_string(),
IdmTransportPacketForm::StreamResponseClosed => "stream-response-closed".to_string(),
_ => format!("unknown-{}", packet.form),
},
data: base64::prelude::BASE64_STANDARD.encode(&packet.data),

View File

@ -29,7 +29,7 @@ pub enum LaunchImageFormat {
#[derive(Parser)]
#[command(about = "Launch a new guest")]
pub struct LauchCommand {
pub struct LaunchCommand {
#[arg(long, default_value = "squashfs", help = "Image format")]
image_format: LaunchImageFormat,
#[arg(long, help = "Overwrite image cache on pull")]
@ -68,6 +68,8 @@ pub struct LauchCommand {
kernel: Option<String>,
#[arg(short = 'I', long, help = "OCI initrd image for guest to use")]
initrd: Option<String>,
#[arg(short = 'w', long, help = "Working directory")]
working_directory: Option<String>,
#[arg(help = "Container image for guest to use")]
oci: String,
#[arg(
@ -78,7 +80,7 @@ pub struct LauchCommand {
command: Vec<String>,
}
impl LauchCommand {
impl LaunchCommand {
pub async fn run(
self,
mut client: ControlServiceClient<Channel>,
@ -130,6 +132,7 @@ impl LauchCommand {
})
.collect(),
command: self.command,
working_directory: self.working_directory.unwrap_or_default(),
}),
annotations: vec![],
}),

View File

@ -1,5 +1,6 @@
pub mod attach;
pub mod destroy;
pub mod exec;
pub mod identify_host;
pub mod idm_snoop;
pub mod launch;
@ -21,10 +22,10 @@ use krata::{
use tonic::{transport::Channel, Request};
use self::{
attach::AttachCommand, destroy::DestroyCommand, identify_host::IdentifyHostCommand,
idm_snoop::IdmSnoopCommand, launch::LauchCommand, list::ListCommand, logs::LogsCommand,
metrics::MetricsCommand, pull::PullCommand, resolve::ResolveCommand, top::TopCommand,
watch::WatchCommand,
attach::AttachCommand, destroy::DestroyCommand, exec::ExecCommand,
identify_host::IdentifyHostCommand, idm_snoop::IdmSnoopCommand, launch::LaunchCommand,
list::ListCommand, logs::LogsCommand, metrics::MetricsCommand, pull::PullCommand,
resolve::ResolveCommand, top::TopCommand, watch::WatchCommand,
};
#[derive(Parser)]
@ -47,7 +48,7 @@ pub struct ControlCommand {
#[derive(Subcommand)]
pub enum Commands {
Launch(LauchCommand),
Launch(LaunchCommand),
Destroy(DestroyCommand),
List(ListCommand),
Attach(AttachCommand),
@ -59,6 +60,7 @@ pub enum Commands {
IdmSnoop(IdmSnoopCommand),
Top(TopCommand),
IdentifyHost(IdentifyHostCommand),
Exec(ExecCommand),
}
impl ControlCommand {
@ -114,6 +116,10 @@ impl ControlCommand {
Commands::IdentifyHost(identify) => {
identify.run(client).await?;
}
Commands::Exec(exec) => {
exec.run(client).await?;
}
}
Ok(())
}

View File

@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use async_stream::stream;
use crossterm::{
terminal::{disable_raw_mode, enable_raw_mode, is_raw_mode_enabled},
@ -8,12 +8,15 @@ use krata::{
events::EventStream,
v1::{
common::GuestStatus,
control::{watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest},
control::{
watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest, ExecGuestReply,
ExecGuestRequest,
},
},
};
use log::debug;
use tokio::{
io::{stdin, stdout, AsyncReadExt, AsyncWriteExt},
io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt},
task::JoinHandle,
};
use tokio_stream::{Stream, StreamExt};
@ -45,6 +48,31 @@ impl StdioConsoleStream {
}
}
pub async fn stdin_stream_exec(
initial: ExecGuestRequest,
) -> impl Stream<Item = ExecGuestRequest> {
let mut stdin = stdin();
stream! {
yield initial;
let mut buffer = vec![0u8; 60];
loop {
let size = match stdin.read(&mut buffer).await {
Ok(size) => size,
Err(error) => {
debug!("failed to read stdin: {}", error);
break;
}
};
let data = buffer[0..size].to_vec();
if size == 1 && buffer[0] == 0x1d {
break;
}
yield ExecGuestRequest { guest_id: String::default(), task: None, data };
}
}
}
pub async fn stdout(mut stream: Streaming<ConsoleDataReply>) -> Result<()> {
if stdin().is_tty() {
enable_raw_mode()?;
@ -62,6 +90,32 @@ impl StdioConsoleStream {
Ok(())
}
pub async fn exec_output(mut stream: Streaming<ExecGuestReply>) -> Result<i32> {
let mut stdout = stdout();
let mut stderr = stderr();
while let Some(reply) = stream.next().await {
let reply = reply?;
if !reply.stdout.is_empty() {
stdout.write_all(&reply.stdout).await?;
stdout.flush().await?;
}
if !reply.stderr.is_empty() {
stderr.write_all(&reply.stderr).await?;
stderr.flush().await?;
}
if reply.exited {
if reply.error.is_empty() {
return Ok(reply.exit_code);
} else {
return Err(anyhow!("exec failed: {}", reply.error));
}
}
}
Ok(-1)
}
pub async fn guest_exit_hook(
id: String,
events: EventStream,

View File

@ -2,18 +2,19 @@ use async_stream::try_stream;
use futures::Stream;
use krata::{
idm::internal::{
request::Request as IdmRequestType, response::Response as IdmResponseType, MetricsRequest,
Request as IdmRequest,
exec_stream_request_update::Update, request::Request as IdmRequestType,
response::Response as IdmResponseType, ExecEnvVar, ExecStreamRequestStart,
ExecStreamRequestStdin, ExecStreamRequestUpdate, MetricsRequest, Request as IdmRequest,
},
v1::{
common::{Guest, GuestState, GuestStatus, OciImageFormat},
control::{
control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest,
CreateGuestReply, CreateGuestRequest, DestroyGuestReply, DestroyGuestRequest,
IdentifyHostReply, IdentifyHostRequest, ListGuestsReply, ListGuestsRequest,
PullImageReply, PullImageRequest, ReadGuestMetricsReply, ReadGuestMetricsRequest,
ResolveGuestReply, ResolveGuestRequest, SnoopIdmReply, SnoopIdmRequest,
WatchEventsReply, WatchEventsRequest,
ExecGuestReply, ExecGuestRequest, IdentifyHostReply, IdentifyHostRequest,
ListGuestsReply, ListGuestsRequest, PullImageReply, PullImageRequest,
ReadGuestMetricsReply, ReadGuestMetricsRequest, ResolveGuestReply, ResolveGuestRequest,
SnoopIdmReply, SnoopIdmRequest, WatchEventsReply, WatchEventsRequest,
},
},
};
@ -101,6 +102,9 @@ enum PullImageSelect {
#[tonic::async_trait]
impl ControlService for DaemonControlService {
type ExecGuestStream =
Pin<Box<dyn Stream<Item = Result<ExecGuestReply, Status>> + Send + 'static>>;
type ConsoleDataStream =
Pin<Box<dyn Stream<Item = Result<ConsoleDataReply, Status>> + Send + 'static>>;
@ -166,6 +170,98 @@ impl ControlService for DaemonControlService {
}))
}
async fn exec_guest(
&self,
request: Request<Streaming<ExecGuestRequest>>,
) -> Result<Response<Self::ExecGuestStream>, Status> {
let mut input = request.into_inner();
let Some(request) = input.next().await else {
return Err(ApiError {
message: "expected to have at least one request".to_string(),
}
.into());
};
let request = request?;
let Some(task) = request.task else {
return Err(ApiError {
message: "task is missing".to_string(),
}
.into());
};
let uuid = Uuid::from_str(&request.guest_id).map_err(|error| ApiError {
message: error.to_string(),
})?;
let idm = self.idm.client(uuid).await.map_err(|error| ApiError {
message: error.to_string(),
})?;
let idm_request = IdmRequest {
request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate {
update: Some(Update::Start(ExecStreamRequestStart {
environment: task
.environment
.into_iter()
.map(|x| ExecEnvVar {
key: x.key,
value: x.value,
})
.collect(),
command: task.command,
working_directory: task.working_directory,
})),
})),
};
let output = try_stream! {
let mut handle = idm.send_stream(idm_request).await.map_err(|x| ApiError {
message: x.to_string(),
})?;
loop {
select! {
x = input.next() => if let Some(update) = x {
let update: Result<ExecGuestRequest, Status> = update.map_err(|error| ApiError {
message: error.to_string()
}.into());
if let Ok(update) = update {
if !update.data.is_empty() {
let _ = handle.update(IdmRequest {
request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate {
update: Some(Update::Stdin(ExecStreamRequestStdin {
data: update.data,
})),
}))}).await;
}
}
},
x = handle.receiver.recv() => match x {
Some(response) => {
let Some(IdmResponseType::ExecStream(update)) = response.response else {
break;
};
let reply = ExecGuestReply {
exited: update.exited,
error: update.error,
exit_code: update.exit_code,
stdout: update.stdout,
stderr: update.stderr
};
yield reply;
},
None => {
break;
}
}
};
}
};
Ok(Response::new(Box::pin(output) as Self::ExecGuestStream))
}
async fn destroy_guest(
&self,
request: Request<DestroyGuestRequest>,

View File

@ -1,16 +1,17 @@
use crate::{
childwait::{ChildEvent, ChildWait},
death,
exec::GuestExecTask,
metrics::MetricsCollector,
};
use anyhow::Result;
use cgroups_rs::Cgroup;
use krata::idm::{
client::IdmInternalClient,
client::{IdmClientStreamResponseHandle, IdmInternalClient},
internal::{
event::Event as EventType, request::Request as RequestType,
response::Response as ResponseType, Event, ExitEvent, MetricsResponse, PingResponse,
Request, Response,
response::Response as ResponseType, Event, ExecStreamResponseUpdate, ExitEvent,
MetricsResponse, PingResponse, Request, Response,
},
};
use log::debug;
@ -41,11 +42,11 @@ impl GuestBackground {
pub async fn run(&mut self) -> Result<()> {
let mut event_subscription = self.idm.subscribe().await?;
let mut requests_subscription = self.idm.requests().await?;
let mut request_streams_subscription = self.idm.request_streams().await?;
loop {
select! {
x = event_subscription.recv() => match x {
Ok(_event) => {
},
Err(broadcast::error::RecvError::Closed) => {
@ -73,6 +74,21 @@ impl GuestBackground {
}
},
x = request_streams_subscription.recv() => match x {
Ok(handle) => {
self.handle_idm_stream_request(handle).await?;
},
Err(broadcast::error::RecvError::Closed) => {
debug!("idm packet channel closed");
break;
},
_ => {
continue;
}
},
event = self.wait.recv() => match event {
Some(event) => self.child_event(event).await?,
None => {
@ -107,7 +123,33 @@ impl GuestBackground {
self.idm.respond(id, response).await?;
}
None => {}
_ => {}
}
Ok(())
}
async fn handle_idm_stream_request(
&mut self,
handle: IdmClientStreamResponseHandle<Request>,
) -> Result<()> {
if let Some(RequestType::ExecStream(_)) = &handle.initial.request {
tokio::task::spawn(async move {
let exec = GuestExecTask { handle };
if let Err(error) = exec.run().await {
let _ = exec
.handle
.respond(Response {
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
exited: true,
error: error.to_string(),
exit_code: -1,
stdout: vec![],
stderr: vec![],
})),
})
.await;
}
});
}
Ok(())
}

172
crates/guest/src/exec.rs Normal file
View File

@ -0,0 +1,172 @@
use std::{collections::HashMap, process::Stdio};
use anyhow::{anyhow, Result};
use krata::idm::{
client::IdmClientStreamResponseHandle,
internal::{
exec_stream_request_update::Update, request::Request as RequestType,
ExecStreamResponseUpdate,
},
internal::{response::Response as ResponseType, Request, Response},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
join,
process::Command,
};
pub struct GuestExecTask {
pub handle: IdmClientStreamResponseHandle<Request>,
}
impl GuestExecTask {
pub async fn run(&self) -> Result<()> {
let mut receiver = self.handle.take().await?;
let Some(ref request) = self.handle.initial.request else {
return Err(anyhow!("request was empty"));
};
let RequestType::ExecStream(update) = request else {
return Err(anyhow!("request was not an exec update"));
};
let Some(Update::Start(ref start)) = update.update else {
return Err(anyhow!("first request did not contain a start update"));
};
let mut cmd = start.command.clone();
if cmd.is_empty() {
return Err(anyhow!("command line was empty"));
}
let exe = cmd.remove(0);
let mut env = HashMap::new();
for entry in &start.environment {
env.insert(entry.key.clone(), entry.value.clone());
}
if !env.contains_key("PATH") {
env.insert(
"PATH".to_string(),
"/bin:/usr/bin:/usr/local/bin".to_string(),
);
}
let dir = if start.working_directory.is_empty() {
"/".to_string()
} else {
start.working_directory.clone()
};
let mut child = Command::new(exe)
.args(cmd)
.envs(env)
.current_dir(dir)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|error| anyhow!("failed to spawn: {}", error))?;
let mut stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("stdin was missing"))?;
let mut stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("stdout was missing"))?;
let mut stderr = child
.stderr
.take()
.ok_or_else(|| anyhow!("stderr was missing"))?;
let stdout_handle = self.handle.clone();
let stdout_task = tokio::task::spawn(async move {
let mut stdout_buffer = vec![0u8; 8 * 1024];
loop {
let Ok(size) = stdout.read(&mut stdout_buffer).await else {
break;
};
if size > 0 {
let response = Response {
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
exited: false,
exit_code: 0,
error: String::new(),
stdout: stdout_buffer[0..size].to_vec(),
stderr: vec![],
})),
};
let _ = stdout_handle.respond(response).await;
} else {
break;
}
}
});
let stderr_handle = self.handle.clone();
let stderr_task = tokio::task::spawn(async move {
let mut stderr_buffer = vec![0u8; 8 * 1024];
loop {
let Ok(size) = stderr.read(&mut stderr_buffer).await else {
break;
};
if size > 0 {
let response = Response {
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
exited: false,
exit_code: 0,
error: String::new(),
stdout: vec![],
stderr: stderr_buffer[0..size].to_vec(),
})),
};
let _ = stderr_handle.respond(response).await;
} else {
break;
}
}
});
let stdin_task = tokio::task::spawn(async move {
loop {
let Some(request) = receiver.recv().await else {
break;
};
let Some(RequestType::ExecStream(update)) = request.request else {
continue;
};
let Some(Update::Stdin(update)) = update.update else {
continue;
};
if stdin.write_all(&update.data).await.is_err() {
break;
}
}
});
let exit = child.wait().await?;
let code = exit.code().unwrap_or(-1);
let _ = join!(stdout_task, stderr_task);
stdin_task.abort();
let response = Response {
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
exited: true,
exit_code: code,
error: String::new(),
stdout: vec![],
stderr: vec![],
})),
};
self.handle.respond(response).await?;
Ok(())
}
}

View File

@ -479,7 +479,7 @@ impl GuestInit {
env.insert("TERM".to_string(), "xterm".to_string());
}
let path = GuestInit::resolve_executable(&env, path.into())?;
let path = resolve_executable(&env, path.into())?;
let Some(file_name) = path.file_name() else {
return Err(anyhow!("cannot get file name of command path"));
};
@ -537,27 +537,6 @@ impl GuestInit {
map
}
fn resolve_executable(env: &HashMap<String, String>, path: PathBuf) -> Result<PathBuf> {
if path.is_absolute() {
return Ok(path);
}
if path.is_file() {
return Ok(path.absolutize()?.to_path_buf());
}
if let Some(path_var) = env.get("PATH") {
for item in path_var.split(':') {
let mut exe_path: PathBuf = item.into();
exe_path.push(&path);
if exe_path.is_file() {
return Ok(exe_path);
}
}
}
Ok(path)
}
fn env_list(env: HashMap<String, String>) -> Vec<String> {
env.iter()
.map(|(key, value)| format!("{}={}", key, value))
@ -613,3 +592,24 @@ impl GuestInit {
Ok(())
}
}
pub fn resolve_executable(env: &HashMap<String, String>, path: PathBuf) -> Result<PathBuf> {
if path.is_absolute() {
return Ok(path);
}
if path.is_file() {
return Ok(path.absolutize()?.to_path_buf());
}
if let Some(path_var) = env.get("PATH") {
for item in path_var.split(':') {
let mut exe_path: PathBuf = item.into();
exe_path.push(&path);
if exe_path.is_file() {
return Ok(exe_path);
}
}
}
Ok(path)
}

View File

@ -6,6 +6,7 @@ use xenstore::{XsdClient, XsdInterface};
pub mod background;
pub mod childwait;
pub mod exec;
pub mod init;
pub mod metrics;

View File

@ -36,6 +36,36 @@ enum MetricFormat {
METRIC_FORMAT_DURATION_SECONDS = 3;
}
message ExecEnvVar {
string key = 1;
string value = 2;
}
message ExecStreamRequestStart {
repeated ExecEnvVar environment = 1;
repeated string command = 2;
string working_directory = 3;
}
message ExecStreamRequestStdin {
bytes data = 1;
}
message ExecStreamRequestUpdate {
oneof update {
ExecStreamRequestStart start = 1;
ExecStreamRequestStdin stdin = 2;
}
}
message ExecStreamResponseUpdate {
bool exited = 1;
string error = 2;
int32 exit_code = 3;
bytes stdout = 4;
bytes stderr = 5;
}
message Event {
oneof event {
ExitEvent exit = 1;
@ -46,6 +76,7 @@ message Request {
oneof request {
PingRequest ping = 1;
MetricsRequest metrics = 2;
ExecStreamRequestUpdate exec_stream = 3;
}
}
@ -53,5 +84,6 @@ message Response {
oneof response {
PingResponse ping = 1;
MetricsResponse metrics = 2;
ExecStreamResponseUpdate exec_stream = 3;
}
}

View File

@ -19,4 +19,9 @@ enum IdmTransportPacketForm {
IDM_TRANSPORT_PACKET_FORM_EVENT = 2;
IDM_TRANSPORT_PACKET_FORM_REQUEST = 3;
IDM_TRANSPORT_PACKET_FORM_RESPONSE = 4;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST = 5;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_UPDATE = 6;
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_UPDATE = 7;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_CLOSED = 8;
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_CLOSED = 9;
}

View File

@ -49,6 +49,7 @@ message GuestOciImageSpec {
message GuestTaskSpec {
repeated GuestTaskSpecEnvVar environment = 1;
repeated string command = 2;
string working_directory = 3;
}
message GuestTaskSpecEnvVar {

View File

@ -17,6 +17,8 @@ service ControlService {
rpc ResolveGuest(ResolveGuestRequest) returns (ResolveGuestReply);
rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply);
rpc ExecGuest(stream ExecGuestRequest) returns (stream ExecGuestReply);
rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply);
rpc ReadGuestMetrics(ReadGuestMetricsRequest) returns (ReadGuestMetricsReply);
@ -62,6 +64,20 @@ message ListGuestsReply {
repeated krata.v1.common.Guest guests = 1;
}
message ExecGuestRequest {
string guest_id = 1;
krata.v1.common.GuestTaskSpec task = 2;
bytes data = 3;
}
message ExecGuestReply {
bool exited = 1;
string error = 2;
int32 exit_code = 3;
bytes stdout = 4;
bytes stderr = 5;
}
message ConsoleDataRequest {
string guest_id = 1;
bytes data = 2;

View File

@ -31,7 +31,9 @@ use super::{
transport::{IdmTransportPacket, IdmTransportPacketForm},
};
type RequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
type OneshotRequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
type StreamRequestMap<R> = Arc<Mutex<HashMap<u64, Sender<<R as IdmRequest>::Response>>>>;
type StreamRequestUpdateMap<R> = Arc<Mutex<HashMap<u64, mpsc::Sender<R>>>>;
pub type IdmInternalClient = IdmClient<internal::Request, internal::Event>;
const IDM_PACKET_QUEUE_LEN: usize = 100;
@ -106,10 +108,12 @@ impl IdmBackend for IdmFileBackend {
pub struct IdmClient<R: IdmRequest, E: IdmSerializable> {
channel: u64,
request_backend_sender: broadcast::Sender<(u64, R)>,
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
next_request_id: Arc<Mutex<u64>>,
event_receiver_sender: broadcast::Sender<E>,
tx_sender: Sender<IdmTransportPacket>,
requests: RequestMap<R>,
requests: OneshotRequestMap<R>,
request_streams: StreamRequestMap<R>,
task: Arc<JoinHandle<()>>,
}
@ -121,21 +125,122 @@ impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClient<R, E> {
}
}
pub struct IdmClientStreamRequestHandle<R: IdmRequest, E: IdmSerializable> {
pub id: u64,
pub receiver: Receiver<R::Response>,
pub client: IdmClient<R, E>,
}
impl<R: IdmRequest, E: IdmSerializable> IdmClientStreamRequestHandle<R, E> {
pub async fn update(&self, request: R) -> Result<()> {
self.client
.tx_sender
.send(IdmTransportPacket {
id: self.id,
channel: self.client.channel,
form: IdmTransportPacketForm::StreamRequestUpdate.into(),
data: request.encode()?,
})
.await?;
Ok(())
}
}
impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClientStreamRequestHandle<R, E> {
fn drop(&mut self) {
let id = self.id;
let client = self.client.clone();
tokio::task::spawn(async move {
let _ = client
.tx_sender
.send(IdmTransportPacket {
id,
channel: client.channel,
form: IdmTransportPacketForm::StreamRequestClosed.into(),
data: vec![],
})
.await;
});
}
}
#[derive(Clone)]
pub struct IdmClientStreamResponseHandle<R: IdmRequest> {
pub initial: R,
pub id: u64,
channel: u64,
tx_sender: Sender<IdmTransportPacket>,
receiver: Arc<Mutex<Option<Receiver<R>>>>,
}
impl<R: IdmRequest> IdmClientStreamResponseHandle<R> {
pub async fn respond(&self, response: R::Response) -> Result<()> {
self.tx_sender
.send(IdmTransportPacket {
id: self.id,
channel: self.channel,
form: IdmTransportPacketForm::StreamResponseUpdate.into(),
data: response.encode()?,
})
.await?;
Ok(())
}
pub async fn take(&self) -> Result<Receiver<R>> {
let mut guard = self.receiver.lock().await;
let Some(receiver) = (*guard).take() else {
return Err(anyhow!("request has already been claimed!"));
};
Ok(receiver)
}
}
impl<R: IdmRequest> Drop for IdmClientStreamResponseHandle<R> {
fn drop(&mut self) {
if Arc::strong_count(&self.receiver) <= 1 {
let id = self.id;
let channel = self.channel;
let tx_sender = self.tx_sender.clone();
tokio::task::spawn(async move {
let _ = tx_sender
.send(IdmTransportPacket {
id,
channel,
form: IdmTransportPacketForm::StreamResponseClosed.into(),
data: vec![],
})
.await;
});
}
}
}
impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
pub async fn new(channel: u64, backend: Box<dyn IdmBackend>) -> Result<Self> {
let requests = Arc::new(Mutex::new(HashMap::new()));
let request_streams = Arc::new(Mutex::new(HashMap::new()));
let request_update_streams = Arc::new(Mutex::new(HashMap::new()));
let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (internal_request_stream_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (tx_sender, tx_receiver) = mpsc::channel(IDM_PACKET_QUEUE_LEN);
let backend_event_sender = event_sender.clone();
let request_backend_sender = internal_request_backend_sender.clone();
let request_stream_backend_sender = internal_request_stream_backend_sender.clone();
let requests_for_client = requests.clone();
let request_streams_for_client = request_streams.clone();
let tx_sender_for_client = tx_sender.clone();
let task = tokio::task::spawn(async move {
if let Err(error) = IdmClient::process(
backend,
channel,
tx_sender,
backend_event_sender,
requests,
request_streams,
request_update_streams,
internal_request_backend_sender,
internal_request_stream_backend_sender,
event_receiver,
tx_receiver,
)
@ -149,8 +254,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
next_request_id: Arc::new(Mutex::new(0)),
event_receiver_sender: event_sender.clone(),
request_backend_sender,
request_stream_backend_sender,
requests: requests_for_client,
tx_sender,
request_streams: request_streams_for_client,
tx_sender: tx_sender_for_client,
task: Arc::new(task),
})
}
@ -194,6 +301,12 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
Ok(self.request_backend_sender.subscribe())
}
pub async fn request_streams(
&self,
) -> Result<broadcast::Receiver<IdmClientStreamResponseHandle<R>>> {
Ok(self.request_stream_backend_sender.subscribe())
}
pub async fn respond<T: IdmSerializable>(&self, id: u64, response: T) -> Result<()> {
let packet = IdmTransportPacket {
id,
@ -244,11 +357,43 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
Ok(response)
}
pub async fn send_stream(&self, request: R) -> Result<IdmClientStreamRequestHandle<R, E>> {
let (sender, receiver) = mpsc::channel::<R::Response>(100);
let req = {
let mut guard = self.next_request_id.lock().await;
let req = *guard;
*guard = req.wrapping_add(1);
req
};
let mut requests = self.request_streams.lock().await;
requests.insert(req, sender);
drop(requests);
self.tx_sender
.send(IdmTransportPacket {
id: req,
channel: self.channel,
form: IdmTransportPacketForm::StreamRequest.into(),
data: request.encode()?,
})
.await?;
Ok(IdmClientStreamRequestHandle {
id: req,
receiver,
client: self.clone(),
})
}
#[allow(clippy::too_many_arguments)]
async fn process(
mut backend: Box<dyn IdmBackend>,
channel: u64,
tx_sender: Sender<IdmTransportPacket>,
event_sender: broadcast::Sender<E>,
requests: RequestMap<R>,
requests: OneshotRequestMap<R>,
request_streams: StreamRequestMap<R>,
request_update_streams: StreamRequestUpdateMap<R>,
request_backend_sender: broadcast::Sender<(u64, R)>,
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
_event_receiver: broadcast::Receiver<E>,
mut receiver: Receiver<IdmTransportPacket>,
) -> Result<()> {
@ -256,6 +401,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
select! {
x = backend.recv() => match x {
Ok(packet) => {
if packet.channel != channel {
continue;
}
match packet.form() {
IdmTransportPacketForm::Event => {
if let Ok(event) = E::decode(&packet.data) {
@ -280,6 +429,50 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
}
},
IdmTransportPacketForm::StreamRequest => {
if let Ok(request) = R::decode(&packet.data) {
let mut update_streams = request_update_streams.lock().await;
let (sender, receiver) = mpsc::channel(100);
update_streams.insert(packet.id, sender.clone());
let handle = IdmClientStreamResponseHandle {
initial: request,
id: packet.id,
channel,
tx_sender: tx_sender.clone(),
receiver: Arc::new(Mutex::new(Some(receiver))),
};
let _ = request_stream_backend_sender.send(handle);
}
}
IdmTransportPacketForm::StreamRequestUpdate => {
if let Ok(request) = R::decode(&packet.data) {
let mut update_streams = request_update_streams.lock().await;
if let Some(stream) = update_streams.get_mut(&packet.id) {
let _ = stream.try_send(request);
}
}
}
IdmTransportPacketForm::StreamRequestClosed => {
let mut update_streams = request_update_streams.lock().await;
update_streams.remove(&packet.id);
}
IdmTransportPacketForm::StreamResponseUpdate => {
let requests = request_streams.lock().await;
if let Some(sender) = requests.get(&packet.id) {
if let Ok(response) = R::Response::decode(&packet.data) {
let _ = sender.try_send(response);
}
}
}
IdmTransportPacketForm::StreamResponseClosed => {
let mut requests = request_streams.lock().await;
requests.remove(&packet.id);
}
_ => {},
}
},