feat: implement guest exec (#107)

This commit is contained in:
Alex Zenla
2024-04-22 13:13:43 -07:00
committed by GitHub
parent 82576df7b7
commit 284ed8f17b
15 changed files with 755 additions and 53 deletions

View File

@ -36,6 +36,36 @@ enum MetricFormat {
METRIC_FORMAT_DURATION_SECONDS = 3;
}
message ExecEnvVar {
string key = 1;
string value = 2;
}
message ExecStreamRequestStart {
repeated ExecEnvVar environment = 1;
repeated string command = 2;
string working_directory = 3;
}
message ExecStreamRequestStdin {
bytes data = 1;
}
message ExecStreamRequestUpdate {
oneof update {
ExecStreamRequestStart start = 1;
ExecStreamRequestStdin stdin = 2;
}
}
message ExecStreamResponseUpdate {
bool exited = 1;
string error = 2;
int32 exit_code = 3;
bytes stdout = 4;
bytes stderr = 5;
}
message Event {
oneof event {
ExitEvent exit = 1;
@ -46,6 +76,7 @@ message Request {
oneof request {
PingRequest ping = 1;
MetricsRequest metrics = 2;
ExecStreamRequestUpdate exec_stream = 3;
}
}
@ -53,5 +84,6 @@ message Response {
oneof response {
PingResponse ping = 1;
MetricsResponse metrics = 2;
ExecStreamResponseUpdate exec_stream = 3;
}
}

View File

@ -19,4 +19,9 @@ enum IdmTransportPacketForm {
IDM_TRANSPORT_PACKET_FORM_EVENT = 2;
IDM_TRANSPORT_PACKET_FORM_REQUEST = 3;
IDM_TRANSPORT_PACKET_FORM_RESPONSE = 4;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST = 5;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_UPDATE = 6;
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_UPDATE = 7;
IDM_TRANSPORT_PACKET_FORM_STREAM_REQUEST_CLOSED = 8;
IDM_TRANSPORT_PACKET_FORM_STREAM_RESPONSE_CLOSED = 9;
}

View File

@ -49,6 +49,7 @@ message GuestOciImageSpec {
message GuestTaskSpec {
repeated GuestTaskSpecEnvVar environment = 1;
repeated string command = 2;
string working_directory = 3;
}
message GuestTaskSpecEnvVar {

View File

@ -17,6 +17,8 @@ service ControlService {
rpc ResolveGuest(ResolveGuestRequest) returns (ResolveGuestReply);
rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply);
rpc ExecGuest(stream ExecGuestRequest) returns (stream ExecGuestReply);
rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply);
rpc ReadGuestMetrics(ReadGuestMetricsRequest) returns (ReadGuestMetricsReply);
@ -62,6 +64,20 @@ message ListGuestsReply {
repeated krata.v1.common.Guest guests = 1;
}
message ExecGuestRequest {
string guest_id = 1;
krata.v1.common.GuestTaskSpec task = 2;
bytes data = 3;
}
message ExecGuestReply {
bool exited = 1;
string error = 2;
int32 exit_code = 3;
bytes stdout = 4;
bytes stderr = 5;
}
message ConsoleDataRequest {
string guest_id = 1;
bytes data = 2;

View File

@ -31,7 +31,9 @@ use super::{
transport::{IdmTransportPacket, IdmTransportPacketForm},
};
type RequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
type OneshotRequestMap<R> = Arc<Mutex<HashMap<u64, oneshot::Sender<<R as IdmRequest>::Response>>>>;
type StreamRequestMap<R> = Arc<Mutex<HashMap<u64, Sender<<R as IdmRequest>::Response>>>>;
type StreamRequestUpdateMap<R> = Arc<Mutex<HashMap<u64, mpsc::Sender<R>>>>;
pub type IdmInternalClient = IdmClient<internal::Request, internal::Event>;
const IDM_PACKET_QUEUE_LEN: usize = 100;
@ -106,10 +108,12 @@ impl IdmBackend for IdmFileBackend {
pub struct IdmClient<R: IdmRequest, E: IdmSerializable> {
channel: u64,
request_backend_sender: broadcast::Sender<(u64, R)>,
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
next_request_id: Arc<Mutex<u64>>,
event_receiver_sender: broadcast::Sender<E>,
tx_sender: Sender<IdmTransportPacket>,
requests: RequestMap<R>,
requests: OneshotRequestMap<R>,
request_streams: StreamRequestMap<R>,
task: Arc<JoinHandle<()>>,
}
@ -121,21 +125,122 @@ impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClient<R, E> {
}
}
pub struct IdmClientStreamRequestHandle<R: IdmRequest, E: IdmSerializable> {
pub id: u64,
pub receiver: Receiver<R::Response>,
pub client: IdmClient<R, E>,
}
impl<R: IdmRequest, E: IdmSerializable> IdmClientStreamRequestHandle<R, E> {
pub async fn update(&self, request: R) -> Result<()> {
self.client
.tx_sender
.send(IdmTransportPacket {
id: self.id,
channel: self.client.channel,
form: IdmTransportPacketForm::StreamRequestUpdate.into(),
data: request.encode()?,
})
.await?;
Ok(())
}
}
impl<R: IdmRequest, E: IdmSerializable> Drop for IdmClientStreamRequestHandle<R, E> {
fn drop(&mut self) {
let id = self.id;
let client = self.client.clone();
tokio::task::spawn(async move {
let _ = client
.tx_sender
.send(IdmTransportPacket {
id,
channel: client.channel,
form: IdmTransportPacketForm::StreamRequestClosed.into(),
data: vec![],
})
.await;
});
}
}
#[derive(Clone)]
pub struct IdmClientStreamResponseHandle<R: IdmRequest> {
pub initial: R,
pub id: u64,
channel: u64,
tx_sender: Sender<IdmTransportPacket>,
receiver: Arc<Mutex<Option<Receiver<R>>>>,
}
impl<R: IdmRequest> IdmClientStreamResponseHandle<R> {
pub async fn respond(&self, response: R::Response) -> Result<()> {
self.tx_sender
.send(IdmTransportPacket {
id: self.id,
channel: self.channel,
form: IdmTransportPacketForm::StreamResponseUpdate.into(),
data: response.encode()?,
})
.await?;
Ok(())
}
pub async fn take(&self) -> Result<Receiver<R>> {
let mut guard = self.receiver.lock().await;
let Some(receiver) = (*guard).take() else {
return Err(anyhow!("request has already been claimed!"));
};
Ok(receiver)
}
}
impl<R: IdmRequest> Drop for IdmClientStreamResponseHandle<R> {
fn drop(&mut self) {
if Arc::strong_count(&self.receiver) <= 1 {
let id = self.id;
let channel = self.channel;
let tx_sender = self.tx_sender.clone();
tokio::task::spawn(async move {
let _ = tx_sender
.send(IdmTransportPacket {
id,
channel,
form: IdmTransportPacketForm::StreamResponseClosed.into(),
data: vec![],
})
.await;
});
}
}
}
impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
pub async fn new(channel: u64, backend: Box<dyn IdmBackend>) -> Result<Self> {
let requests = Arc::new(Mutex::new(HashMap::new()));
let request_streams = Arc::new(Mutex::new(HashMap::new()));
let request_update_streams = Arc::new(Mutex::new(HashMap::new()));
let (event_sender, event_receiver) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (internal_request_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (internal_request_stream_backend_sender, _) = broadcast::channel(IDM_PACKET_QUEUE_LEN);
let (tx_sender, tx_receiver) = mpsc::channel(IDM_PACKET_QUEUE_LEN);
let backend_event_sender = event_sender.clone();
let request_backend_sender = internal_request_backend_sender.clone();
let request_stream_backend_sender = internal_request_stream_backend_sender.clone();
let requests_for_client = requests.clone();
let request_streams_for_client = request_streams.clone();
let tx_sender_for_client = tx_sender.clone();
let task = tokio::task::spawn(async move {
if let Err(error) = IdmClient::process(
backend,
channel,
tx_sender,
backend_event_sender,
requests,
request_streams,
request_update_streams,
internal_request_backend_sender,
internal_request_stream_backend_sender,
event_receiver,
tx_receiver,
)
@ -149,8 +254,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
next_request_id: Arc::new(Mutex::new(0)),
event_receiver_sender: event_sender.clone(),
request_backend_sender,
request_stream_backend_sender,
requests: requests_for_client,
tx_sender,
request_streams: request_streams_for_client,
tx_sender: tx_sender_for_client,
task: Arc::new(task),
})
}
@ -194,6 +301,12 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
Ok(self.request_backend_sender.subscribe())
}
pub async fn request_streams(
&self,
) -> Result<broadcast::Receiver<IdmClientStreamResponseHandle<R>>> {
Ok(self.request_stream_backend_sender.subscribe())
}
pub async fn respond<T: IdmSerializable>(&self, id: u64, response: T) -> Result<()> {
let packet = IdmTransportPacket {
id,
@ -244,11 +357,43 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
Ok(response)
}
pub async fn send_stream(&self, request: R) -> Result<IdmClientStreamRequestHandle<R, E>> {
let (sender, receiver) = mpsc::channel::<R::Response>(100);
let req = {
let mut guard = self.next_request_id.lock().await;
let req = *guard;
*guard = req.wrapping_add(1);
req
};
let mut requests = self.request_streams.lock().await;
requests.insert(req, sender);
drop(requests);
self.tx_sender
.send(IdmTransportPacket {
id: req,
channel: self.channel,
form: IdmTransportPacketForm::StreamRequest.into(),
data: request.encode()?,
})
.await?;
Ok(IdmClientStreamRequestHandle {
id: req,
receiver,
client: self.clone(),
})
}
#[allow(clippy::too_many_arguments)]
async fn process(
mut backend: Box<dyn IdmBackend>,
channel: u64,
tx_sender: Sender<IdmTransportPacket>,
event_sender: broadcast::Sender<E>,
requests: RequestMap<R>,
requests: OneshotRequestMap<R>,
request_streams: StreamRequestMap<R>,
request_update_streams: StreamRequestUpdateMap<R>,
request_backend_sender: broadcast::Sender<(u64, R)>,
request_stream_backend_sender: broadcast::Sender<IdmClientStreamResponseHandle<R>>,
_event_receiver: broadcast::Receiver<E>,
mut receiver: Receiver<IdmTransportPacket>,
) -> Result<()> {
@ -256,6 +401,10 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
select! {
x = backend.recv() => match x {
Ok(packet) => {
if packet.channel != channel {
continue;
}
match packet.form() {
IdmTransportPacketForm::Event => {
if let Ok(event) = E::decode(&packet.data) {
@ -280,6 +429,50 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
}
},
IdmTransportPacketForm::StreamRequest => {
if let Ok(request) = R::decode(&packet.data) {
let mut update_streams = request_update_streams.lock().await;
let (sender, receiver) = mpsc::channel(100);
update_streams.insert(packet.id, sender.clone());
let handle = IdmClientStreamResponseHandle {
initial: request,
id: packet.id,
channel,
tx_sender: tx_sender.clone(),
receiver: Arc::new(Mutex::new(Some(receiver))),
};
let _ = request_stream_backend_sender.send(handle);
}
}
IdmTransportPacketForm::StreamRequestUpdate => {
if let Ok(request) = R::decode(&packet.data) {
let mut update_streams = request_update_streams.lock().await;
if let Some(stream) = update_streams.get_mut(&packet.id) {
let _ = stream.try_send(request);
}
}
}
IdmTransportPacketForm::StreamRequestClosed => {
let mut update_streams = request_update_streams.lock().await;
update_streams.remove(&packet.id);
}
IdmTransportPacketForm::StreamResponseUpdate => {
let requests = request_streams.lock().await;
if let Some(sender) = requests.get(&packet.id) {
if let Ok(response) = R::Response::decode(&packet.data) {
let _ = sender.try_send(response);
}
}
}
IdmTransportPacketForm::StreamResponseClosed => {
let mut requests = request_streams.lock().await;
requests.remove(&packet.id);
}
_ => {},
}
},