diff --git a/crates/ctl/src/cli/zone/attach.rs b/crates/ctl/src/cli/zone/attach.rs index 990e74f..9b3b487 100644 --- a/crates/ctl/src/cli/zone/attach.rs +++ b/crates/ctl/src/cli/zone/attach.rs @@ -26,7 +26,7 @@ impl ZoneAttachCommand { let input = StdioConsoleStream::stdin_stream(zone_id.clone()).await; let output = client.attach_zone_console(input).await?.into_inner(); let stdout_handle = - tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); + tokio::task::spawn(async move { StdioConsoleStream::stdout(output, true).await }); let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?; let code = select! { x = stdout_handle => { diff --git a/crates/ctl/src/cli/zone/launch.rs b/crates/ctl/src/cli/zone/launch.rs index 896a620..53ca8b2 100644 --- a/crates/ctl/src/cli/zone/launch.rs +++ b/crates/ctl/src/cli/zone/launch.rs @@ -155,7 +155,7 @@ impl ZoneLaunchCommand { let input = StdioConsoleStream::stdin_stream(id.clone()).await; let output = client.attach_zone_console(input).await?.into_inner(); let stdout_handle = - tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); + tokio::task::spawn(async move { StdioConsoleStream::stdout(output, true).await }); let exit_hook_task = StdioConsoleStream::zone_exit_hook(id.clone(), events).await?; select! { x = stdout_handle => { diff --git a/crates/ctl/src/cli/zone/logs.rs b/crates/ctl/src/cli/zone/logs.rs index 820183b..9d8a98b 100644 --- a/crates/ctl/src/cli/zone/logs.rs +++ b/crates/ctl/src/cli/zone/logs.rs @@ -43,7 +43,7 @@ impl ZoneLogsCommand { }; let output = client.attach_zone_console(input).await?.into_inner(); let stdout_handle = - tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); + tokio::task::spawn(async move { StdioConsoleStream::stdout(output, false).await }); let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?; let code = select! { x = stdout_handle => { diff --git a/crates/ctl/src/console.rs b/crates/ctl/src/console.rs index 8884c36..96282e0 100644 --- a/crates/ctl/src/console.rs +++ b/crates/ctl/src/console.rs @@ -73,8 +73,8 @@ impl StdioConsoleStream { } } - pub async fn stdout(mut stream: Streaming) -> Result<()> { - if stdin().is_tty() { + pub async fn stdout(mut stream: Streaming, raw: bool) -> Result<()> { + if raw && stdin().is_tty() { enable_raw_mode()?; StdioConsoleStream::register_terminal_restore_hook()?; } diff --git a/crates/daemon/src/idm.rs b/crates/daemon/src/idm.rs index f5d9016..94d70f6 100644 --- a/crates/daemon/src/idm.rs +++ b/crates/daemon/src/idm.rs @@ -254,9 +254,9 @@ pub struct IdmDaemonBackend { #[async_trait::async_trait] impl IdmBackend for IdmDaemonBackend { - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result> { if let Some(packet) = self.rx_receiver.recv().await { - Ok(packet) + Ok(vec![packet]) } else { Err(anyhow!("idm receive channel closed")) } diff --git a/crates/krata/src/idm/client.rs b/crates/krata/src/idm/client.rs index 6e34f59..8e240f4 100644 --- a/crates/krata/src/idm/client.rs +++ b/crates/krata/src/idm/client.rs @@ -9,13 +9,13 @@ use std::{ }; use anyhow::{anyhow, Result}; -use bytes::{BufMut, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error}; use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, SetArg}; use prost::Message; use tokio::{ fs::File, - io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt}, select, sync::{ broadcast, @@ -43,12 +43,13 @@ const IDM_PACKET_MAX_SIZE: usize = 20 * 1024 * 1024; #[async_trait::async_trait] pub trait IdmBackend: Send { - async fn recv(&mut self) -> Result; + async fn recv(&mut self) -> Result>; async fn send(&mut self, packet: IdmTransportPacket) -> Result<()>; } pub struct IdmFileBackend { - read_fd: Arc>>, + read: Arc>, + read_buffer: BytesMut, write: Arc>, } @@ -57,7 +58,8 @@ impl IdmFileBackend { IdmFileBackend::set_raw_port(&read_file)?; IdmFileBackend::set_raw_port(&write_file)?; Ok(IdmFileBackend { - read_fd: Arc::new(Mutex::new(AsyncFd::new(read_file)?)), + read: Arc::new(Mutex::new(read_file)), + read_buffer: BytesMut::new(), write: Arc::new(Mutex::new(write_file)), }) } @@ -72,26 +74,58 @@ impl IdmFileBackend { #[async_trait::async_trait] impl IdmBackend for IdmFileBackend { - async fn recv(&mut self) -> Result { - let mut fd = self.read_fd.lock().await; - let mut guard = fd.readable_mut().await?; - let b1 = guard.get_inner_mut().read_u8().await?; - if b1 != 0xff { - return Ok(IdmTransportPacket::default()); - } - let b2 = guard.get_inner_mut().read_u8().await?; - if b2 != 0xff { - return Ok(IdmTransportPacket::default()); - } - let size = guard.get_inner_mut().read_u32_le().await?; - if size == 0 { - return Ok(IdmTransportPacket::default()); - } - let mut buffer = vec![0u8; size as usize]; - guard.get_inner_mut().read_exact(&mut buffer).await?; - match IdmTransportPacket::decode(buffer.as_slice()) { - Ok(packet) => Ok(packet), - Err(error) => Err(anyhow!("received invalid idm packet: {}", error)), + async fn recv(&mut self) -> Result> { + let mut data = vec![0; 8192]; + let mut first = true; + 'read_more: loop { + let mut packets = Vec::new(); + if !first { + if !packets.is_empty() { + return Ok(packets); + } + let size = self.read.lock().await.read(&mut data).await?; + self.read_buffer.extend_from_slice(&data[0..size]); + } + first = false; + loop { + if self.read_buffer.len() < 6 { + continue 'read_more; + } + + let b1 = self.read_buffer[0]; + let b2 = self.read_buffer[1]; + + if b1 != 0xff || b2 != 0xff { + self.read_buffer.clear(); + continue 'read_more; + } + + let size = (self.read_buffer[2] as u32 + | (self.read_buffer[3] as u32) << 8 + | (self.read_buffer[4] as u32) << 16 + | (self.read_buffer[5] as u32) << 24) as usize; + let needed = size + 6; + if self.read_buffer.len() < needed { + continue 'read_more; + } + + let mut packet = self.read_buffer.split_to(needed); + packet.advance(6); + + match IdmTransportPacket::decode(packet) { + Ok(packet) => { + packets.push(packet); + } + Err(error) => { + return Err(anyhow!("received invalid idm packet: {}", error)); + } + } + + if self.read_buffer.is_empty() { + break; + } + } + return Ok(packets); } } @@ -403,8 +437,9 @@ impl IdmClient { loop { select! { x = backend.recv() => match x { - Ok(packet) => { - if packet.channel != channel { + Ok(packets) => { + for packet in packets { + if packet.channel != channel { continue; } @@ -478,6 +513,7 @@ impl IdmClient { _ => {}, } + } }, Err(error) => {