mirror of
https://github.com/edera-dev/krata.git
synced 2025-08-03 05:10:55 +00:00
feat: implement guest exec (#107)
This commit is contained in:
parent
82576df7b7
commit
284ed8f17b
70
crates/ctl/src/cli/exec.rs
Normal file
70
crates/ctl/src/cli/exec.rs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
use krata::v1::{
|
||||||
|
common::{GuestTaskSpec, GuestTaskSpecEnvVar},
|
||||||
|
control::{control_service_client::ControlServiceClient, ExecGuestRequest},
|
||||||
|
};
|
||||||
|
|
||||||
|
use tonic::{transport::Channel, Request};
|
||||||
|
|
||||||
|
use crate::console::StdioConsoleStream;
|
||||||
|
|
||||||
|
use super::resolve_guest;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(about = "Execute a command inside the guest")]
|
||||||
|
pub struct ExecCommand {
|
||||||
|
#[arg[short, long, help = "Environment variables"]]
|
||||||
|
env: Option<Vec<String>>,
|
||||||
|
#[arg(short = 'w', long, help = "Working directory")]
|
||||||
|
working_directory: Option<String>,
|
||||||
|
#[arg(help = "Guest to exec inside, either the name or the uuid")]
|
||||||
|
guest: String,
|
||||||
|
#[arg(
|
||||||
|
allow_hyphen_values = true,
|
||||||
|
trailing_var_arg = true,
|
||||||
|
help = "Command to run inside the guest"
|
||||||
|
)]
|
||||||
|
command: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecCommand {
|
||||||
|
pub async fn run(self, mut client: ControlServiceClient<Channel>) -> Result<()> {
|
||||||
|
let guest_id: String = resolve_guest(&mut client, &self.guest).await?;
|
||||||
|
let initial = ExecGuestRequest {
|
||||||
|
guest_id,
|
||||||
|
task: Some(GuestTaskSpec {
|
||||||
|
environment: env_map(&self.env.unwrap_or_default())
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| GuestTaskSpecEnvVar {
|
||||||
|
key: key.clone(),
|
||||||
|
value: value.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
command: self.command,
|
||||||
|
working_directory: self.working_directory.unwrap_or_default(),
|
||||||
|
}),
|
||||||
|
data: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = StdioConsoleStream::stdin_stream_exec(initial).await;
|
||||||
|
|
||||||
|
let response = client.exec_guest(Request::new(stream)).await?.into_inner();
|
||||||
|
|
||||||
|
let code = StdioConsoleStream::exec_output(response).await?;
|
||||||
|
std::process::exit(code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn env_map(env: &[String]) -> HashMap<String, String> {
|
||||||
|
let mut map = HashMap::<String, String>::new();
|
||||||
|
for item in env {
|
||||||
|
if let Some((key, value)) = item.split_once('=') {
|
||||||
|
map.insert(key.to_string(), value.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map
|
||||||
|
}
|
@ -106,13 +106,19 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option<IdmSnoopLine> {
|
|||||||
.ok()
|
.ok()
|
||||||
.and_then(|event| proto2dynamic(event).ok()),
|
.and_then(|event| proto2dynamic(event).ok()),
|
||||||
|
|
||||||
IdmTransportPacketForm::Request => internal::Request::decode(&packet.data)
|
IdmTransportPacketForm::Request
|
||||||
.ok()
|
| IdmTransportPacketForm::StreamRequest
|
||||||
.and_then(|event| proto2dynamic(event).ok()),
|
| IdmTransportPacketForm::StreamRequestUpdate => {
|
||||||
|
internal::Request::decode(&packet.data)
|
||||||
|
.ok()
|
||||||
|
.and_then(|event| proto2dynamic(event).ok())
|
||||||
|
}
|
||||||
|
|
||||||
IdmTransportPacketForm::Response => internal::Response::decode(&packet.data)
|
IdmTransportPacketForm::Response | IdmTransportPacketForm::StreamResponseUpdate => {
|
||||||
.ok()
|
internal::Response::decode(&packet.data)
|
||||||
.and_then(|event| proto2dynamic(event).ok()),
|
.ok()
|
||||||
|
.and_then(|event| proto2dynamic(event).ok())
|
||||||
|
}
|
||||||
|
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
@ -132,6 +138,11 @@ pub fn convert_idm_snoop(reply: SnoopIdmReply) -> Option<IdmSnoopLine> {
|
|||||||
IdmTransportPacketForm::Event => "event".to_string(),
|
IdmTransportPacketForm::Event => "event".to_string(),
|
||||||
IdmTransportPacketForm::Request => "request".to_string(),
|
IdmTransportPacketForm::Request => "request".to_string(),
|
||||||
IdmTransportPacketForm::Response => "response".to_string(),
|
IdmTransportPacketForm::Response => "response".to_string(),
|
||||||
|
IdmTransportPacketForm::StreamRequest => "stream-request".to_string(),
|
||||||
|
IdmTransportPacketForm::StreamRequestUpdate => "stream-request-update".to_string(),
|
||||||
|
IdmTransportPacketForm::StreamRequestClosed => "stream-request-closed".to_string(),
|
||||||
|
IdmTransportPacketForm::StreamResponseUpdate => "stream-response-update".to_string(),
|
||||||
|
IdmTransportPacketForm::StreamResponseClosed => "stream-response-closed".to_string(),
|
||||||
_ => format!("unknown-{}", packet.form),
|
_ => format!("unknown-{}", packet.form),
|
||||||
},
|
},
|
||||||
data: base64::prelude::BASE64_STANDARD.encode(&packet.data),
|
data: base64::prelude::BASE64_STANDARD.encode(&packet.data),
|
||||||
|
@ -29,7 +29,7 @@ pub enum LaunchImageFormat {
|
|||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(about = "Launch a new guest")]
|
#[command(about = "Launch a new guest")]
|
||||||
pub struct LauchCommand {
|
pub struct LaunchCommand {
|
||||||
#[arg(long, default_value = "squashfs", help = "Image format")]
|
#[arg(long, default_value = "squashfs", help = "Image format")]
|
||||||
image_format: LaunchImageFormat,
|
image_format: LaunchImageFormat,
|
||||||
#[arg(long, help = "Overwrite image cache on pull")]
|
#[arg(long, help = "Overwrite image cache on pull")]
|
||||||
@ -68,6 +68,8 @@ pub struct LauchCommand {
|
|||||||
kernel: Option<String>,
|
kernel: Option<String>,
|
||||||
#[arg(short = 'I', long, help = "OCI initrd image for guest to use")]
|
#[arg(short = 'I', long, help = "OCI initrd image for guest to use")]
|
||||||
initrd: Option<String>,
|
initrd: Option<String>,
|
||||||
|
#[arg(short = 'w', long, help = "Working directory")]
|
||||||
|
working_directory: Option<String>,
|
||||||
#[arg(help = "Container image for guest to use")]
|
#[arg(help = "Container image for guest to use")]
|
||||||
oci: String,
|
oci: String,
|
||||||
#[arg(
|
#[arg(
|
||||||
@ -78,7 +80,7 @@ pub struct LauchCommand {
|
|||||||
command: Vec<String>,
|
command: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LauchCommand {
|
impl LaunchCommand {
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
self,
|
self,
|
||||||
mut client: ControlServiceClient<Channel>,
|
mut client: ControlServiceClient<Channel>,
|
||||||
@ -130,6 +132,7 @@ impl LauchCommand {
|
|||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
command: self.command,
|
command: self.command,
|
||||||
|
working_directory: self.working_directory.unwrap_or_default(),
|
||||||
}),
|
}),
|
||||||
annotations: vec![],
|
annotations: vec![],
|
||||||
}),
|
}),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
pub mod attach;
|
pub mod attach;
|
||||||
pub mod destroy;
|
pub mod destroy;
|
||||||
|
pub mod exec;
|
||||||
pub mod identify_host;
|
pub mod identify_host;
|
||||||
pub mod idm_snoop;
|
pub mod idm_snoop;
|
||||||
pub mod launch;
|
pub mod launch;
|
||||||
@ -21,10 +22,10 @@ use krata::{
|
|||||||
use tonic::{transport::Channel, Request};
|
use tonic::{transport::Channel, Request};
|
||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
attach::AttachCommand, destroy::DestroyCommand, identify_host::IdentifyHostCommand,
|
attach::AttachCommand, destroy::DestroyCommand, exec::ExecCommand,
|
||||||
idm_snoop::IdmSnoopCommand, launch::LauchCommand, list::ListCommand, logs::LogsCommand,
|
identify_host::IdentifyHostCommand, idm_snoop::IdmSnoopCommand, launch::LaunchCommand,
|
||||||
metrics::MetricsCommand, pull::PullCommand, resolve::ResolveCommand, top::TopCommand,
|
list::ListCommand, logs::LogsCommand, metrics::MetricsCommand, pull::PullCommand,
|
||||||
watch::WatchCommand,
|
resolve::ResolveCommand, top::TopCommand, watch::WatchCommand,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@ -47,7 +48,7 @@ pub struct ControlCommand {
|
|||||||
|
|
||||||
#[derive(Subcommand)]
|
#[derive(Subcommand)]
|
||||||
pub enum Commands {
|
pub enum Commands {
|
||||||
Launch(LauchCommand),
|
Launch(LaunchCommand),
|
||||||
Destroy(DestroyCommand),
|
Destroy(DestroyCommand),
|
||||||
List(ListCommand),
|
List(ListCommand),
|
||||||
Attach(AttachCommand),
|
Attach(AttachCommand),
|
||||||
@ -59,6 +60,7 @@ pub enum Commands {
|
|||||||
IdmSnoop(IdmSnoopCommand),
|
IdmSnoop(IdmSnoopCommand),
|
||||||
Top(TopCommand),
|
Top(TopCommand),
|
||||||
IdentifyHost(IdentifyHostCommand),
|
IdentifyHost(IdentifyHostCommand),
|
||||||
|
Exec(ExecCommand),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ControlCommand {
|
impl ControlCommand {
|
||||||
@ -114,6 +116,10 @@ impl ControlCommand {
|
|||||||
Commands::IdentifyHost(identify) => {
|
Commands::IdentifyHost(identify) => {
|
||||||
identify.run(client).await?;
|
identify.run(client).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Commands::Exec(exec) => {
|
||||||
|
exec.run(client).await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use anyhow::Result;
|
use anyhow::{anyhow, Result};
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use crossterm::{
|
use crossterm::{
|
||||||
terminal::{disable_raw_mode, enable_raw_mode, is_raw_mode_enabled},
|
terminal::{disable_raw_mode, enable_raw_mode, is_raw_mode_enabled},
|
||||||
@ -8,12 +8,15 @@ use krata::{
|
|||||||
events::EventStream,
|
events::EventStream,
|
||||||
v1::{
|
v1::{
|
||||||
common::GuestStatus,
|
common::GuestStatus,
|
||||||
control::{watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest},
|
control::{
|
||||||
|
watch_events_reply::Event, ConsoleDataReply, ConsoleDataRequest, ExecGuestReply,
|
||||||
|
ExecGuestRequest,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{stdin, stdout, AsyncReadExt, AsyncWriteExt},
|
io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt},
|
||||||
task::JoinHandle,
|
task::JoinHandle,
|
||||||
};
|
};
|
||||||
use tokio_stream::{Stream, StreamExt};
|
use tokio_stream::{Stream, StreamExt};
|
||||||
@ -45,6 +48,31 @@ impl StdioConsoleStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn stdin_stream_exec(
|
||||||
|
initial: ExecGuestRequest,
|
||||||
|
) -> impl Stream<Item = ExecGuestRequest> {
|
||||||
|
let mut stdin = stdin();
|
||||||
|
stream! {
|
||||||
|
yield initial;
|
||||||
|
|
||||||
|
let mut buffer = vec![0u8; 60];
|
||||||
|
loop {
|
||||||
|
let size = match stdin.read(&mut buffer).await {
|
||||||
|
Ok(size) => size,
|
||||||
|
Err(error) => {
|
||||||
|
debug!("failed to read stdin: {}", error);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let data = buffer[0..size].to_vec();
|
||||||
|
if size == 1 && buffer[0] == 0x1d {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
yield ExecGuestRequest { guest_id: String::default(), task: None, data };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn stdout(mut stream: Streaming<ConsoleDataReply>) -> Result<()> {
|
pub async fn stdout(mut stream: Streaming<ConsoleDataReply>) -> Result<()> {
|
||||||
if stdin().is_tty() {
|
if stdin().is_tty() {
|
||||||
enable_raw_mode()?;
|
enable_raw_mode()?;
|
||||||
@ -62,6 +90,32 @@ impl StdioConsoleStream {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn exec_output(mut stream: Streaming<ExecGuestReply>) -> Result<i32> {
|
||||||
|
let mut stdout = stdout();
|
||||||
|
let mut stderr = stderr();
|
||||||
|
while let Some(reply) = stream.next().await {
|
||||||
|
let reply = reply?;
|
||||||
|
if !reply.stdout.is_empty() {
|
||||||
|
stdout.write_all(&reply.stdout).await?;
|
||||||
|
stdout.flush().await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reply.stderr.is_empty() {
|
||||||
|
stderr.write_all(&reply.stderr).await?;
|
||||||
|
stderr.flush().await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if reply.exited {
|
||||||
|
if reply.error.is_empty() {
|
||||||
|
return Ok(reply.exit_code);
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!("exec failed: {}", reply.error));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(-1)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn guest_exit_hook(
|
pub async fn guest_exit_hook(
|
||||||
id: String,
|
id: String,
|
||||||
events: EventStream,
|
events: EventStream,
|
||||||
|
@ -2,18 +2,19 @@ use async_stream::try_stream;
|
|||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use krata::{
|
use krata::{
|
||||||
idm::internal::{
|
idm::internal::{
|
||||||
request::Request as IdmRequestType, response::Response as IdmResponseType, MetricsRequest,
|
exec_stream_request_update::Update, request::Request as IdmRequestType,
|
||||||
Request as IdmRequest,
|
response::Response as IdmResponseType, ExecEnvVar, ExecStreamRequestStart,
|
||||||
|
ExecStreamRequestStdin, ExecStreamRequestUpdate, MetricsRequest, Request as IdmRequest,
|
||||||
},
|
},
|
||||||
v1::{
|
v1::{
|
||||||
common::{Guest, GuestState, GuestStatus, OciImageFormat},
|
common::{Guest, GuestState, GuestStatus, OciImageFormat},
|
||||||
control::{
|
control::{
|
||||||
control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest,
|
control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest,
|
||||||
CreateGuestReply, CreateGuestRequest, DestroyGuestReply, DestroyGuestRequest,
|
CreateGuestReply, CreateGuestRequest, DestroyGuestReply, DestroyGuestRequest,
|
||||||
IdentifyHostReply, IdentifyHostRequest, ListGuestsReply, ListGuestsRequest,
|
ExecGuestReply, ExecGuestRequest, IdentifyHostReply, IdentifyHostRequest,
|
||||||
PullImageReply, PullImageRequest, ReadGuestMetricsReply, ReadGuestMetricsRequest,
|
ListGuestsReply, ListGuestsRequest, PullImageReply, PullImageRequest,
|
||||||
ResolveGuestReply, ResolveGuestRequest, SnoopIdmReply, SnoopIdmRequest,
|
ReadGuestMetricsReply, ReadGuestMetricsRequest, ResolveGuestReply, ResolveGuestRequest,
|
||||||
WatchEventsReply, WatchEventsRequest,
|
SnoopIdmReply, SnoopIdmRequest, WatchEventsReply, WatchEventsRequest,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -101,6 +102,9 @@ enum PullImageSelect {
|
|||||||
|
|
||||||
#[tonic::async_trait]
|
#[tonic::async_trait]
|
||||||
impl ControlService for DaemonControlService {
|
impl ControlService for DaemonControlService {
|
||||||
|
type ExecGuestStream =
|
||||||
|
Pin<Box<dyn Stream<Item = Result<ExecGuestReply, Status>> + Send + 'static>>;
|
||||||
|
|
||||||
type ConsoleDataStream =
|
type ConsoleDataStream =
|
||||||
Pin<Box<dyn Stream<Item = Result<ConsoleDataReply, Status>> + Send + 'static>>;
|
Pin<Box<dyn Stream<Item = Result<ConsoleDataReply, Status>> + Send + 'static>>;
|
||||||
|
|
||||||
@ -166,6 +170,98 @@ impl ControlService for DaemonControlService {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn exec_guest(
|
||||||
|
&self,
|
||||||
|
request: Request<Streaming<ExecGuestRequest>>,
|
||||||
|
) -> Result<Response<Self::ExecGuestStream>, Status> {
|
||||||
|
let mut input = request.into_inner();
|
||||||
|
let Some(request) = input.next().await else {
|
||||||
|
return Err(ApiError {
|
||||||
|
message: "expected to have at least one request".to_string(),
|
||||||
|
}
|
||||||
|
.into());
|
||||||
|
};
|
||||||
|
let request = request?;
|
||||||
|
|
||||||
|
let Some(task) = request.task else {
|
||||||
|
return Err(ApiError {
|
||||||
|
message: "task is missing".to_string(),
|
||||||
|
}
|
||||||
|
.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let uuid = Uuid::from_str(&request.guest_id).map_err(|error| ApiError {
|
||||||
|
message: error.to_string(),
|
||||||
|
})?;
|
||||||
|
let idm = self.idm.client(uuid).await.map_err(|error| ApiError {
|
||||||
|
message: error.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let idm_request = IdmRequest {
|
||||||
|
request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate {
|
||||||
|
update: Some(Update::Start(ExecStreamRequestStart {
|
||||||
|
environment: task
|
||||||
|
.environment
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| ExecEnvVar {
|
||||||
|
key: x.key,
|
||||||
|
value: x.value,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
command: task.command,
|
||||||
|
working_directory: task.working_directory,
|
||||||
|
})),
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = try_stream! {
|
||||||
|
let mut handle = idm.send_stream(idm_request).await.map_err(|x| ApiError {
|
||||||
|
message: x.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
x = input.next() => if let Some(update) = x {
|
||||||
|
let update: Result<ExecGuestRequest, Status> = update.map_err(|error| ApiError {
|
||||||
|
message: error.to_string()
|
||||||
|
}.into());
|
||||||
|
|
||||||
|
if let Ok(update) = update {
|
||||||
|
if !update.data.is_empty() {
|
||||||
|
let _ = handle.update(IdmRequest {
|
||||||
|
request: Some(IdmRequestType::ExecStream(ExecStreamRequestUpdate {
|
||||||
|
update: Some(Update::Stdin(ExecStreamRequestStdin {
|
||||||
|
data: update.data,
|
||||||
|
})),
|
||||||
|
}))}).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
x = handle.receiver.recv() => match x {
|
||||||
|
Some(response) => {
|
||||||
|
let Some(IdmResponseType::ExecStream(update)) = response.response else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
let reply = ExecGuestReply {
|
||||||
|
exited: update.exited,
|
||||||
|
error: update.error,
|
||||||
|
exit_code: update.exit_code,
|
||||||
|
stdout: update.stdout,
|
||||||
|
stderr: update.stderr
|
||||||
|
};
|
||||||
|
yield reply;
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Response::new(Box::pin(output) as Self::ExecGuestStream))
|
||||||
|
}
|
||||||
|
|
||||||
async fn destroy_guest(
|
async fn destroy_guest(
|
||||||
&self,
|
&self,
|
||||||
request: Request<DestroyGuestRequest>,
|
request: Request<DestroyGuestRequest>,
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
childwait::{ChildEvent, ChildWait},
|
childwait::{ChildEvent, ChildWait},
|
||||||
death,
|
death,
|
||||||
|
exec::GuestExecTask,
|
||||||
metrics::MetricsCollector,
|
metrics::MetricsCollector,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use cgroups_rs::Cgroup;
|
use cgroups_rs::Cgroup;
|
||||||
use krata::idm::{
|
use krata::idm::{
|
||||||
client::IdmInternalClient,
|
client::{IdmClientStreamResponseHandle, IdmInternalClient},
|
||||||
internal::{
|
internal::{
|
||||||
event::Event as EventType, request::Request as RequestType,
|
event::Event as EventType, request::Request as RequestType,
|
||||||
response::Response as ResponseType, Event, ExitEvent, MetricsResponse, PingResponse,
|
response::Response as ResponseType, Event, ExecStreamResponseUpdate, ExitEvent,
|
||||||
Request, Response,
|
MetricsResponse, PingResponse, Request, Response,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
@ -41,11 +42,11 @@ impl GuestBackground {
|
|||||||
pub async fn run(&mut self) -> Result<()> {
|
pub async fn run(&mut self) -> Result<()> {
|
||||||
let mut event_subscription = self.idm.subscribe().await?;
|
let mut event_subscription = self.idm.subscribe().await?;
|
||||||
let mut requests_subscription = self.idm.requests().await?;
|
let mut requests_subscription = self.idm.requests().await?;
|
||||||
|
let mut request_streams_subscription = self.idm.request_streams().await?;
|
||||||
loop {
|
loop {
|
||||||
select! {
|
select! {
|
||||||
x = event_subscription.recv() => match x {
|
x = event_subscription.recv() => match x {
|
||||||
Ok(_event) => {
|
Ok(_event) => {
|
||||||
|
|
||||||
},
|
},
|
||||||
|
|
||||||
Err(broadcast::error::RecvError::Closed) => {
|
Err(broadcast::error::RecvError::Closed) => {
|
||||||
@ -73,6 +74,21 @@ impl GuestBackground {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
x = request_streams_subscription.recv() => match x {
|
||||||
|
Ok(handle) => {
|
||||||
|
self.handle_idm_stream_request(handle).await?;
|
||||||
|
},
|
||||||
|
|
||||||
|
Err(broadcast::error::RecvError::Closed) => {
|
||||||
|
debug!("idm packet channel closed");
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
event = self.wait.recv() => match event {
|
event = self.wait.recv() => match event {
|
||||||
Some(event) => self.child_event(event).await?,
|
Some(event) => self.child_event(event).await?,
|
||||||
None => {
|
None => {
|
||||||
@ -107,7 +123,33 @@ impl GuestBackground {
|
|||||||
self.idm.respond(id, response).await?;
|
self.idm.respond(id, response).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
None => {}
|
_ => {}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_idm_stream_request(
|
||||||
|
&mut self,
|
||||||
|
handle: IdmClientStreamResponseHandle<Request>,
|
||||||
|
) -> Result<()> {
|
||||||
|
if let Some(RequestType::ExecStream(_)) = &handle.initial.request {
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let exec = GuestExecTask { handle };
|
||||||
|
if let Err(error) = exec.run().await {
|
||||||
|
let _ = exec
|
||||||
|
.handle
|
||||||
|
.respond(Response {
|
||||||
|
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
|
||||||
|
exited: true,
|
||||||
|
error: error.to_string(),
|
||||||
|
exit_code: -1,
|
||||||
|
stdout: vec![],
|
||||||
|
stderr: vec![],
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
172
crates/guest/src/exec.rs
Normal file
172
crates/guest/src/exec.rs
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
use std::{collections::HashMap, process::Stdio};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use krata::idm::{
|
||||||
|
client::IdmClientStreamResponseHandle,
|
||||||
|
internal::{
|
||||||
|
exec_stream_request_update::Update, request::Request as RequestType,
|
||||||
|
ExecStreamResponseUpdate,
|
||||||
|
},
|
||||||
|
internal::{response::Response as ResponseType, Request, Response},
|
||||||
|
};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
|
join,
|
||||||
|
process::Command,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct GuestExecTask {
|
||||||
|
pub handle: IdmClientStreamResponseHandle<Request>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GuestExecTask {
|
||||||
|
pub async fn run(&self) -> Result<()> {
|
||||||
|
let mut receiver = self.handle.take().await?;
|
||||||
|
|
||||||
|
let Some(ref request) = self.handle.initial.request else {
|
||||||
|
return Err(anyhow!("request was empty"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let RequestType::ExecStream(update) = request else {
|
||||||
|
return Err(anyhow!("request was not an exec update"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(Update::Start(ref start)) = update.update else {
|
||||||
|
return Err(anyhow!("first request did not contain a start update"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cmd = start.command.clone();
|
||||||
|
if cmd.is_empty() {
|
||||||
|
return Err(anyhow!("command line was empty"));
|
||||||
|
}
|
||||||
|
let exe = cmd.remove(0);
|
||||||
|
let mut env = HashMap::new();
|
||||||
|
for entry in &start.environment {
|
||||||
|
env.insert(entry.key.clone(), entry.value.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !env.contains_key("PATH") {
|
||||||
|
env.insert(
|
||||||
|
"PATH".to_string(),
|
||||||
|
"/bin:/usr/bin:/usr/local/bin".to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dir = if start.working_directory.is_empty() {
|
||||||
|
"/".to_string()
|
||||||
|
} else {
|
||||||
|
start.working_directory.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut child = Command::new(exe)
|
||||||
|
.args(cmd)
|
||||||
|
.envs(env)
|
||||||
|
.current_dir(dir)
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()
|
||||||
|
.map_err(|error| anyhow!("failed to spawn: {}", error))?;
|
||||||
|
|
||||||
|
let mut stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow!("stdin was missing"))?;
|
||||||
|
let mut stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow!("stdout was missing"))?;
|
||||||
|
let mut stderr = child
|
||||||
|
.stderr
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| anyhow!("stderr was missing"))?;
|
||||||
|
|
||||||
|
let stdout_handle = self.handle.clone();
|
||||||
|
let stdout_task = tokio::task::spawn(async move {
|
||||||
|
let mut stdout_buffer = vec![0u8; 8 * 1024];
|
||||||
|
loop {
|
||||||
|
let Ok(size) = stdout.read(&mut stdout_buffer).await else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
if size > 0 {
|
||||||
|
let response = Response {
|
||||||
|
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
|
||||||
|
exited: false,
|
||||||
|
exit_code: 0,
|
||||||
|
error: String::new(),
|
||||||
|
stdout: stdout_buffer[0..size].to_vec(),
|
||||||
|
stderr: vec![],
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
let _ = stdout_handle.respond(response).await;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stderr_handle = self.handle.clone();
|
||||||
|
let stderr_task = tokio::task::spawn(async move {
|
||||||
|
let mut stderr_buffer = vec![0u8; 8 * 1024];
|
||||||
|
loop {
|
||||||
|
let Ok(size) = stderr.read(&mut stderr_buffer).await else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
if size > 0 {
|
||||||
|
let response = Response {
|
||||||
|
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
|
||||||
|
exited: false,
|
||||||
|
exit_code: 0,
|
||||||
|
error: String::new(),
|
||||||
|
stdout: vec![],
|
||||||
|
stderr: stderr_buffer[0..size].to_vec(),
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
let _ = stderr_handle.respond(response).await;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stdin_task = tokio::task::spawn(async move {
|
||||||
|
loop {
|
||||||
|
let Some(request) = receiver.recv().await else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(RequestType::ExecStream(update)) = request.request else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(Update::Stdin(update)) = update.update else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if stdin.write_all(&update.data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let exit = child.wait().await?;
|
||||||
|
let code = exit.code().unwrap_or(-1);
|
||||||
|
|
||||||
|
let _ = join!(stdout_task, stderr_task);
|
||||||
|
stdin_task.abort();
|
||||||
|
|
||||||
|
let response = Response {
|
||||||
|
response: Some(ResponseType::ExecStream(ExecStreamResponseUpdate {
|
||||||
|
exited: true,
|
||||||
|
exit_code: code,
|
||||||
|
error: String::new(),
|
||||||
|
stdout: vec![],
|
||||||
|
stderr: vec![],
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
self.handle.respond(response).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -479,7 +479,7 @@ impl GuestInit {
|
|||||||
env.insert("TERM".to_string(), "xterm".to_string());
|
env.insert("TERM".to_string(), "xterm".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
let path = GuestInit::resolve_executable(&env, path.into())?;
|
let path = resolve_executable(&env, path.into())?;
|
||||||
let Some(file_name) = path.file_name() else {
|
let Some(file_name) = path.file_name() else {
|
||||||
return Err(anyhow!("cannot get file name of command path"));
|
return Err(anyhow!("cannot get file name of command path"));
|
||||||
};
|
};
|
||||||
@ -537,27 +537,6 @@ impl GuestInit {
|
|||||||
map
|
map
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_executable(env: &HashMap<String, String>, path: PathBuf) -> Result<PathBuf> {
|
|
||||||
if path.is_absolute() {
|
|
||||||
return Ok(path);
|
|
||||||
}
|
|
||||||
|
|
||||||
if path.is_file() {
|
|
||||||
return Ok(path.absolutize()?.to_path_buf());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(path_var) = env.get("PATH") {
|
|
||||||
for item in path_var.split(':') {
|
|
||||||
let mut exe_path: PathBuf = item.into();
|
|
||||||
exe_path.push(&path);
|
|
||||||
if exe_path.is_file() {
|
|
||||||
return Ok(exe_path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn env_list(env: HashMap<String, String>) -> Vec<String> {
|
fn env_list(env: HashMap<String, String>) -> Vec<String> {
|
||||||
env.iter()
|
env.iter()
|
||||||
.map(|(key, value)| format!("{}={}", key, value))
|
.map(|(key, value)| format!("{}={}", key, value))
|
||||||
@ -613,3 +592,24 @@ impl GuestInit {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn resolve_executable(env: &HashMap<String, String>, path: PathBuf) -> Result<PathBuf> {
|
||||||
|
if path.is_absolute() {
|
||||||
|
return Ok(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
if path.is_file() {
|
||||||
|
return Ok(path.absolutize()?.to_path_buf());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path_var) = env.get("PATH") {
|
||||||
|
for item in path_var.split(':') {
|
||||||
|
let mut exe_path: PathBuf = item.into();
|
||||||
|
exe_path.push(&path);
|
||||||
|
if exe_path.is_file() {
|
||||||
|
return Ok(exe_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(path)
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ use xenstore::{XsdClient, XsdInterface};
|
|||||||
|
|
||||||
pub mod background;
|
pub mod background;
|
||||||
pub mod childwait;
|
pub mod childwait;
|
||||||
|
pub mod exec;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
|
||||||
|
@ -36,6 +36,36 @@ enum MetricFormat {
|
|||||||
METRIC_FORMAT_DURATION_SECONDS = 3;
|
METRIC_FORMAT_DURATION_SECONDS = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message ExecEnvVar {
|
||||||
|
string key = 1;
|
||||||
|
string value = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExecStreamRequestStart {
|
||||||
|
repeated ExecEnvVar environment = 1;
|
||||||
|
repeated string command = 2;
|
||||||
|
string working_directory = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExecStreamRequestStdin {
|
||||||
|
bytes data = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExecStreamRequestUpdate {
|
||||||
|
oneof update {
|
||||||
|
ExecStreamRequestStart start = 1;
|
||||||
|
ExecStreamRequestStdin stdin = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExecStreamResponseUpdate {
|
||||||
|
bool exited = 1;
|
||||||
|
string error = 2;
|
||||||
|
int32 exit_code = 3;
|
||||||
|
bytes stdout = 4;
|
||||||
|
bytes stderr = 5;
|
||||||
|
}
|
||||||
|
|
||||||
message Event {
|
message Event {
|
||||||
oneof event {
|
oneof event {
|
||||||
ExitEvent exit = 1;
|
ExitEvent exit = 1;
|
||||||
@ -46,6 +76,7 @@ message Request {
|
|||||||
oneof request {
|
oneof request {
|
||||||
PingRequest ping = 1;
|
PingRequest ping = 1;
|
||||||
MetricsRequest metrics = 2;
|
MetricsRequest metrics = 2;
|
||||||
|
ExecStreamRequestUpdate exec_stream = 3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,5 +84,6 @@ message Response {
|
|||||||
oneof response {
|
oneof response {
|
||||||
PingResponse ping = 1;
|
PingResponse ping = 1;
|
||||||
MetricsResponse metrics = 2;
|
MetricsResponse metrics = 2;
|
||||||
|
ExecStreamResponseUpdate exec_stream = 3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,4 +19,9 @@ enum IdmTransportPacketForm {
|
|||||||
IDM_TRANSPORT_PACKET_FORM_EVENT = 2;
|
IDM_TRANSPORT_PACKET_FORM_EVENT = 2;
|
||||||
IDM_TRANSPORT_PACKET_FORM_REQUEST = 3;
|
IDM_TRANSPORT_PACKET_FORM_REQUEST = 3;
|
||||||
IDM_TRANSPORT_PACKET_FORM_RESPONSE = 4;
|
IDM_TRANSPORT_PACKET_FORM_RESPONSE = 4;
|
||||||
|
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST = 5;
|
||||||
|
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_UPDATE = 6;
|
||||||
|
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_UPDATE = 7;
|
||||||
|
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_CLOSED = 8;
|
||||||
|
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_CLOSED = 9;
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,7 @@ message GuestOciImageSpec {
|
|||||||
message GuestTaskSpec {
|
message GuestTaskSpec {
|
||||||
repeated GuestTaskSpecEnvVar environment = 1;
|
repeated GuestTaskSpecEnvVar environment = 1;
|
||||||
repeated string command = 2;
|
repeated string command = 2;
|
||||||
|
string working_directory = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GuestTaskSpecEnvVar {
|
message GuestTaskSpecEnvVar {
|
||||||
|
@ -17,6 +17,8 @@ service ControlService {
|
|||||||
rpc ResolveGuest(ResolveGuestRequest) returns (ResolveGuestReply);
|
rpc ResolveGuest(ResolveGuestRequest) returns (ResolveGuestReply);
|
||||||
rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply);
|
rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply);
|
||||||
|
|
||||||
|
rpc ExecGuest(stream ExecGuestRequest) returns (stream ExecGuestReply);
|
||||||
|
|
||||||
rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply);
|
rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply);
|
||||||
rpc ReadGuestMetrics(ReadGuestMetricsRequest) returns (ReadGuestMetricsReply);
|
rpc ReadGuestMetrics(ReadGuestMetricsRequest) returns (ReadGuestMetricsReply);
|
||||||
|
|
||||||
@ -62,6 +64,20 @@ message ListGuestsReply {
|
|||||||
repeated krata.v1.common.Guest guests = 1;
|
repeated krata.v1.common.Guest guests = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message ExecGuestRequest {
|
||||||
|
string guest_id = 1;
|
||||||
|
krata.v1.common.GuestTaskSpec task = 2;
|
||||||
|
bytes data = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExecGuestReply {
|
||||||
|
bool exited = 1;
|
||||||
|
string error = 2;
|
||||||
|
int32 exit_code = 3;
|
||||||
|
bytes stdout = 4;
|
||||||
|
bytes stderr = 5;
|
||||||
|
}
|
||||||
|
|
||||||
message ConsoleDataRequest {
|
message ConsoleDataRequest {
|
||||||
string guest_id = 1;
|
string guest_id = 1;
|
||||||
bytes data = 2;
|
bytes data = 2;
|
||||||
|
@ -31,7 +31,9 @@ use super::{
|
|||||||
transport::{IdmTransportPacket, IdmTransportPacketForm},
|
transport::{IdmTransportPacket, IdmTransportPacketForm},
|
||||||
};
|
};
|
||||||
|
|
||||||
type RequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
|
type OneshotRequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
|
||||||
|
type StreamRequestMap<R> = Arc<Mutex<HashMap<u64, Sender<<R as IdmRequest>::Response>>>>;
|
||||||
|
type StreamRequestUpdateMap<R> = Arc<Mutex<HashMap<u64, mpsc::Sender<R>>>>;
|
||||||
pub type IdmInternalClient = IdmClient<internal::Request, internal::Event>;
|
pub type IdmInternalClient = IdmClient<internal::Request, internal::Event>;
|
||||||
|
|
||||||
const IDM_PACKET_QUEUE_LEN: usize = 100;
|
const IDM_PACKET_QUEUE_LEN: usize = 100;
|
||||||
@ -106,10 +108,12 @@ impl IdmBackend for IdmFileBackend {
|
|||||||
pub struct IdmClient<R: IdmRequest, E: IdmSerializable> {
|
pub struct IdmClient<R: IdmRequest, E: IdmSerializable> {
|
||||||
channel: u64,
|
channel: u64,
|
||||||
request_backend_sender: broadcast::Sender<(u64, R)>,
|
request_backend_sender: broadcast::Sender<(u64, R)>,
|
||||||
|
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
|
||||||
next_request_id: Arc<Mutex<u64>>,
|
next_request_id: Arc<Mutex<u64>>,
|
||||||
event_receiver_sender: broadcast::Sender<E>,
|
event_receiver_sender: broadcast::Sender<E>,
|
||||||
tx_sender: Sender<IdmTransportPacket>,
|
tx_sender: Sender<IdmTransportPacket>,
|
||||||
requests: RequestMap<R>,
|
requests: OneshotRequestMap<R>,
|
||||||
|
request_streams: StreamRequestMap<R>,
|
||||||
task: Arc<JoinHandle<()>>,
|
task: Arc<JoinHandle<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,21 +125,122 @@ impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClient<R, E> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct IdmClientStreamRequestHandle<R: IdmRequest, E: IdmSerializable> {
|
||||||
|
pub id: u64,
|
||||||
|
pub receiver: Receiver<R::Response>,
|
||||||
|
pub client: IdmClient<R, E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: IdmRequest, E: IdmSerializable> IdmClientStreamRequestHandle<R, E> {
|
||||||
|
pub async fn update(&self, request: R) -> Result<()> {
|
||||||
|
self.client
|
||||||
|
.tx_sender
|
||||||
|
.send(IdmTransportPacket {
|
||||||
|
id: self.id,
|
||||||
|
channel: self.client.channel,
|
||||||
|
form: IdmTransportPacketForm::StreamRequestUpdate.into(),
|
||||||
|
data: request.encode()?,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClientStreamRequestHandle<R, E> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let id = self.id;
|
||||||
|
let client = self.client.clone();
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let _ = client
|
||||||
|
.tx_sender
|
||||||
|
.send(IdmTransportPacket {
|
||||||
|
id,
|
||||||
|
channel: client.channel,
|
||||||
|
form: IdmTransportPacketForm::StreamRequestClosed.into(),
|
||||||
|
data: vec![],
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct IdmClientStreamResponseHandle<R: IdmRequest> {
|
||||||
|
pub initial: R,
|
||||||
|
pub id: u64,
|
||||||
|
channel: u64,
|
||||||
|
tx_sender: Sender<IdmTransportPacket>,
|
||||||
|
receiver: Arc<Mutex<Option<Receiver<R>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: IdmRequest> IdmClientStreamResponseHandle<R> {
|
||||||
|
pub async fn respond(&self, response: R::Response) -> Result<()> {
|
||||||
|
self.tx_sender
|
||||||
|
.send(IdmTransportPacket {
|
||||||
|
id: self.id,
|
||||||
|
channel: self.channel,
|
||||||
|
form: IdmTransportPacketForm::StreamResponseUpdate.into(),
|
||||||
|
data: response.encode()?,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn take(&self) -> Result<Receiver<R>> {
|
||||||
|
let mut guard = self.receiver.lock().await;
|
||||||
|
let Some(receiver) = (*guard).take() else {
|
||||||
|
return Err(anyhow!("request has already been claimed!"));
|
||||||
|
};
|
||||||
|
Ok(receiver)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: IdmRequest> Drop for IdmClientStreamResponseHandle<R> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if Arc::strong_count(&self.receiver) <= 1 {
|
||||||
|
let id = self.id;
|
||||||
|
let channel = self.channel;
|
||||||
|
let tx_sender = self.tx_sender.clone();
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let _ = tx_sender
|
||||||
|
.send(IdmTransportPacket {
|
||||||
|
id,
|
||||||
|
channel,
|
||||||
|
form: IdmTransportPacketForm::StreamResponseClosed.into(),
|
||||||
|
data: vec![],
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
||||||
pub async fn new(channel: u64, backend: Box<dyn IdmBackend>) -> Result<Self> {
|
pub async fn new(channel: u64, backend: Box<dyn IdmBackend>) -> Result<Self> {
|
||||||
let requests = Arc::new(Mutex::new(HashMap::new()));
|
let requests = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
let request_streams = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
let request_update_streams = Arc::new(Mutex::new(HashMap::new()));
|
||||||
let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
|
let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
|
||||||
let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
|
let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
|
||||||
|
let (internal_request_stream_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
|
||||||
let (tx_sender, tx_receiver) = mpsc::channel(IDM_PACKET_QUEUE_LEN);
|
let (tx_sender, tx_receiver) = mpsc::channel(IDM_PACKET_QUEUE_LEN);
|
||||||
let backend_event_sender = event_sender.clone();
|
let backend_event_sender = event_sender.clone();
|
||||||
let request_backend_sender = internal_request_backend_sender.clone();
|
let request_backend_sender = internal_request_backend_sender.clone();
|
||||||
|
let request_stream_backend_sender = internal_request_stream_backend_sender.clone();
|
||||||
let requests_for_client = requests.clone();
|
let requests_for_client = requests.clone();
|
||||||
|
let request_streams_for_client = request_streams.clone();
|
||||||
|
let tx_sender_for_client = tx_sender.clone();
|
||||||
let task = tokio::task::spawn(async move {
|
let task = tokio::task::spawn(async move {
|
||||||
if let Err(error) = IdmClient::process(
|
if let Err(error) = IdmClient::process(
|
||||||
backend,
|
backend,
|
||||||
|
channel,
|
||||||
|
tx_sender,
|
||||||
backend_event_sender,
|
backend_event_sender,
|
||||||
requests,
|
requests,
|
||||||
|
request_streams,
|
||||||
|
request_update_streams,
|
||||||
internal_request_backend_sender,
|
internal_request_backend_sender,
|
||||||
|
internal_request_stream_backend_sender,
|
||||||
event_receiver,
|
event_receiver,
|
||||||
tx_receiver,
|
tx_receiver,
|
||||||
)
|
)
|
||||||
@ -149,8 +254,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
|||||||
next_request_id: Arc::new(Mutex::new(0)),
|
next_request_id: Arc::new(Mutex::new(0)),
|
||||||
event_receiver_sender: event_sender.clone(),
|
event_receiver_sender: event_sender.clone(),
|
||||||
request_backend_sender,
|
request_backend_sender,
|
||||||
|
request_stream_backend_sender,
|
||||||
requests: requests_for_client,
|
requests: requests_for_client,
|
||||||
tx_sender,
|
request_streams: request_streams_for_client,
|
||||||
|
tx_sender: tx_sender_for_client,
|
||||||
task: Arc::new(task),
|
task: Arc::new(task),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -194,6 +301,12 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
|||||||
Ok(self.request_backend_sender.subscribe())
|
Ok(self.request_backend_sender.subscribe())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn request_streams(
|
||||||
|
&self,
|
||||||
|
) -> Result<broadcast::Receiver<IdmClientStreamResponseHandle<R>>> {
|
||||||
|
Ok(self.request_stream_backend_sender.subscribe())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn respond<T: IdmSerializable>(&self, id: u64, response: T) -> Result<()> {
|
pub async fn respond<T: IdmSerializable>(&self, id: u64, response: T) -> Result<()> {
|
||||||
let packet = IdmTransportPacket {
|
let packet = IdmTransportPacket {
|
||||||
id,
|
id,
|
||||||
@ -244,11 +357,43 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
|||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn send_stream(&self, request: R) -> Result<IdmClientStreamRequestHandle<R, E>> {
|
||||||
|
let (sender, receiver) = mpsc::channel::<R::Response>(100);
|
||||||
|
let req = {
|
||||||
|
let mut guard = self.next_request_id.lock().await;
|
||||||
|
let req = *guard;
|
||||||
|
*guard = req.wrapping_add(1);
|
||||||
|
req
|
||||||
|
};
|
||||||
|
let mut requests = self.request_streams.lock().await;
|
||||||
|
requests.insert(req, sender);
|
||||||
|
drop(requests);
|
||||||
|
self.tx_sender
|
||||||
|
.send(IdmTransportPacket {
|
||||||
|
id: req,
|
||||||
|
channel: self.channel,
|
||||||
|
form: IdmTransportPacketForm::StreamRequest.into(),
|
||||||
|
data: request.encode()?,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(IdmClientStreamRequestHandle {
|
||||||
|
id: req,
|
||||||
|
receiver,
|
||||||
|
client: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn process(
|
async fn process(
|
||||||
mut backend: Box<dyn IdmBackend>,
|
mut backend: Box<dyn IdmBackend>,
|
||||||
|
channel: u64,
|
||||||
|
tx_sender: Sender<IdmTransportPacket>,
|
||||||
event_sender: broadcast::Sender<E>,
|
event_sender: broadcast::Sender<E>,
|
||||||
requests: RequestMap<R>,
|
requests: OneshotRequestMap<R>,
|
||||||
|
request_streams: StreamRequestMap<R>,
|
||||||
|
request_update_streams: StreamRequestUpdateMap<R>,
|
||||||
request_backend_sender: broadcast::Sender<(u64, R)>,
|
request_backend_sender: broadcast::Sender<(u64, R)>,
|
||||||
|
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
|
||||||
_event_receiver: broadcast::Receiver<E>,
|
_event_receiver: broadcast::Receiver<E>,
|
||||||
mut receiver: Receiver<IdmTransportPacket>,
|
mut receiver: Receiver<IdmTransportPacket>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
@ -256,6 +401,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
|||||||
select! {
|
select! {
|
||||||
x = backend.recv() => match x {
|
x = backend.recv() => match x {
|
||||||
Ok(packet) => {
|
Ok(packet) => {
|
||||||
|
if packet.channel != channel {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
match packet.form() {
|
match packet.form() {
|
||||||
IdmTransportPacketForm::Event => {
|
IdmTransportPacketForm::Event => {
|
||||||
if let Ok(event) = E::decode(&packet.data) {
|
if let Ok(event) = E::decode(&packet.data) {
|
||||||
@ -280,6 +429,50 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
IdmTransportPacketForm::StreamRequest => {
|
||||||
|
if let Ok(request) = R::decode(&packet.data) {
|
||||||
|
let mut update_streams = request_update_streams.lock().await;
|
||||||
|
let (sender, receiver) = mpsc::channel(100);
|
||||||
|
update_streams.insert(packet.id, sender.clone());
|
||||||
|
let handle = IdmClientStreamResponseHandle {
|
||||||
|
initial: request,
|
||||||
|
id: packet.id,
|
||||||
|
channel,
|
||||||
|
tx_sender: tx_sender.clone(),
|
||||||
|
receiver: Arc::new(Mutex::new(Some(receiver))),
|
||||||
|
};
|
||||||
|
let _ = request_stream_backend_sender.send(handle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IdmTransportPacketForm::StreamRequestUpdate => {
|
||||||
|
if let Ok(request) = R::decode(&packet.data) {
|
||||||
|
let mut update_streams = request_update_streams.lock().await;
|
||||||
|
if let Some(stream) = update_streams.get_mut(&packet.id) {
|
||||||
|
let _ = stream.try_send(request);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IdmTransportPacketForm::StreamRequestClosed => {
|
||||||
|
let mut update_streams = request_update_streams.lock().await;
|
||||||
|
update_streams.remove(&packet.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
IdmTransportPacketForm::StreamResponseUpdate => {
|
||||||
|
let requests = request_streams.lock().await;
|
||||||
|
if let Some(sender) = requests.get(&packet.id) {
|
||||||
|
if let Ok(response) = R::Response::decode(&packet.data) {
|
||||||
|
let _ = sender.try_send(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IdmTransportPacketForm::StreamResponseClosed => {
|
||||||
|
let mut requests = request_streams.lock().await;
|
||||||
|
requests.remove(&packet.id);
|
||||||
|
}
|
||||||
|
|
||||||
_ => {},
|
_ => {},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Loading…
Reference in New Issue
Block a user