krata: utilize gRPC for control service

This commit is contained in:
Alex Zenla 2024-03-06 12:05:01 +00:00
parent 31cf3044a4
commit 3628422168
No known key found for this signature in database
GPG Key ID: 067B238899B51269
24 changed files with 532 additions and 1159 deletions

View File

@ -60,6 +60,10 @@ tokio-listener = "0.3.1"
trait-variant = "0.1.1" trait-variant = "0.1.1"
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
signal-hook = "0.3.17" signal-hook = "0.3.17"
tonic-build = "0.11.0"
prost = "0.12.3"
async-stream = "0.3.5"
tower = "0.4.13"
[workspace.dependencies.uuid] [workspace.dependencies.uuid]
version = "1.6.1" version = "1.6.1"
@ -79,7 +83,7 @@ features = ["macros", "rt", "rt-multi-thread", "io-util"]
[workspace.dependencies.tokio-stream] [workspace.dependencies.tokio-stream]
version = "0.1" version = "0.1"
features = ["io-util"] features = ["io-util", "net"]
[workspace.dependencies.reqwest] [workspace.dependencies.reqwest]
version = "0.11.24" version = "0.11.24"
@ -87,3 +91,7 @@ version = "0.11.24"
[workspace.dependencies.serde] [workspace.dependencies.serde]
version = "1.0.196" version = "1.0.196"
features = ["derive"] features = ["derive"]
[workspace.dependencies.tonic]
version = "0.11.0"
features = ["tls"]

View File

@ -17,6 +17,9 @@ tokio = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
tokio-native-tls = { workspace = true } tokio-native-tls = { workspace = true }
url = { workspace = true } url = { workspace = true }
tower = { workspace = true }
tonic = { workspace = true}
async-stream = { workspace = true }
[dependencies.krata] [dependencies.krata]
path = "../shared" path = "../shared"

View File

