fix(idm): repair idm bugs on the file backend

This commit is contained in:
Alex Zenla 2024-08-13 15:57:45 -07:00
parent 621ae536f6
commit 29ce7ef5e4
No known key found for this signature in database
GPG Key ID: 067B238899B51269
6 changed files with 70 additions and 34 deletions

View File

@ -26,7 +26,7 @@ impl ZoneAttachCommand {
let input = StdioConsoleStream::stdin_stream(zone_id.clone()).await; let input = StdioConsoleStream::stdin_stream(zone_id.clone()).await;
let output = client.attach_zone_console(input).await?.into_inner(); let output = client.attach_zone_console(input).await?.into_inner();
let stdout_handle = let stdout_handle =
tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); tokio::task::spawn(async move { StdioConsoleStream::stdout(output, true).await });
let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?; let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?;
let code = select! { let code = select! {
x = stdout_handle => { x = stdout_handle => {

View File

@ -155,7 +155,7 @@ impl ZoneLaunchCommand {
let input = StdioConsoleStream::stdin_stream(id.clone()).await; let input = StdioConsoleStream::stdin_stream(id.clone()).await;
let output = client.attach_zone_console(input).await?.into_inner(); let output = client.attach_zone_console(input).await?.into_inner();
let stdout_handle = let stdout_handle =
tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); tokio::task::spawn(async move { StdioConsoleStream::stdout(output, true).await });
let exit_hook_task = StdioConsoleStream::zone_exit_hook(id.clone(), events).await?; let exit_hook_task = StdioConsoleStream::zone_exit_hook(id.clone(), events).await?;
select! { select! {
x = stdout_handle => { x = stdout_handle => {

View File

@ -43,7 +43,7 @@ impl ZoneLogsCommand {
}; };
let output = client.attach_zone_console(input).await?.into_inner(); let output = client.attach_zone_console(input).await?.into_inner();
let stdout_handle = let stdout_handle =
tokio::task::spawn(async move { StdioConsoleStream::stdout(output).await }); tokio::task::spawn(async move { StdioConsoleStream::stdout(output, false).await });
let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?; let exit_hook_task = StdioConsoleStream::zone_exit_hook(zone_id.clone(), events).await?;
let code = select! { let code = select! {
x = stdout_handle => { x = stdout_handle => {

View File

@ -73,8 +73,8 @@ impl StdioConsoleStream {
} }
} }
pub async fn stdout(mut stream: Streaming<ZoneConsoleReply>) -> Result<()> { pub async fn stdout(mut stream: Streaming<ZoneConsoleReply>, raw: bool) -> Result<()> {
if stdin().is_tty() { if raw && stdin().is_tty() {
enable_raw_mode()?; enable_raw_mode()?;
StdioConsoleStream::register_terminal_restore_hook()?; StdioConsoleStream::register_terminal_restore_hook()?;
} }

View File

@ -254,9 +254,9 @@ pub struct IdmDaemonBackend {
#[async_trait::async_trait] #[async_trait::async_trait]
impl IdmBackend for IdmDaemonBackend { impl IdmBackend for IdmDaemonBackend {
async fn recv(&mut self) -> Result<IdmTransportPacket> { async fn recv(&mut self) -> Result<Vec<IdmTransportPacket>> {
if let Some(packet) = self.rx_receiver.recv().await { if let Some(packet) = self.rx_receiver.recv().await {
Ok(packet) Ok(vec![packet])
} else { } else {
Err(anyhow!("idm receive channel closed")) Err(anyhow!("idm receive channel closed"))
} }

View File

@ -9,13 +9,13 @@ use std::{
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use bytes::{BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error}; use log::{debug, error};
use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, SetArg}; use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, SetArg};
use prost::Message; use prost::Message;
use tokio::{ use tokio::{
fs::File, fs::File,
io::{unix::AsyncFd, AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
select, select,
sync::{ sync::{
broadcast, broadcast,
@ -43,12 +43,13 @@ const IDM_PACKET_MAX_SIZE: usize = 20 * 1024 * 1024;
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait IdmBackend: Send { pub trait IdmBackend: Send {
async fn recv(&mut self) -> Result<IdmTransportPacket>; async fn recv(&mut self) -> Result<Vec<IdmTransportPacket>>;
async fn send(&mut self, packet: IdmTransportPacket) -> Result<()>; async fn send(&mut self, packet: IdmTransportPacket) -> Result<()>;
} }
pub struct IdmFileBackend { pub struct IdmFileBackend {
read_fd: Arc<Mutex<AsyncFd<File>>>, read: Arc<Mutex<File>>,
read_buffer: BytesMut,
write: Arc<Mutex<File>>, write: Arc<Mutex<File>>,
} }
@ -57,7 +58,8 @@ impl IdmFileBackend {
IdmFileBackend::set_raw_port(&read_file)?; IdmFileBackend::set_raw_port(&read_file)?;
IdmFileBackend::set_raw_port(&write_file)?; IdmFileBackend::set_raw_port(&write_file)?;
Ok(IdmFileBackend { Ok(IdmFileBackend {
read_fd: Arc::new(Mutex::new(AsyncFd::new(read_file)?)), read: Arc::new(Mutex::new(read_file)),
read_buffer: BytesMut::new(),
write: Arc::new(Mutex::new(write_file)), write: Arc::new(Mutex::new(write_file)),
}) })
} }
@ -72,26 +74,58 @@ impl IdmFileBackend {
#[async_trait::async_trait] #[async_trait::async_trait]
impl IdmBackend for IdmFileBackend { impl IdmBackend for IdmFileBackend {
async fn recv(&mut self) -> Result<IdmTransportPacket> { async fn recv(&mut self) -> Result<Vec<IdmTransportPacket>> {
let mut fd = self.read_fd.lock().await; let mut data = vec![0; 8192];
let mut guard = fd.readable_mut().await?; let mut first = true;
let b1 = guard.get_inner_mut().read_u8().await?; 'read_more: loop {
if b1 != 0xff { let mut packets = Vec::new();
return Ok(IdmTransportPacket::default()); if !first {
} if !packets.is_empty() {
let b2 = guard.get_inner_mut().read_u8().await?; return Ok(packets);
if b2 != 0xff { }
return Ok(IdmTransportPacket::default()); let size = self.read.lock().await.read(&mut data).await?;
} self.read_buffer.extend_from_slice(&data[0..size]);
let size = guard.get_inner_mut().read_u32_le().await?; }
if size == 0 { first = false;
return Ok(IdmTransportPacket::default()); loop {
} if self.read_buffer.len() < 6 {
let mut buffer = vec![0u8; size as usize]; continue 'read_more;
guard.get_inner_mut().read_exact(&mut buffer).await?; }
match IdmTransportPacket::decode(buffer.as_slice()) {
Ok(packet) => Ok(packet), let b1 = self.read_buffer[0];
Err(error) => Err(anyhow!("received invalid idm packet: {}", error)), let b2 = self.read_buffer[1];
if b1 != 0xff || b2 != 0xff {
self.read_buffer.clear();
continue 'read_more;
}
let size = (self.read_buffer[2] as u32
| (self.read_buffer[3] as u32) << 8
| (self.read_buffer[4] as u32) << 16
| (self.read_buffer[5] as u32) << 24) as usize;
let needed = size + 6;
if self.read_buffer.len() < needed {
continue 'read_more;
}
let mut packet = self.read_buffer.split_to(needed);
packet.advance(6);
match IdmTransportPacket::decode(packet) {
Ok(packet) => {
packets.push(packet);
}
Err(error) => {
return Err(anyhow!("received invalid idm packet: {}", error));
}
}
if self.read_buffer.is_empty() {
break;
}
}
return Ok(packets);
} }
} }
@ -403,8 +437,9 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
loop { loop {
select! { select! {
x = backend.recv() => match x { x = backend.recv() => match x {
Ok(packet) => { Ok(packets) => {
if packet.channel != channel { for packet in packets {
if packet.channel != channel {
continue; continue;
} }
@ -478,6 +513,7 @@ impl<R: IdmRequest, E: IdmSerializable> IdmClient<R, E> {
_ => {}, _ => {},
} }
}
}, },
Err(error) => { Err(error) => {