@ -1,14 +1,9 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use env_logger::Env; use env_logger::Env;
use krata::control::{ use krata::control::{DestroyGuestRequest, LaunchGuestRequest, ListGuestsRequest};
ConsoleStreamRequest, DestroyRequest, LaunchRequest, ListRequest, Request, Response, use kratactl::{client::ControlClientProvider, console::StdioConsoleStream};
}; use tonic::Request;
use kratactl::{
client::{KrataClient, KrataClientTransport},
console::XenConsole,
};
use url::Url;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version, about)] #[command(version, about)]
@ -53,8 +48,7 @@ async fn main() -> Result<()> {
env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init(); env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init();
let args = ControllerArgs::parse(); let args = ControllerArgs::parse();
let transport = KrataClientTransport::dial(Url::parse(&args.connection)?).await?; let mut client = ControlClientProvider::dial(args.connection.parse()?).await?;
let client = KrataClient::new(transport).await?;
match args.command { match args.command {
Commands::Launch { Commands::Launch {
@ -65,67 +59,56 @@ async fn main() -> Result<()> {
env, env,
run, run,
} => { } => {
let request = LaunchRequest { let request = LaunchGuestRequest {
image, image,
vcpus: cpus, vcpus: cpus,
mem, mem,
env, env: env.unwrap_or_default(),
run: if run.is_empty() { None } else { Some(run) }, run,
}; };
let Response::Launch(response) = client.send(Request::Launch(request)).await? else { let response = client
return Err(anyhow!("invalid response type")); .launch_guest(Request::new(request))
.await?
.into_inner();
let Some(guest) = response.guest else {
return Err(anyhow!(
"control service did not return a guest in the response"
));
}; };
println!("launched guest: {}", response.guest.id); println!("launched guest: {}", guest.id);
if attach { if attach {
let request = ConsoleStreamRequest { let input = StdioConsoleStream::stdin_stream(guest.id).await;
guest: response.guest.id.clone(), let output = client.console_data(input).await?.into_inner();
}; StdioConsoleStream::stdout(output).await?;
let Response::ConsoleStream(response) =
client.send(Request::ConsoleStream(request)).await?
else {
return Err(anyhow!("invalid response type"));
};
let stream = client.acquire(response.stream).await?;
let console = XenConsole::new(stream).await?;
console.attach().await?;
} }
} }
Commands::Destroy { guest } => { Commands::Destroy { guest } => {
let request = DestroyRequest { guest }; let _ = client
let Response::Destroy(response) = client.send(Request::Destroy(request)).await? else { .destroy_guest(Request::new(DestroyGuestRequest {
return Err(anyhow!("invalid response type")); guest_id: guest.clone(),
}; }))
println!("destroyed guest: {}", response.guest); .await?
.into_inner();
println!("destroyed guest: {}", guest);
} }
Commands::Console { guest } => { Commands::Console { guest } => {
let request = ConsoleStreamRequest { guest }; let input = StdioConsoleStream::stdin_stream(guest).await;
let Response::ConsoleStream(response) = let output = client.console_data(input).await?.into_inner();
client.send(Request::ConsoleStream(request)).await? StdioConsoleStream::stdout(output).await?;
else {
return Err(anyhow!("invalid response type"));
};
let stream = client.acquire(response.stream).await?;
let console = XenConsole::new(stream).await?;
console.attach().await?;
} }
Commands::List { .. } => { Commands::List { .. } => {
let request = ListRequest {}; let response = client
let Response::List(response) = client.send(Request::List(request)).await? else { .list_guests(Request::new(ListGuestsRequest {}))
return Err(anyhow!("invalid response type")); .await?
}; .into_inner();
let mut table = cli_tables::Table::new(); let mut table = cli_tables::Table::new();
let header = vec!["uuid", "ipv4", "ipv6", "image"]; let header = vec!["uuid", "ipv4", "ipv6", "image"];
table.push_row(&header)?; table.push_row(&header)?;
for guest in response.guests { for guest in response.guests {
table.push_row_string(&vec![ table.push_row_string(&vec![guest.id, guest.ipv4, guest.ipv6, guest.image])?;
guest.id,
guest.ipv4.unwrap_or("none".to_string()),
guest.ipv6.unwrap_or("none".to_string()),
guest.image,
])?;
} }
if table.num_records() == 1 { if table.num_records() == 1 {
println!("no guests have been launched"); println!("no guests have been launched");

View File

@ -1,249 +1,44 @@
use std::{collections::HashMap, sync::Arc}; use anyhow::Result;
use krata::{control::control_service_client::ControlServiceClient, dial::ControlDialAddress};
use tokio::net::UnixStream;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
use tower::service_fn;
use anyhow::{anyhow, Result}; pub struct ControlClientProvider {}
use krata::{
control::{Message, Request, RequestBox, Response},
stream::{ConnectionStreams, StreamContext},
KRATA_DEFAULT_TCP_PORT, KRATA_DEFAULT_TLS_PORT,
};
use log::{trace, warn};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{TcpStream, UnixStream},
select,
sync::{
mpsc::{channel, Receiver, Sender},
oneshot, Mutex,
},
task::JoinHandle,
};
use tokio_native_tls::{native_tls::TlsConnector, TlsStream};
use tokio_stream::{wrappers::LinesStream, StreamExt};
use url::{Host, Url};
const QUEUE_MAX_LEN: usize = 100; impl ControlClientProvider {
pub async fn dial(addr: ControlDialAddress) -> Result<ControlServiceClient<Channel>> {
pub struct KrataClientTransport { let channel = match addr {
sender: Sender<Message>, ControlDialAddress::UnixSocket { path } => {
receiver: Receiver<Message>, // This URL is not actually used but is required to be specified.
task: JoinHandle<()>, Endpoint::try_from(format!("unix://localhost/{}", path))?
} .connect_with_connector(service_fn(|uri: Uri| {
let path = uri.path().to_string();
impl Drop for KrataClientTransport { UnixStream::connect(path)
fn drop(&mut self) { }))
self.task.abort(); .await?
} }
}
macro_rules! transport_new { ControlDialAddress::Tcp { host, port } => {
($name:ident, $stream:ty, $processor:ident) => { Endpoint::try_from(format!("http://{}:{}", host, port))?
pub async fn $name(stream: $stream) -> Result<Self> { .connect()
let (tx_sender, tx_receiver) = channel::<Message>(QUEUE_MAX_LEN); .await?
let (rx_sender, rx_receiver) = channel::<Message>(QUEUE_MAX_LEN);
let task = tokio::task::spawn(async move {
if let Err(error) =
KrataClientTransport::$processor(stream, rx_sender, tx_receiver).await
{
warn!("failed to process krata transport messages: {}", error);
} }
});
Ok(Self { ControlDialAddress::Tls {
sender: tx_sender, host,
receiver: rx_receiver, port,
task, insecure: _,
}) } => {
let tls_config = ClientTlsConfig::new().domain_name(&host);
let address = format!("https://{}:{}", host, port);
Channel::from_shared(address)?
.tls_config(tls_config)?
.connect()
.await?
} }
}; };
}
macro_rules! transport_processor { Ok(ControlServiceClient::new(channel))
($name:ident, $stream:ty) => {
async fn $name(
stream: $stream,
rx_sender: Sender<Message>,
mut tx_receiver: Receiver<Message>,
) -> Result<()> {
let (read, mut write) = tokio::io::split(stream);
let mut read = LinesStream::new(BufReader::new(read).lines());
loop {
select! {
x = tx_receiver.recv() => match x {
Some(message) => {
let mut line = serde_json::to_string(&message)?;
trace!("sending line '{}'", line);
line.push('\n');
write.write_all(line.as_bytes()).await?;
},
None => {
break;
}
},
x = read.next() => match x {
Some(Ok(line)) => {
let message = serde_json::from_str::<Message>(&line)?;
rx_sender.send(message).await?;
},
Some(Err(error)) => {
return Err(error.into());
},
None => {
break;
}
}
};
}
Ok(())
}
};
}
impl KrataClientTransport {
transport_new!(from_unix, UnixStream, process_unix_stream);
transport_new!(from_tcp, TcpStream, process_tcp_stream);
transport_new!(from_tls_tcp, TlsStream<TcpStream>, process_tls_tcp_stream);
pub async fn dial(url: Url) -> Result<KrataClientTransport> {
match url.scheme() {
"unix" => {
let stream = UnixStream::connect(url.path()).await?;
Ok(KrataClientTransport::from_unix(stream).await?)
}
"tcp" => {
let address = format!(
"{}:{}",
url.host().unwrap_or(Host::Domain("localhost")),
url.port().unwrap_or(KRATA_DEFAULT_TCP_PORT)
);
let stream = TcpStream::connect(address).await?;
Ok(KrataClientTransport::from_tcp(stream).await?)
}
"tls" | "tls-insecure" => {
let insecure = url.scheme() == "tls-insecure";
let host = format!("{}", url.host().unwrap_or(Host::Domain("localhost")));
let address = format!("{}:{}", host, url.port().unwrap_or(KRATA_DEFAULT_TLS_PORT));
let stream = TcpStream::connect(address).await?;
let mut connector = TlsConnector::builder();
if insecure {
connector.danger_accept_invalid_certs(true);
}
let connector = connector.build()?;
let connector = tokio_native_tls::TlsConnector::from(connector);
let stream = connector.connect(&host, stream).await?;
Ok(KrataClientTransport::from_tls_tcp(stream).await?)
}
_ => Err(anyhow!("unsupported url scheme: {}", url.scheme())),
}
}
transport_processor!(process_unix_stream, UnixStream);
transport_processor!(process_tcp_stream, TcpStream);
transport_processor!(process_tls_tcp_stream, TlsStream<TcpStream>);
}
type RequestsMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>;
#[derive(Clone)]
pub struct KrataClient {
tx_sender: Sender<Message>,
next: Arc<Mutex<u64>>,
streams: ConnectionStreams,
requests: RequestsMap,
task: Arc<JoinHandle<()>>,
}
impl KrataClient {
pub async fn new(transport: KrataClientTransport) -> Result<Self> {
let tx_sender = transport.sender.clone();
let streams = ConnectionStreams::new(tx_sender.clone());
let requests = Arc::new(Mutex::new(HashMap::new()));
let task = {
let requests = requests.clone();
let streams = streams.clone();
tokio::task::spawn(async move {
if let Err(error) = KrataClient::process(transport, streams, requests).await {
warn!("failed to process krata client messages: {}", error);
}
})
};
Ok(Self {
tx_sender,
next: Arc::new(Mutex::new(0)),
requests,
streams,
task: Arc::new(task),
})
}
pub async fn send(&self, request: Request) -> Result<Response> {
let id = {
let mut next = self.next.lock().await;
let id = *next;
*next = id + 1;
id
};
let (sender, receiver) = oneshot::channel();
self.requests.lock().await.insert(id, sender);
self.tx_sender
.send(Message::Request(RequestBox { id, request }))
.await?;
let response = receiver.await?;
if let Response::Error(error) = response {
Err(anyhow!("krata error: {}", error.message))
} else {
Ok(response)
}
}
pub async fn acquire(&self, stream: u64) -> Result<StreamContext> {
self.streams.acquire(stream).await
}
async fn process(
mut transport: KrataClientTransport,
streams: ConnectionStreams,
requests: RequestsMap,
) -> Result<()> {
loop {
let Some(message) = transport.receiver.recv().await else {
break;
};
match message {
Message::Request(_) => {
return Err(anyhow!("received request from service"));
}
Message::Response(resp) => {
let Some(sender) = requests.lock().await.remove(&resp.id) else {
continue;
};
let _ = sender.send(resp.response);
}
Message::StreamUpdated(updated) => {
streams.incoming(updated).await?;
}
}
}
Ok(())
}
}
impl Drop for KrataClient {
fn drop(&mut self) {
if Arc::strong_count(&self.task) <= 1 {
self.task.abort();
}
} }
} }

View File

@ -1,75 +1,56 @@
use std::{ use std::{
io::{stdin, stdout}, io::stdout,
os::fd::{AsRawFd, FromRawFd}, os::fd::{AsRawFd, FromRawFd},
}; };
use anyhow::Result; use anyhow::Result;
use krata::{ use async_stream::stream;
control::{ConsoleStreamUpdate, StreamUpdate}, use krata::control::{ConsoleDataReply, ConsoleDataRequest};
stream::StreamContext,
};
use log::debug; use log::debug;
use std::process::exit;
use termion::raw::IntoRawMode; use termion::raw::IntoRawMode;
use tokio::{ use tokio::{
fs::File, fs::File,
io::{AsyncReadExt, AsyncWriteExt}, io::{stdin, AsyncReadExt, AsyncWriteExt},
select,
}; };
use tokio_stream::{Stream, StreamExt};
use tonic::Streaming;
pub struct XenConsole { pub struct StdioConsoleStream;
stream: StreamContext,
}
impl XenConsole { impl StdioConsoleStream {
pub async fn new(stream: StreamContext) -> Result<XenConsole> { pub async fn stdin_stream(guest: String) -> impl Stream<Item = ConsoleDataRequest> {
Ok(XenConsole { stream }) let mut stdin = stdin();
} stream! {
yield ConsoleDataRequest { guest, data: vec![] };
pub async fn attach(self) -> Result<()> {
let stdin = unsafe { File::from_raw_fd(stdin().as_raw_fd()) };
let terminal = stdout().into_raw_mode()?;
let stdout = unsafe { File::from_raw_fd(terminal.as_raw_fd()) };
if let Err(error) = XenConsole::process(stdin, stdout, self.stream).await {
debug!("failed to process console stream: {}", error);
}
Ok(())
}
async fn process(mut stdin: File, mut stdout: File, mut stream: StreamContext) -> Result<()> {
let mut buffer = vec![0u8; 60]; let mut buffer = vec![0u8; 60];
loop { loop {
select! { let size = match stdin.read(&mut buffer).await {
x = stream.receiver.recv() => match x { Ok(size) => size,
Some(StreamUpdate::ConsoleStream(update)) => { Err(error) => {
stdout.write_all(&update.data).await?; debug!("failed to read stdin: {}", error);
stdout.flush().await?;
},
None => {
break; break;
} }
},
x = stdin.read(&mut buffer) => match x {
Ok(size) => {
if size == 1 && buffer[0] == 0x1d {
exit(0);
}
let data = buffer[0..size].to_vec();
stream.send(StreamUpdate::ConsoleStream(ConsoleStreamUpdate {
data,
})).await?;
},
Err(error) => {
return Err(error.into());
}
}
}; };
let data = buffer[0..size].to_vec();
if size == 1 && buffer[0] == 0x1d {
break;
}
yield ConsoleDataRequest { guest: String::default(), data };
}
}
}
pub async fn stdout(mut stream: Streaming<ConsoleDataReply>) -> Result<()> {
let terminal = stdout().into_raw_mode()?;
let mut stdout = unsafe { File::from_raw_fd(terminal.as_raw_fd()) };
while let Some(reply) = stream.next().await {
let reply = reply?;
if reply.data.is_empty() {
continue;
}
stdout.write_all(&reply.data).await?;
stdout.flush().await?;
} }
Ok(()) Ok(())
} }

View File

@ -32,10 +32,8 @@ bytes = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
signal-hook = { workspace = true } signal-hook = { workspace = true }
async-stream = { workspace = true }
[dependencies.tokio-listener] tonic = { workspace = true, features = ["tls"]}
workspace = true
features = ["clap"]
[dependencies.krata] [dependencies.krata]
path = "../shared" path = "../shared"
@ -62,7 +60,3 @@ path = "src/lib.rs"
[[bin]] [[bin]]
name = "kratad" name = "kratad"
path = "bin/daemon.rs" path = "bin/daemon.rs"
[[example]]
name = "kratad-dial"
path = "examples/dial.rs"

View File

@ -1,15 +1,17 @@
use std::sync::{atomic::AtomicBool, Arc}; use anyhow::Result;
use anyhow::{anyhow, Result};
use clap::Parser; use clap::Parser;
use env_logger::Env; use env_logger::Env;
use krata::dial::ControlDialAddress;
use kratad::{runtime::Runtime, Daemon}; use kratad::{runtime::Runtime, Daemon};
use tokio_listener::ListenerAddressLFlag; use std::{
str::FromStr,
sync::{atomic::AtomicBool, Arc},
};
#[derive(Parser)] #[derive(Parser)]
struct Args { struct Args {
#[clap(flatten)] #[arg(short, long, default_value = "unix:///var/lib/krata/daemon.socket")]
listener: ListenerAddressLFlag, listen: String,
#[arg(short, long, default_value = "/var/lib/krata")] #[arg(short, long, default_value = "/var/lib/krata")]
store: String, store: String,
} }
@ -20,12 +22,10 @@ async fn main() -> Result<()> {
mask_sighup()?; mask_sighup()?;
let args = Args::parse(); let args = Args::parse();
let Some(listener) = args.listener.bind().await else { let addr = ControlDialAddress::from_str(&args.listen)?;
return Err(anyhow!("no listener specified"));
};
let runtime = Runtime::new(args.store.clone()).await?; let runtime = Runtime::new(args.store.clone()).await?;
let mut daemon = Daemon::new(runtime).await?; let mut daemon = Daemon::new(args.store.clone(), runtime).await?;
daemon.listen(listener?).await?; daemon.listen(addr).await?;
Ok(()) Ok(())
} }

View File

@ -1,28 +0,0 @@
use anyhow::Result;
use krata::control::{ListRequest, Message, Request, RequestBox};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::TcpStream,
};
use tokio_stream::{wrappers::LinesStream, StreamExt};
#[tokio::main]
async fn main() -> Result<()> {
let mut stream = TcpStream::connect("127.0.0.1:4050").await?;
let (read, mut write) = stream.split();
let mut read = LinesStream::new(BufReader::new(read).lines());
let send = Message::Request(RequestBox {
id: 1,
request: Request::List(ListRequest {}),
});
let mut line = serde_json::to_string(&send)?;
line.push('\n');
write.write_all(line.as_bytes()).await?;
println!("sent: {:?}", send);
while let Some(line) = read.try_next().await? {
let message: Message = serde_json::from_str(&line)?;
println!("received: {:?}", message);
}
Ok(())
}

172
daemon/src/control.rs Normal file
View File

@ -0,0 +1,172 @@
use std::{io, pin::Pin};
use async_stream::try_stream;
use futures::Stream;
use krata::control::{
control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest,
DestroyGuestReply, DestroyGuestRequest, GuestInfo, LaunchGuestReply, LaunchGuestRequest,
ListGuestsReply, ListGuestsRequest,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
};
use tokio_stream::StreamExt;
use tonic::{Request, Response, Status, Streaming};
use crate::runtime::{launch::GuestLaunchRequest, Runtime};
pub struct ApiError {
message: String,
}
impl From<anyhow::Error> for ApiError {
fn from(value: anyhow::Error) -> Self {
ApiError {
message: value.to_string(),
}
}
}
impl From<ApiError> for Status {
fn from(value: ApiError) -> Self {
Status::unknown(value.message)
}
}
#[derive(Clone)]
pub struct RuntimeControlService {
runtime: Runtime,
}
impl RuntimeControlService {
pub fn new(runtime: Runtime) -> Self {
Self { runtime }
}
}
enum ConsoleDataSelect {
Read(io::Result<usize>),
Write(Option<Result<ConsoleDataRequest, tonic::Status>>),
}
#[tonic::async_trait]
impl ControlService for RuntimeControlService {
type ConsoleDataStream =
Pin<Box<dyn Stream<Item = Result<ConsoleDataReply, Status>> + Send + 'static>>;
async fn launch_guest(
&self,
request: Request<LaunchGuestRequest>,
) -> Result<Response<LaunchGuestReply>, Status> {
let request = request.into_inner();
let guest: GuestInfo = self
.runtime
.launch(GuestLaunchRequest {
image: &request.image,
vcpus: request.vcpus,
mem: request.mem,
env: empty_vec_optional(request.env),
run: empty_vec_optional(request.run),
debug: false,
})
.await
.map_err(ApiError::from)?
.into();
Ok(Response::new(LaunchGuestReply { guest: Some(guest) }))
}
async fn destroy_guest(
&self,
request: Request<DestroyGuestRequest>,
) -> Result<Response<DestroyGuestReply>, Status> {
let request = request.into_inner();
self.runtime
.destroy(&request.guest_id)
.await
.map_err(ApiError::from)?;
Ok(Response::new(DestroyGuestReply {}))
}
async fn list_guests(
&self,
request: Request<ListGuestsRequest>,
) -> Result<Response<ListGuestsReply>, Status> {
let _ = request.into_inner();
let guests = self.runtime.list().await.map_err(ApiError::from)?;
let guests = guests
.into_iter()
.map(GuestInfo::from)
.collect::<Vec<GuestInfo>>();
Ok(Response::new(ListGuestsReply { guests }))
}
async fn console_data(
&self,
request: Request<Streaming<ConsoleDataRequest>>,
) -> Result<Response<Self::ConsoleDataStream>, Status> {
let mut input = request.into_inner();
let Some(request) = input.next().await else {
return Err(ApiError {
message: "expected to have at least one request".to_string(),
}
.into());
};
let request = request?;
let mut console = self
.runtime
.console(&request.guest)
.await
.map_err(ApiError::from)?;
let output = try_stream! {
let mut buffer: Vec<u8> = vec![0u8; 256];
loop {
let what = select! {
x = console.read_handle.read(&mut buffer) => ConsoleDataSelect::Read(x),
x = input.next() => ConsoleDataSelect::Write(x),
};
match what {
ConsoleDataSelect::Read(result) => {
let size = result?;
let data = buffer[0..size].to_vec();
yield ConsoleDataReply { data, };
},
ConsoleDataSelect::Write(Some(request)) => {
let request = request?;
if !request.data.is_empty() {
console.write_handle.write_all(&request.data).await?;
}
},
ConsoleDataSelect::Write(None) => {
break;
}
}
}
};
Ok(Response::new(Box::pin(output) as Self::ConsoleDataStream))
}
}
impl From<crate::runtime::GuestInfo> for GuestInfo {
fn from(value: crate::runtime::GuestInfo) -> Self {
GuestInfo {
id: value.uuid.to_string(),
image: value.image,
ipv4: value.ipv4.map(|x| x.ip().to_string()).unwrap_or_default(),
ipv6: value.ipv6.map(|x| x.ip().to_string()).unwrap_or_default(),
}
}
}
fn empty_vec_optional<T>(value: Vec<T>) -> Option<Vec<T>> {
if value.is_empty() {
None
} else {
Some(value)
}
}

View File

@ -1,91 +0,0 @@
use anyhow::{anyhow, Result};
use krata::control::{ConsoleStreamResponse, ConsoleStreamUpdate, Request, Response, StreamUpdate};
use log::warn;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
select,
};
use crate::{
listen::DaemonRequestHandler,
runtime::{console::XenConsole, Runtime},
};
use krata::stream::{ConnectionStreams, StreamContext};
pub struct ConsoleStreamRequestHandler {}
impl Default for ConsoleStreamRequestHandler {
fn default() -> Self {
Self::new()
}
}
impl ConsoleStreamRequestHandler {
pub fn new() -> Self {
Self {}
}
async fn link_console_stream(mut stream: StreamContext, mut console: XenConsole) -> Result<()> {
loop {
let mut buffer = vec![0u8; 256];
select! {
x = console.read_handle.read(&mut buffer) => match x {
Ok(size) => {
let data = buffer[0..size].to_vec();
let update = StreamUpdate::ConsoleStream(ConsoleStreamUpdate {
data,
});
stream.send(update).await?;
},
Err(error) => {
return Err(error.into());
}
},
x = stream.receiver.recv() => match x {
Some(StreamUpdate::ConsoleStream(update)) => {
console.write_handle.write_all(&update.data).await?;
}
None => {
break;
}
}
};
}
Ok(())
}
}
#[async_trait::async_trait]
impl DaemonRequestHandler for ConsoleStreamRequestHandler {
fn accepts(&self, request: &Request) -> bool {
matches!(request, Request::ConsoleStream(_))
}
async fn handle(
&self,
streams: ConnectionStreams,
runtime: Runtime,
request: Request,
) -> Result<Response> {
let console_stream = match request {
Request::ConsoleStream(stream) => stream,
_ => return Err(anyhow!("unknown request")),
};
let console = runtime.console(&console_stream.guest).await?;
let stream = streams.open().await?;
let id = stream.id;
tokio::task::spawn(async move {
if let Err(error) =
ConsoleStreamRequestHandler::link_console_stream(stream, console).await
{
warn!("failed to process console stream: {}", error);
}
});
Ok(Response::ConsoleStream(ConsoleStreamResponse {
stream: id,
}))
}
}

View File

@ -1,44 +0,0 @@
use anyhow::{anyhow, Result};
use krata::{
control::{DestroyResponse, Request, Response},
stream::ConnectionStreams,
};
use crate::{listen::DaemonRequestHandler, runtime::Runtime};
pub struct DestroyRequestHandler {}
impl Default for DestroyRequestHandler {
fn default() -> Self {
Self::new()
}
}
impl DestroyRequestHandler {
pub fn new() -> Self {
Self {}
}
}
#[async_trait::async_trait]
impl DaemonRequestHandler for DestroyRequestHandler {
fn accepts(&self, request: &Request) -> bool {
matches!(request, Request::Destroy(_))
}
async fn handle(
&self,
_: ConnectionStreams,
runtime: Runtime,
request: Request,
) -> Result<Response> {
let destroy = match request {
Request::Destroy(destroy) => destroy,
_ => return Err(anyhow!("unknown request")),
};
let guest = runtime.destroy(&destroy.guest).await?;
Ok(Response::Destroy(DestroyResponse {
guest: guest.to_string(),
}))
}
}

View File

@ -1,55 +0,0 @@
use anyhow::{anyhow, Result};
use krata::{
control::{GuestInfo, LaunchResponse, Request, Response},
stream::ConnectionStreams,
};
use crate::{
listen::DaemonRequestHandler,
runtime::{launch::GuestLaunchRequest, Runtime},
};
pub struct LaunchRequestHandler {}
impl Default for LaunchRequestHandler {
fn default() -> Self {
Self::new()
}
}
impl LaunchRequestHandler {
pub fn new() -> Self {
Self {}
}
}
#[async_trait::async_trait]
impl DaemonRequestHandler for LaunchRequestHandler {
fn accepts(&self, request: &Request) -> bool {
matches!(request, Request::Launch(_))
}
async fn handle(
&self,
_: ConnectionStreams,
runtime: Runtime,
request: Request,
) -> Result<Response> {
let launch = match request {
Request::Launch(launch) => launch,
_ => return Err(anyhow!("unknown request")),
};
let guest: GuestInfo = runtime
.launch(GuestLaunchRequest {
image: &launch.image,
vcpus: launch.vcpus,
mem: launch.mem,
env: launch.env,
run: launch.run,
debug: false,
})
.await?
.into();
Ok(Response::Launch(LaunchResponse { guest }))
}
}

View File

@ -1,37 +0,0 @@
use anyhow::Result;
use krata::{
control::{GuestInfo, ListResponse, Request, Response},
stream::ConnectionStreams,
};
use crate::{listen::DaemonRequestHandler, runtime::Runtime};
pub struct ListRequestHandler {}
impl Default for ListRequestHandler {
fn default() -> Self {
Self::new()
}
}
impl ListRequestHandler {
pub fn new() -> Self {
Self {}
}
}
#[async_trait::async_trait]
impl DaemonRequestHandler for ListRequestHandler {
fn accepts(&self, request: &Request) -> bool {
matches!(request, Request::List(_))
}
async fn handle(&self, _: ConnectionStreams, runtime: Runtime, _: Request) -> Result<Response> {
let guests = runtime.list().await?;
let guests = guests
.into_iter()
.map(GuestInfo::from)
.collect::<Vec<GuestInfo>>();
Ok(Response::List(ListResponse { guests }))
}
}

View File

@ -1,15 +0,0 @@
pub mod console;
pub mod destroy;
pub mod launch;
pub mod list;
impl From<crate::runtime::GuestInfo> for krata::control::GuestInfo {
fn from(value: crate::runtime::GuestInfo) -> Self {
krata::control::GuestInfo {
id: value.uuid.to_string(),
image: value.image.clone(),
ipv4: value.ipv4.map(|x| x.ip().to_string()),
ipv6: value.ipv6.map(|x| x.ip().to_string()),
}
}
}

View File

@ -1,37 +1,74 @@
use anyhow::Result; use std::{net::SocketAddr, path::PathBuf, str::FromStr};
use handlers::{
console::ConsoleStreamRequestHandler, destroy::DestroyRequestHandler,
launch::LaunchRequestHandler, list::ListRequestHandler,
};
use listen::{DaemonListener, DaemonRequestHandlers};
use runtime::Runtime;
use tokio_listener::Listener;
pub mod handlers; use anyhow::Result;
pub mod listen; use control::RuntimeControlService;
use krata::{control::control_service_server::ControlServiceServer, dial::ControlDialAddress};
use log::info;
use runtime::Runtime;
use tokio::net::UnixListener;
use tokio_stream::wrappers::UnixListenerStream;
use tonic::transport::{Identity, Server, ServerTlsConfig};
pub mod control;
pub mod runtime; pub mod runtime;
pub struct Daemon { pub struct Daemon {
store: String,
runtime: Runtime, runtime: Runtime,
} }
impl Daemon { impl Daemon {
pub async fn new(runtime: Runtime) -> Result<Self> { pub async fn new(store: String, runtime: Runtime) -> Result<Self> {
Ok(Self { runtime }) Ok(Self { store, runtime })
} }
pub async fn listen(&mut self, listener: Listener) -> Result<()> { pub async fn listen(&mut self, addr: ControlDialAddress) -> Result<()> {
let handlers = DaemonRequestHandlers::new( let control_service = RuntimeControlService::new(self.runtime.clone());
self.runtime.clone(),
vec![ let mut server = Server::builder();
Box::new(LaunchRequestHandler::new()),
Box::new(DestroyRequestHandler::new()), if let ControlDialAddress::Tls {
Box::new(ConsoleStreamRequestHandler::new()), host: _,
Box::new(ListRequestHandler::new()), port: _,
], insecure,
); } = &addr
let mut listener = DaemonListener::new(listener, handlers); {
listener.handle().await?; let mut tls_config = ServerTlsConfig::new();
if !insecure {
let certificate_path = format!("{}/tls/daemon.pem", self.store);
let key_path = format!("{}/tls/daemon.key", self.store);
tls_config = tls_config.identity(Identity::from_pem(certificate_path, key_path));
}
server = server.tls_config(tls_config)?;
}
let server = server.add_service(ControlServiceServer::new(control_service));
info!("listening on address {}", addr);
match addr {
ControlDialAddress::UnixSocket { path } => {
let path = PathBuf::from(path);
if path.exists() {
tokio::fs::remove_file(&path).await?;
}
let listener = UnixListener::bind(path)?;
let stream = UnixListenerStream::new(listener);
server.serve_with_incoming(stream).await?;
}
ControlDialAddress::Tcp { host, port } => {
let address = format!("{}:{}", host, port);
server.serve(SocketAddr::from_str(&address)?).await?;
}
ControlDialAddress::Tls {
host,
port,
insecure: _,
} => {
let address = format!("{}:{}", host, port);
server.serve(SocketAddr::from_str(&address)?).await?;
}
}
Ok(()) Ok(())
} }
} }

View File

@ -1,228 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use krata::control::{ErrorResponse, Message, Request, RequestBox, Response, ResponseBox};
use log::trace;
use log::warn;
use tokio::sync::Mutex;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
select,
sync::mpsc::{channel, Receiver, Sender},
};
use tokio_listener::{Connection, Listener, SomeSocketAddrClonable};
use tokio_stream::{wrappers::LinesStream, StreamExt};
use crate::runtime::Runtime;
use krata::stream::ConnectionStreams;
const QUEUE_MAX_LEN: usize = 100;
#[async_trait::async_trait]
pub trait DaemonRequestHandler: Send + Sync {
fn accepts(&self, request: &Request) -> bool;
async fn handle(
&self,
streams: ConnectionStreams,
runtime: Runtime,
request: Request,
) -> Result<Response>;
}
#[derive(Clone)]
pub struct DaemonRequestHandlers {
runtime: Runtime,
handlers: Arc<Vec<Box<dyn DaemonRequestHandler>>>,
}
impl DaemonRequestHandlers {
pub fn new(runtime: Runtime, handlers: Vec<Box<dyn DaemonRequestHandler>>) -> Self {
DaemonRequestHandlers {
runtime,
handlers: Arc::new(handlers),
}
}
async fn dispatch(&self, streams: ConnectionStreams, request: Request) -> Result<Response> {
for handler in self.handlers.iter() {
if handler.accepts(&request) {
return handler.handle(streams, self.runtime.clone(), request).await;
}
}
Err(anyhow!("daemon cannot handle that request"))
}
}
pub struct DaemonListener {
listener: Listener,
handlers: DaemonRequestHandlers,
connections: Arc<Mutex<HashMap<u64, DaemonConnection>>>,
next: Arc<Mutex<u64>>,
}
impl DaemonListener {
pub fn new(listener: Listener, handlers: DaemonRequestHandlers) -> DaemonListener {
DaemonListener {
listener,
handlers,
connections: Arc::new(Mutex::new(HashMap::new())),
next: Arc::new(Mutex::new(0)),
}
}
pub async fn handle(&mut self) -> Result<()> {
loop {
let (connection, addr) = self.listener.accept().await?;
let connection =
DaemonConnection::new(connection, addr.clonable(), self.handlers.clone()).await?;
let id = {
let mut next = self.next.lock().await;
let id = *next;
*next = id + 1;
id
};
trace!("new connection from {}", connection.addr);
let tx_channel = connection.tx_sender.clone();
let addr = connection.addr.clone();
self.connections.lock().await.insert(id, connection);
let connections_for_close = self.connections.clone();
tokio::task::spawn(async move {
tx_channel.closed().await;
trace!("connection from {} closed", addr);
connections_for_close.lock().await.remove(&id);
});
}
}
}
#[derive(Clone)]
pub struct DaemonConnection {
tx_sender: Sender<Message>,
addr: SomeSocketAddrClonable,
handlers: DaemonRequestHandlers,
streams: ConnectionStreams,
}
impl DaemonConnection {
pub async fn new(
connection: Connection,
addr: SomeSocketAddrClonable,
handlers: DaemonRequestHandlers,
) -> Result<Self> {
let (tx_sender, tx_receiver) = channel::<Message>(QUEUE_MAX_LEN);
let streams_tx_sender = tx_sender.clone();
let instance = DaemonConnection {
tx_sender,
addr,
handlers,
streams: ConnectionStreams::new(streams_tx_sender),
};
{
let mut instance = instance.clone();
tokio::task::spawn(async move {
if let Err(error) = instance.process(tx_receiver, connection).await {
warn!(
"failed to process daemon connection for {}: {}",
instance.addr, error
);
}
});
}
Ok(instance)
}
async fn process(
&mut self,
mut tx_receiver: Receiver<Message>,
connection: Connection,
) -> Result<()> {
let (read, mut write) = tokio::io::split(connection);
let mut read = LinesStream::new(BufReader::new(read).lines());
loop {
select! {
x = read.next() => match x {
Some(Ok(line)) => {
let message: Message = serde_json::from_str(&line)?;
trace!("received message '{}' from {}", serde_json::to_string(&message)?, self.addr);
let mut context = self.clone();
tokio::task::spawn(async move {
if let Err(error) = context.handle_message(&message).await {
let line = serde_json::to_string(&message).unwrap_or("<invalid>".to_string());
warn!("failed to handle message '{}' from {}: {}", line, context.addr, error);
}
});
},
Some(Err(error)) => {
return Err(error.into());
},
None => {
break;
}
},
x = tx_receiver.recv() => match x {
Some(message) => {
if let Message::StreamUpdated(ref update) = message {
self.streams.outgoing(update).await?;
}
let mut line = serde_json::to_string(&message)?;
trace!("sending message '{}' to {}", line, self.addr);
line.push('\n');
write.write_all(line.as_bytes()).await?;
},
None => {
break;
}
}
};
}
Ok(())
}
async fn handle_message(&mut self, message: &Message) -> Result<()> {
match message {
Message::Request(req) => {
self.handle_request(req.clone()).await?;
}
Message::Response(_) => {
return Err(anyhow!(
"received a response message from client {}, but this is the daemon",
self.addr
));
}
Message::StreamUpdated(updated) => {
self.streams.incoming(updated.clone()).await?;
}
}
Ok(())
}
async fn handle_request(&mut self, req: RequestBox) -> Result<()> {
let id = req.id;
let response = self
.handlers
.dispatch(self.streams.clone(), req.request)
.await
.map_err(|error| {
Response::Error(ErrorResponse {
message: error.to_string(),
})
});
let response = if let Err(response) = response {
response
} else {
response.unwrap()
};
let resp = ResponseBox { id, response };
self.tx_sender.send(Message::Response(resp)).await?;
Ok(())
}
}

View File

@ -5,7 +5,7 @@ Description=Krata Controller Daemon
Restart=on-failure Restart=on-failure
Type=simple Type=simple
WorkingDirectory=/var/lib/krata WorkingDirectory=/var/lib/krata
ExecStart=/usr/local/bin/kratad -l /var/lib/krata/daemon.socket --unix-listen-unlink ExecStart=/usr/local/bin/kratad -l unix:///var/lib/krata/daemon.socket
Environment=RUST_LOG=info Environment=RUST_LOG=info
User=root User=root

View File

@ -10,6 +10,12 @@ serde = { workspace = true }
libc = { workspace = true } libc = { workspace = true }
log = { workspace = true } log = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
url = { workspace = true }
tonic = { workspace = true }
prost = { workspace = true }
[build-dependencies]
tonic-build = { workspace = true }
[dependencies.nix] [dependencies.nix]
workspace = true workspace = true

5
shared/build.rs Normal file
View File

@ -0,0 +1,5 @@
fn main() {
tonic_build::configure()
.compile(&["proto/krata/control.proto"], &["proto"])
.unwrap();
}

View File

@ -0,0 +1,56 @@
syntax = "proto3";
option java_multiple_files = true;
option java_package = "dev.krata.proto.control";
option java_outer_classname = "ControlProto";
package krata.control;
message GuestInfo {
string id = 1;
string image = 2;
string ipv4 = 3;
string ipv6 = 4;
}
message LaunchGuestRequest {
string image = 1;
uint32 vcpus = 2;
uint64 mem = 3;
repeated string env = 4;
repeated string run = 5;
}
message LaunchGuestReply {
GuestInfo guest = 1;
}
message ListGuestsRequest {}
message ListGuestsReply {
repeated GuestInfo guests = 1;
}
message DestroyGuestRequest {
string guest_id = 1;
}
message DestroyGuestReply {}
message ConsoleDataRequest {
string guest = 1;
bytes data = 2;
}
message ConsoleDataReply {
bytes data = 1;
}
service ControlService {
rpc LaunchGuest(LaunchGuestRequest) returns (LaunchGuestReply);
rpc DestroyGuest(DestroyGuestRequest) returns (DestroyGuestReply);
rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply);
rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply);
}

View File

@ -1,115 +1 @@
use serde::{Deserialize, Serialize}; tonic::include_proto!("krata.control");
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuestInfo {
pub id: String,
pub image: String,
pub ipv4: Option<String>,
pub ipv6: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LaunchRequest {
pub image: String,
pub vcpus: u32,
pub mem: u64,
pub env: Option<Vec<String>>,
pub run: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LaunchResponse {
pub guest: GuestInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListRequest {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListResponse {
pub guests: Vec<GuestInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DestroyRequest {
pub guest: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DestroyResponse {
pub guest: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsoleStreamRequest {
pub guest: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsoleStreamResponse {
pub stream: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsoleStreamUpdate {
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Request {
Launch(LaunchRequest),
Destroy(DestroyRequest),
List(ListRequest),
ConsoleStream(ConsoleStreamRequest),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Response {
Error(ErrorResponse),
Launch(LaunchResponse),
Destroy(DestroyResponse),
List(ListResponse),
ConsoleStream(ConsoleStreamResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestBox {
pub id: u64,
pub request: Request,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseBox {
pub id: u64,
pub response: Response,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum StreamStatus {
Open,
Closed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StreamUpdate {
ConsoleStream(ConsoleStreamUpdate),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamUpdated {
pub id: u64,
pub update: Option<StreamUpdate>,
pub status: StreamStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Message {
Request(RequestBox),
Response(ResponseBox),
StreamUpdated(StreamUpdated),
}

100
shared/src/dial.rs Normal file
View File

@ -0,0 +1,100 @@
use std::{fmt::Display, str::FromStr};
use anyhow::anyhow;
use url::{Host, Url};
pub const KRATA_DEFAULT_TCP_PORT: u16 = 4350;
pub const KRATA_DEFAULT_TLS_PORT: u16 = 4353;
#[derive(Clone)]
pub enum ControlDialAddress {
UnixSocket {
path: String,
},
Tcp {
host: String,
port: u16,
},
Tls {
host: String,
port: u16,
insecure: bool,
},
}
impl FromStr for ControlDialAddress {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: Url = s.parse()?;
let host = url.host().unwrap_or(Host::Domain("localhost")).to_string();
match url.scheme() {
"unix" => Ok(ControlDialAddress::UnixSocket {
path: url.path().to_string(),
}),
"tcp" => {
let port = url.port().unwrap_or(KRATA_DEFAULT_TCP_PORT);
Ok(ControlDialAddress::Tcp { host, port })
}
"tls" | "tls-insecure" => {
let insecure = url.scheme() == "tls-insecure";
let port = url.port().unwrap_or(KRATA_DEFAULT_TLS_PORT);
Ok(ControlDialAddress::Tls {
host,
port,
insecure,
})
}
_ => Err(anyhow!("unknown control address scheme: {}", url.scheme())),
}
}
}
impl From<ControlDialAddress> for Url {
fn from(val: ControlDialAddress) -> Self {
match val {
ControlDialAddress::UnixSocket { path } => {
let mut url = Url::parse("unix:///").unwrap();
url.set_path(&path);
url
}
ControlDialAddress::Tcp { host, port } => {
let mut url = Url::parse("tcp://").unwrap();
url.set_host(Some(&host)).unwrap();
if port != KRATA_DEFAULT_TCP_PORT {
url.set_port(Some(port)).unwrap();
}
url
}
ControlDialAddress::Tls {
host,
port,
insecure,
} => {
let mut url = Url::parse("tls://").unwrap();
if insecure {
url.set_scheme("tls-insecure").unwrap();
}
url.set_host(Some(&host)).unwrap();
if port != KRATA_DEFAULT_TLS_PORT {
url.set_port(Some(port)).unwrap();
}
url
}
}
}
}
impl Display for ControlDialAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let url: Url = self.clone().into();
write!(f, "{}", url)
}
}

View File

@ -1,7 +1,4 @@
pub mod control; pub mod control;
pub mod dial;
pub mod ethtool; pub mod ethtool;
pub mod launchcfg; pub mod launchcfg;
pub mod stream;
pub const KRATA_DEFAULT_TCP_PORT: u16 = 4350;
pub const KRATA_DEFAULT_TLS_PORT: u16 = 4353;

View File

@ -1,152 +0,0 @@
use crate::control::{Message, StreamStatus, StreamUpdate, StreamUpdated};
use anyhow::{anyhow, Result};
use log::warn;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{
mpsc::{channel, Receiver, Sender},
Mutex,
};
pub struct StreamContext {
pub id: u64,
pub receiver: Receiver<StreamUpdate>,
sender: Sender<Message>,
}
impl StreamContext {
pub async fn send(&self, update: StreamUpdate) -> Result<()> {
self.sender
.send(Message::StreamUpdated(StreamUpdated {
id: self.id,
update: Some(update),
status: StreamStatus::Open,
}))
.await?;
Ok(())
}
}
impl Drop for StreamContext {
fn drop(&mut self) {
if self.sender.is_closed() {
return;
}
let result = self.sender.try_send(Message::StreamUpdated(StreamUpdated {
id: self.id,
update: None,
status: StreamStatus::Closed,
}));
if let Err(error) = result {
warn!(
"failed to send close message for stream {}: {}",
self.id, error
);
}
}
}
struct StreamStorage {
rx_sender: Sender<StreamUpdate>,
rx_receiver: Option<Receiver<StreamUpdate>>,
}
#[derive(Clone)]
pub struct ConnectionStreams {
next: Arc<Mutex<u64>>,
streams: Arc<Mutex<HashMap<u64, StreamStorage>>>,
tx_sender: Sender<Message>,
}
const QUEUE_MAX_LEN: usize = 100;
impl ConnectionStreams {
pub fn new(tx_sender: Sender<Message>) -> Self {
Self {
next: Arc::new(Mutex::new(0)),
streams: Arc::new(Mutex::new(HashMap::new())),
tx_sender,
}
}
pub async fn open(&self) -> Result<StreamContext> {
let id = {
let mut next = self.next.lock().await;
let id = *next;
*next = id + 1;
id
};
let (rx_sender, rx_receiver) = channel(QUEUE_MAX_LEN);
let store = StreamStorage {
rx_sender,
rx_receiver: None,
};
self.streams.lock().await.insert(id, store);
let open = Message::StreamUpdated(StreamUpdated {
id,
update: None,
status: StreamStatus::Open,
});
self.tx_sender.send(open).await?;
Ok(StreamContext {
id,
sender: self.tx_sender.clone(),
receiver: rx_receiver,
})
}
pub async fn incoming(&self, updated: StreamUpdated) -> Result<()> {
let mut streams = self.streams.lock().await;
if updated.update.is_none() && updated.status == StreamStatus::Open {
let (rx_sender, rx_receiver) = channel(QUEUE_MAX_LEN);
let store = StreamStorage {
rx_sender,
rx_receiver: Some(rx_receiver),
};
streams.insert(updated.id, store);
}
let Some(storage) = streams.get(&updated.id) else {
return Ok(());
};
if let Some(update) = updated.update {
storage.rx_sender.send(update).await?;
}
if updated.status == StreamStatus::Closed {
streams.remove(&updated.id);
}
Ok(())
}
pub async fn outgoing(&self, updated: &StreamUpdated) -> Result<()> {
if updated.status == StreamStatus::Closed {
let mut streams = self.streams.lock().await;
streams.remove(&updated.id);
}
Ok(())
}
pub async fn acquire(&self, id: u64) -> Result<StreamContext> {
let mut streams = self.streams.lock().await;
let Some(storage) = streams.get_mut(&id) else {
return Err(anyhow!("stream {} has not been opened", id));
};
let Some(receiver) = storage.rx_receiver.take() else {
return Err(anyhow!("stream has already been acquired"));
};
Ok(StreamContext {
id,
receiver,
sender: self.tx_sender.clone(),
})
}
}