From 36284221689039f823ef5ffa25b35678660aac15 Mon Sep 17 00:00:00 2001 From: Alex Zenla Date: Wed, 6 Mar 2024 12:05:01 +0000 Subject: [PATCH] krata: utilize gRPC for control service --- Cargo.toml | 10 +- controller/Cargo.toml | 3 + controller/bin/control.rs | 85 ++++------ controller/src/client.rs | 271 ++++--------------------------- controller/src/console.rs | 93 +++++------ daemon/Cargo.toml | 10 +- daemon/bin/daemon.rs | 22 +-- daemon/examples/dial.rs | 28 ---- daemon/src/control.rs | 172 ++++++++++++++++++++ daemon/src/handlers/console.rs | 91 ----------- daemon/src/handlers/destroy.rs | 44 ----- daemon/src/handlers/launch.rs | 55 ------- daemon/src/handlers/list.rs | 37 ----- daemon/src/handlers/mod.rs | 15 -- daemon/src/lib.rs | 85 +++++++--- daemon/src/listen.rs | 228 -------------------------- resources/systemd/kratad.service | 2 +- shared/Cargo.toml | 6 + shared/build.rs | 5 + shared/proto/krata/control.proto | 56 +++++++ shared/src/control.rs | 116 +------------ shared/src/dial.rs | 100 ++++++++++++ shared/src/lib.rs | 5 +- shared/src/stream.rs | 152 ----------------- 24 files changed, 532 insertions(+), 1159 deletions(-) delete mode 100644 daemon/examples/dial.rs create mode 100644 daemon/src/control.rs delete mode 100644 daemon/src/handlers/console.rs delete mode 100644 daemon/src/handlers/destroy.rs delete mode 100644 daemon/src/handlers/launch.rs delete mode 100644 daemon/src/handlers/list.rs delete mode 100644 daemon/src/handlers/mod.rs delete mode 100644 daemon/src/listen.rs create mode 100644 shared/build.rs create mode 100644 shared/proto/krata/control.proto create mode 100644 shared/src/dial.rs delete mode 100644 shared/src/stream.rs diff --git a/Cargo.toml b/Cargo.toml index b6503ae..6896203 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,10 @@ tokio-listener = "0.3.1" trait-variant = "0.1.1" tokio-native-tls = "0.3.1" signal-hook = "0.3.17" +tonic-build = "0.11.0" +prost = "0.12.3" +async-stream = "0.3.5" +tower = "0.4.13" [workspace.dependencies.uuid] version = "1.6.1" @@ -79,7 +83,7 @@ features = ["macros", "rt", "rt-multi-thread", "io-util"] [workspace.dependencies.tokio-stream] version = "0.1" -features = ["io-util"] +features = ["io-util", "net"] [workspace.dependencies.reqwest] version = "0.11.24" @@ -87,3 +91,7 @@ version = "0.11.24" [workspace.dependencies.serde] version = "1.0.196" features = ["derive"] + +[workspace.dependencies.tonic] +version = "0.11.0" +features = ["tls"] diff --git a/controller/Cargo.toml b/controller/Cargo.toml index e4c616c..cd6c48a 100644 --- a/controller/Cargo.toml +++ b/controller/Cargo.toml @@ -17,6 +17,9 @@ tokio = { workspace = true } tokio-stream = { workspace = true } tokio-native-tls = { workspace = true } url = { workspace = true } +tower = { workspace = true } +tonic = { workspace = true} +async-stream = { workspace = true } [dependencies.krata] path = "../shared" diff --git a/controller/bin/control.rs b/controller/bin/control.rs index f1aab14..3d0b6a2 100644 --- a/controller/bin/control.rs +++ b/controller/bin/control.rs @@ -1,14 +1,9 @@ use anyhow::{anyhow, Result}; use clap::{Parser, Subcommand}; use env_logger::Env; -use krata::control::{ - ConsoleStreamRequest, DestroyRequest, LaunchRequest, ListRequest, Request, Response, -}; -use kratactl::{ - client::{KrataClient, KrataClientTransport}, - console::XenConsole, -}; -use url::Url; +use krata::control::{DestroyGuestRequest, LaunchGuestRequest, ListGuestsRequest}; +use kratactl::{client::ControlClientProvider, console::StdioConsoleStream}; +use tonic::Request; #[derive(Parser, Debug)] #[command(version, about)] @@ -53,8 +48,7 @@ async fn main() -> Result<()> { env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init(); let args = ControllerArgs::parse(); - let transport = KrataClientTransport::dial(Url::parse(&args.connection)?).await?; - let client = KrataClient::new(transport).await?; + let mut client = ControlClientProvider::dial(args.connection.parse()?).await?; match args.command { Commands::Launch { @@ -65,67 +59,56 @@ async fn main() -> Result<()> { env, run, } => { - let request = LaunchRequest { + let request = LaunchGuestRequest { image, vcpus: cpus, mem, - env, - run: if run.is_empty() { None } else { Some(run) }, + env: env.unwrap_or_default(), + run, }; - let Response::Launch(response) = client.send(Request::Launch(request)).await? else { - return Err(anyhow!("invalid response type")); + let response = client + .launch_guest(Request::new(request)) + .await? + .into_inner(); + let Some(guest) = response.guest else { + return Err(anyhow!( + "control service did not return a guest in the response" + )); }; - println!("launched guest: {}", response.guest.id); + println!("launched guest: {}", guest.id); if attach { - let request = ConsoleStreamRequest { - guest: response.guest.id.clone(), - }; - let Response::ConsoleStream(response) = - client.send(Request::ConsoleStream(request)).await? - else { - return Err(anyhow!("invalid response type")); - }; - let stream = client.acquire(response.stream).await?; - let console = XenConsole::new(stream).await?; - console.attach().await?; + let input = StdioConsoleStream::stdin_stream(guest.id).await; + let output = client.console_data(input).await?.into_inner(); + StdioConsoleStream::stdout(output).await?; } } Commands::Destroy { guest } => { - let request = DestroyRequest { guest }; - let Response::Destroy(response) = client.send(Request::Destroy(request)).await? else { - return Err(anyhow!("invalid response type")); - }; - println!("destroyed guest: {}", response.guest); + let _ = client + .destroy_guest(Request::new(DestroyGuestRequest { + guest_id: guest.clone(), + })) + .await? + .into_inner(); + println!("destroyed guest: {}", guest); } Commands::Console { guest } => { - let request = ConsoleStreamRequest { guest }; - let Response::ConsoleStream(response) = - client.send(Request::ConsoleStream(request)).await? - else { - return Err(anyhow!("invalid response type")); - }; - let stream = client.acquire(response.stream).await?; - let console = XenConsole::new(stream).await?; - console.attach().await?; + let input = StdioConsoleStream::stdin_stream(guest).await; + let output = client.console_data(input).await?.into_inner(); + StdioConsoleStream::stdout(output).await?; } Commands::List { .. } => { - let request = ListRequest {}; - let Response::List(response) = client.send(Request::List(request)).await? else { - return Err(anyhow!("invalid response type")); - }; + let response = client + .list_guests(Request::new(ListGuestsRequest {})) + .await? + .into_inner(); let mut table = cli_tables::Table::new(); let header = vec!["uuid", "ipv4", "ipv6", "image"]; table.push_row(&header)?; for guest in response.guests { - table.push_row_string(&vec![ - guest.id, - guest.ipv4.unwrap_or("none".to_string()), - guest.ipv6.unwrap_or("none".to_string()), - guest.image, - ])?; + table.push_row_string(&vec![guest.id, guest.ipv4, guest.ipv6, guest.image])?; } if table.num_records() == 1 { println!("no guests have been launched"); diff --git a/controller/src/client.rs b/controller/src/client.rs index 1e06dd1..44fe5ed 100644 --- a/controller/src/client.rs +++ b/controller/src/client.rs @@ -1,249 +1,44 @@ -use std::{collections::HashMap, sync::Arc}; +use anyhow::Result; +use krata::{control::control_service_client::ControlServiceClient, dial::ControlDialAddress}; +use tokio::net::UnixStream; +use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri}; +use tower::service_fn; -use anyhow::{anyhow, Result}; -use krata::{ - control::{Message, Request, RequestBox, Response}, - stream::{ConnectionStreams, StreamContext}, - KRATA_DEFAULT_TCP_PORT, KRATA_DEFAULT_TLS_PORT, -}; -use log::{trace, warn}; -use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - net::{TcpStream, UnixStream}, - select, - sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, Mutex, - }, - task::JoinHandle, -}; -use tokio_native_tls::{native_tls::TlsConnector, TlsStream}; -use tokio_stream::{wrappers::LinesStream, StreamExt}; -use url::{Host, Url}; +pub struct ControlClientProvider {} -const QUEUE_MAX_LEN: usize = 100; - -pub struct KrataClientTransport { - sender: Sender, - receiver: Receiver, - task: JoinHandle<()>, -} - -impl Drop for KrataClientTransport { - fn drop(&mut self) { - self.task.abort(); - } -} - -macro_rules! transport_new { - ($name:ident, $stream:ty, $processor:ident) => { - pub async fn $name(stream: $stream) -> Result { - let (tx_sender, tx_receiver) = channel::(QUEUE_MAX_LEN); - let (rx_sender, rx_receiver) = channel::(QUEUE_MAX_LEN); - - let task = tokio::task::spawn(async move { - if let Err(error) = - KrataClientTransport::$processor(stream, rx_sender, tx_receiver).await - { - warn!("failed to process krata transport messages: {}", error); - } - }); - - Ok(Self { - sender: tx_sender, - receiver: rx_receiver, - task, - }) - } - }; -} - -macro_rules! transport_processor { - ($name:ident, $stream:ty) => { - async fn $name( - stream: $stream, - rx_sender: Sender, - mut tx_receiver: Receiver, - ) -> Result<()> { - let (read, mut write) = tokio::io::split(stream); - let mut read = LinesStream::new(BufReader::new(read).lines()); - loop { - select! { - x = tx_receiver.recv() => match x { - Some(message) => { - let mut line = serde_json::to_string(&message)?; - trace!("sending line '{}'", line); - line.push('\n'); - write.write_all(line.as_bytes()).await?; - }, - - None => { - break; - } - }, - - x = read.next() => match x { - Some(Ok(line)) => { - let message = serde_json::from_str::(&line)?; - rx_sender.send(message).await?; - }, - - Some(Err(error)) => { - return Err(error.into()); - }, - - None => { - break; - } - } - }; - } - Ok(()) - } - }; -} - -impl KrataClientTransport { - transport_new!(from_unix, UnixStream, process_unix_stream); - transport_new!(from_tcp, TcpStream, process_tcp_stream); - transport_new!(from_tls_tcp, TlsStream, process_tls_tcp_stream); - - pub async fn dial(url: Url) -> Result { - match url.scheme() { - "unix" => { - let stream = UnixStream::connect(url.path()).await?; - Ok(KrataClientTransport::from_unix(stream).await?) +impl ControlClientProvider { + pub async fn dial(addr: ControlDialAddress) -> Result> { + let channel = match addr { + ControlDialAddress::UnixSocket { path } => { + // This URL is not actually used but is required to be specified. + Endpoint::try_from(format!("unix://localhost/{}", path))? + .connect_with_connector(service_fn(|uri: Uri| { + let path = uri.path().to_string(); + UnixStream::connect(path) + })) + .await? } - "tcp" => { - let address = format!( - "{}:{}", - url.host().unwrap_or(Host::Domain("localhost")), - url.port().unwrap_or(KRATA_DEFAULT_TCP_PORT) - ); - let stream = TcpStream::connect(address).await?; - Ok(KrataClientTransport::from_tcp(stream).await?) + ControlDialAddress::Tcp { host, port } => { + Endpoint::try_from(format!("http://{}:{}", host, port))? + .connect() + .await? } - "tls" | "tls-insecure" => { - let insecure = url.scheme() == "tls-insecure"; - let host = format!("{}", url.host().unwrap_or(Host::Domain("localhost"))); - let address = format!("{}:{}", host, url.port().unwrap_or(KRATA_DEFAULT_TLS_PORT)); - let stream = TcpStream::connect(address).await?; - let mut connector = TlsConnector::builder(); - if insecure { - connector.danger_accept_invalid_certs(true); - } - let connector = connector.build()?; - let connector = tokio_native_tls::TlsConnector::from(connector); - let stream = connector.connect(&host, stream).await?; - Ok(KrataClientTransport::from_tls_tcp(stream).await?) + ControlDialAddress::Tls { + host, + port, + insecure: _, + } => { + let tls_config = ClientTlsConfig::new().domain_name(&host); + let address = format!("https://{}:{}", host, port); + Channel::from_shared(address)? + .tls_config(tls_config)? + .connect() + .await? } - - _ => Err(anyhow!("unsupported url scheme: {}", url.scheme())), - } - } - - transport_processor!(process_unix_stream, UnixStream); - transport_processor!(process_tcp_stream, TcpStream); - transport_processor!(process_tls_tcp_stream, TlsStream); -} - -type RequestsMap = Arc>>>; - -#[derive(Clone)] -pub struct KrataClient { - tx_sender: Sender, - next: Arc>, - streams: ConnectionStreams, - requests: RequestsMap, - task: Arc>, -} - -impl KrataClient { - pub async fn new(transport: KrataClientTransport) -> Result { - let tx_sender = transport.sender.clone(); - let streams = ConnectionStreams::new(tx_sender.clone()); - let requests = Arc::new(Mutex::new(HashMap::new())); - let task = { - let requests = requests.clone(); - let streams = streams.clone(); - tokio::task::spawn(async move { - if let Err(error) = KrataClient::process(transport, streams, requests).await { - warn!("failed to process krata client messages: {}", error); - } - }) }; - Ok(Self { - tx_sender, - next: Arc::new(Mutex::new(0)), - requests, - streams, - task: Arc::new(task), - }) - } - - pub async fn send(&self, request: Request) -> Result { - let id = { - let mut next = self.next.lock().await; - let id = *next; - *next = id + 1; - id - }; - let (sender, receiver) = oneshot::channel(); - self.requests.lock().await.insert(id, sender); - self.tx_sender - .send(Message::Request(RequestBox { id, request })) - .await?; - let response = receiver.await?; - if let Response::Error(error) = response { - Err(anyhow!("krata error: {}", error.message)) - } else { - Ok(response) - } - } - - pub async fn acquire(&self, stream: u64) -> Result { - self.streams.acquire(stream).await - } - - async fn process( - mut transport: KrataClientTransport, - streams: ConnectionStreams, - requests: RequestsMap, - ) -> Result<()> { - loop { - let Some(message) = transport.receiver.recv().await else { - break; - }; - - match message { - Message::Request(_) => { - return Err(anyhow!("received request from service")); - } - - Message::Response(resp) => { - let Some(sender) = requests.lock().await.remove(&resp.id) else { - continue; - }; - - let _ = sender.send(resp.response); - } - - Message::StreamUpdated(updated) => { - streams.incoming(updated).await?; - } - } - } - Ok(()) - } -} - -impl Drop for KrataClient { - fn drop(&mut self) { - if Arc::strong_count(&self.task) <= 1 { - self.task.abort(); - } + Ok(ControlServiceClient::new(channel)) } } diff --git a/controller/src/console.rs b/controller/src/console.rs index 45b1d36..ded6a2b 100644 --- a/controller/src/console.rs +++ b/controller/src/console.rs @@ -1,75 +1,56 @@ use std::{ - io::{stdin, stdout}, + io::stdout, os::fd::{AsRawFd, FromRawFd}, }; use anyhow::Result; -use krata::{ - control::{ConsoleStreamUpdate, StreamUpdate}, - stream::StreamContext, -}; +use async_stream::stream; +use krata::control::{ConsoleDataReply, ConsoleDataRequest}; use log::debug; -use std::process::exit; use termion::raw::IntoRawMode; use tokio::{ fs::File, - io::{AsyncReadExt, AsyncWriteExt}, - select, + io::{stdin, AsyncReadExt, AsyncWriteExt}, }; +use tokio_stream::{Stream, StreamExt}; +use tonic::Streaming; -pub struct XenConsole { - stream: StreamContext, -} +pub struct StdioConsoleStream; -impl XenConsole { - pub async fn new(stream: StreamContext) -> Result { - Ok(XenConsole { stream }) - } +impl StdioConsoleStream { + pub async fn stdin_stream(guest: String) -> impl Stream { + let mut stdin = stdin(); + stream! { + yield ConsoleDataRequest { guest, data: vec![] }; - pub async fn attach(self) -> Result<()> { - let stdin = unsafe { File::from_raw_fd(stdin().as_raw_fd()) }; - let terminal = stdout().into_raw_mode()?; - let stdout = unsafe { File::from_raw_fd(terminal.as_raw_fd()) }; - - if let Err(error) = XenConsole::process(stdin, stdout, self.stream).await { - debug!("failed to process console stream: {}", error); - } - - Ok(()) - } - - async fn process(mut stdin: File, mut stdout: File, mut stream: StreamContext) -> Result<()> { - let mut buffer = vec![0u8; 60]; - loop { - select! { - x = stream.receiver.recv() => match x { - Some(StreamUpdate::ConsoleStream(update)) => { - stdout.write_all(&update.data).await?; - stdout.flush().await?; - }, - - None => { + let mut buffer = vec![0u8; 60]; + loop { + let size = match stdin.read(&mut buffer).await { + Ok(size) => size, + Err(error) => { + debug!("failed to read stdin: {}", error); break; } - }, - - x = stdin.read(&mut buffer) => match x { - Ok(size) => { - if size == 1 && buffer[0] == 0x1d { - exit(0); - } - - let data = buffer[0..size].to_vec(); - stream.send(StreamUpdate::ConsoleStream(ConsoleStreamUpdate { - data, - })).await?; - }, - - Err(error) => { - return Err(error.into()); - } + }; + let data = buffer[0..size].to_vec(); + if size == 1 && buffer[0] == 0x1d { + break; } - }; + yield ConsoleDataRequest { guest: String::default(), data }; + } + } + } + + pub async fn stdout(mut stream: Streaming) -> Result<()> { + let terminal = stdout().into_raw_mode()?; + let mut stdout = unsafe { File::from_raw_fd(terminal.as_raw_fd()) }; + while let Some(reply) = stream.next().await { + let reply = reply?; + if reply.data.is_empty() { + continue; + } + stdout.write_all(&reply.data).await?; + stdout.flush().await?; } Ok(()) } diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index ee9335e..7447eae 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -32,10 +32,8 @@ bytes = { workspace = true } tokio-stream = { workspace = true } async-trait = { workspace = true } signal-hook = { workspace = true } - -[dependencies.tokio-listener] -workspace = true -features = ["clap"] +async-stream = { workspace = true } +tonic = { workspace = true, features = ["tls"]} [dependencies.krata] path = "../shared" @@ -62,7 +60,3 @@ path = "src/lib.rs" [[bin]] name = "kratad" path = "bin/daemon.rs" - -[[example]] -name = "kratad-dial" -path = "examples/dial.rs" diff --git a/daemon/bin/daemon.rs b/daemon/bin/daemon.rs index 0741174..407beb9 100644 --- a/daemon/bin/daemon.rs +++ b/daemon/bin/daemon.rs @@ -1,15 +1,17 @@ -use std::sync::{atomic::AtomicBool, Arc}; - -use anyhow::{anyhow, Result}; +use anyhow::Result; use clap::Parser; use env_logger::Env; +use krata::dial::ControlDialAddress; use kratad::{runtime::Runtime, Daemon}; -use tokio_listener::ListenerAddressLFlag; +use std::{ + str::FromStr, + sync::{atomic::AtomicBool, Arc}, +}; #[derive(Parser)] struct Args { - #[clap(flatten)] - listener: ListenerAddressLFlag, + #[arg(short, long, default_value = "unix:///var/lib/krata/daemon.socket")] + listen: String, #[arg(short, long, default_value = "/var/lib/krata")] store: String, } @@ -20,12 +22,10 @@ async fn main() -> Result<()> { mask_sighup()?; let args = Args::parse(); - let Some(listener) = args.listener.bind().await else { - return Err(anyhow!("no listener specified")); - }; + let addr = ControlDialAddress::from_str(&args.listen)?; let runtime = Runtime::new(args.store.clone()).await?; - let mut daemon = Daemon::new(runtime).await?; - daemon.listen(listener?).await?; + let mut daemon = Daemon::new(args.store.clone(), runtime).await?; + daemon.listen(addr).await?; Ok(()) } diff --git a/daemon/examples/dial.rs b/daemon/examples/dial.rs deleted file mode 100644 index 193823d..0000000 --- a/daemon/examples/dial.rs +++ /dev/null @@ -1,28 +0,0 @@ -use anyhow::Result; -use krata::control::{ListRequest, Message, Request, RequestBox}; -use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - net::TcpStream, -}; -use tokio_stream::{wrappers::LinesStream, StreamExt}; - -#[tokio::main] -async fn main() -> Result<()> { - let mut stream = TcpStream::connect("127.0.0.1:4050").await?; - let (read, mut write) = stream.split(); - let mut read = LinesStream::new(BufReader::new(read).lines()); - - let send = Message::Request(RequestBox { - id: 1, - request: Request::List(ListRequest {}), - }); - let mut line = serde_json::to_string(&send)?; - line.push('\n'); - write.write_all(line.as_bytes()).await?; - println!("sent: {:?}", send); - while let Some(line) = read.try_next().await? { - let message: Message = serde_json::from_str(&line)?; - println!("received: {:?}", message); - } - Ok(()) -} diff --git a/daemon/src/control.rs b/daemon/src/control.rs new file mode 100644 index 0000000..efeaf28 --- /dev/null +++ b/daemon/src/control.rs @@ -0,0 +1,172 @@ +use std::{io, pin::Pin}; + +use async_stream::try_stream; +use futures::Stream; +use krata::control::{ + control_service_server::ControlService, ConsoleDataReply, ConsoleDataRequest, + DestroyGuestReply, DestroyGuestRequest, GuestInfo, LaunchGuestReply, LaunchGuestRequest, + ListGuestsReply, ListGuestsRequest, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, +}; +use tokio_stream::StreamExt; +use tonic::{Request, Response, Status, Streaming}; + +use crate::runtime::{launch::GuestLaunchRequest, Runtime}; + +pub struct ApiError { + message: String, +} + +impl From for ApiError { + fn from(value: anyhow::Error) -> Self { + ApiError { + message: value.to_string(), + } + } +} + +impl From for Status { + fn from(value: ApiError) -> Self { + Status::unknown(value.message) + } +} + +#[derive(Clone)] +pub struct RuntimeControlService { + runtime: Runtime, +} + +impl RuntimeControlService { + pub fn new(runtime: Runtime) -> Self { + Self { runtime } + } +} + +enum ConsoleDataSelect { + Read(io::Result), + Write(Option>), +} + +#[tonic::async_trait] +impl ControlService for RuntimeControlService { + type ConsoleDataStream = + Pin> + Send + 'static>>; + + async fn launch_guest( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let guest: GuestInfo = self + .runtime + .launch(GuestLaunchRequest { + image: &request.image, + vcpus: request.vcpus, + mem: request.mem, + env: empty_vec_optional(request.env), + run: empty_vec_optional(request.run), + debug: false, + }) + .await + .map_err(ApiError::from)? + .into(); + Ok(Response::new(LaunchGuestReply { guest: Some(guest) })) + } + + async fn destroy_guest( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + self.runtime + .destroy(&request.guest_id) + .await + .map_err(ApiError::from)?; + Ok(Response::new(DestroyGuestReply {})) + } + + async fn list_guests( + &self, + request: Request, + ) -> Result, Status> { + let _ = request.into_inner(); + let guests = self.runtime.list().await.map_err(ApiError::from)?; + let guests = guests + .into_iter() + .map(GuestInfo::from) + .collect::>(); + Ok(Response::new(ListGuestsReply { guests })) + } + + async fn console_data( + &self, + request: Request>, + ) -> Result, Status> { + let mut input = request.into_inner(); + let Some(request) = input.next().await else { + return Err(ApiError { + message: "expected to have at least one request".to_string(), + } + .into()); + }; + let request = request?; + let mut console = self + .runtime + .console(&request.guest) + .await + .map_err(ApiError::from)?; + + let output = try_stream! { + let mut buffer: Vec = vec![0u8; 256]; + loop { + let what = select! { + x = console.read_handle.read(&mut buffer) => ConsoleDataSelect::Read(x), + x = input.next() => ConsoleDataSelect::Write(x), + }; + + match what { + ConsoleDataSelect::Read(result) => { + let size = result?; + let data = buffer[0..size].to_vec(); + yield ConsoleDataReply { data, }; + }, + + ConsoleDataSelect::Write(Some(request)) => { + let request = request?; + if !request.data.is_empty() { + console.write_handle.write_all(&request.data).await?; + } + }, + + ConsoleDataSelect::Write(None) => { + break; + } + } + } + }; + + Ok(Response::new(Box::pin(output) as Self::ConsoleDataStream)) + } +} + +impl From for GuestInfo { + fn from(value: crate::runtime::GuestInfo) -> Self { + GuestInfo { + id: value.uuid.to_string(), + image: value.image, + ipv4: value.ipv4.map(|x| x.ip().to_string()).unwrap_or_default(), + ipv6: value.ipv6.map(|x| x.ip().to_string()).unwrap_or_default(), + } + } +} + +fn empty_vec_optional(value: Vec) -> Option> { + if value.is_empty() { + None + } else { + Some(value) + } +} diff --git a/daemon/src/handlers/console.rs b/daemon/src/handlers/console.rs deleted file mode 100644 index 404d7f2..0000000 --- a/daemon/src/handlers/console.rs +++ /dev/null @@ -1,91 +0,0 @@ -use anyhow::{anyhow, Result}; -use krata::control::{ConsoleStreamResponse, ConsoleStreamUpdate, Request, Response, StreamUpdate}; -use log::warn; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - select, -}; - -use crate::{ - listen::DaemonRequestHandler, - runtime::{console::XenConsole, Runtime}, -}; -use krata::stream::{ConnectionStreams, StreamContext}; -pub struct ConsoleStreamRequestHandler {} - -impl Default for ConsoleStreamRequestHandler { - fn default() -> Self { - Self::new() - } -} - -impl ConsoleStreamRequestHandler { - pub fn new() -> Self { - Self {} - } - - async fn link_console_stream(mut stream: StreamContext, mut console: XenConsole) -> Result<()> { - loop { - let mut buffer = vec![0u8; 256]; - select! { - x = console.read_handle.read(&mut buffer) => match x { - Ok(size) => { - let data = buffer[0..size].to_vec(); - let update = StreamUpdate::ConsoleStream(ConsoleStreamUpdate { - data, - }); - stream.send(update).await?; - }, - - Err(error) => { - return Err(error.into()); - } - }, - - x = stream.receiver.recv() => match x { - Some(StreamUpdate::ConsoleStream(update)) => { - console.write_handle.write_all(&update.data).await?; - } - - None => { - break; - } - } - }; - } - Ok(()) - } -} - -#[async_trait::async_trait] -impl DaemonRequestHandler for ConsoleStreamRequestHandler { - fn accepts(&self, request: &Request) -> bool { - matches!(request, Request::ConsoleStream(_)) - } - - async fn handle( - &self, - streams: ConnectionStreams, - runtime: Runtime, - request: Request, - ) -> Result { - let console_stream = match request { - Request::ConsoleStream(stream) => stream, - _ => return Err(anyhow!("unknown request")), - }; - let console = runtime.console(&console_stream.guest).await?; - let stream = streams.open().await?; - let id = stream.id; - tokio::task::spawn(async move { - if let Err(error) = - ConsoleStreamRequestHandler::link_console_stream(stream, console).await - { - warn!("failed to process console stream: {}", error); - } - }); - - Ok(Response::ConsoleStream(ConsoleStreamResponse { - stream: id, - })) - } -} diff --git a/daemon/src/handlers/destroy.rs b/daemon/src/handlers/destroy.rs deleted file mode 100644 index 7af1e13..0000000 --- a/daemon/src/handlers/destroy.rs +++ /dev/null @@ -1,44 +0,0 @@ -use anyhow::{anyhow, Result}; -use krata::{ - control::{DestroyResponse, Request, Response}, - stream::ConnectionStreams, -}; - -use crate::{listen::DaemonRequestHandler, runtime::Runtime}; - -pub struct DestroyRequestHandler {} - -impl Default for DestroyRequestHandler { - fn default() -> Self { - Self::new() - } -} - -impl DestroyRequestHandler { - pub fn new() -> Self { - Self {} - } -} - -#[async_trait::async_trait] -impl DaemonRequestHandler for DestroyRequestHandler { - fn accepts(&self, request: &Request) -> bool { - matches!(request, Request::Destroy(_)) - } - - async fn handle( - &self, - _: ConnectionStreams, - runtime: Runtime, - request: Request, - ) -> Result { - let destroy = match request { - Request::Destroy(destroy) => destroy, - _ => return Err(anyhow!("unknown request")), - }; - let guest = runtime.destroy(&destroy.guest).await?; - Ok(Response::Destroy(DestroyResponse { - guest: guest.to_string(), - })) - } -} diff --git a/daemon/src/handlers/launch.rs b/daemon/src/handlers/launch.rs deleted file mode 100644 index 2fa575d..0000000 --- a/daemon/src/handlers/launch.rs +++ /dev/null @@ -1,55 +0,0 @@ -use anyhow::{anyhow, Result}; -use krata::{ - control::{GuestInfo, LaunchResponse, Request, Response}, - stream::ConnectionStreams, -}; - -use crate::{ - listen::DaemonRequestHandler, - runtime::{launch::GuestLaunchRequest, Runtime}, -}; - -pub struct LaunchRequestHandler {} - -impl Default for LaunchRequestHandler { - fn default() -> Self { - Self::new() - } -} - -impl LaunchRequestHandler { - pub fn new() -> Self { - Self {} - } -} - -#[async_trait::async_trait] -impl DaemonRequestHandler for LaunchRequestHandler { - fn accepts(&self, request: &Request) -> bool { - matches!(request, Request::Launch(_)) - } - - async fn handle( - &self, - _: ConnectionStreams, - runtime: Runtime, - request: Request, - ) -> Result { - let launch = match request { - Request::Launch(launch) => launch, - _ => return Err(anyhow!("unknown request")), - }; - let guest: GuestInfo = runtime - .launch(GuestLaunchRequest { - image: &launch.image, - vcpus: launch.vcpus, - mem: launch.mem, - env: launch.env, - run: launch.run, - debug: false, - }) - .await? - .into(); - Ok(Response::Launch(LaunchResponse { guest })) - } -} diff --git a/daemon/src/handlers/list.rs b/daemon/src/handlers/list.rs deleted file mode 100644 index 2e48b5d..0000000 --- a/daemon/src/handlers/list.rs +++ /dev/null @@ -1,37 +0,0 @@ -use anyhow::Result; -use krata::{ - control::{GuestInfo, ListResponse, Request, Response}, - stream::ConnectionStreams, -}; - -use crate::{listen::DaemonRequestHandler, runtime::Runtime}; - -pub struct ListRequestHandler {} - -impl Default for ListRequestHandler { - fn default() -> Self { - Self::new() - } -} - -impl ListRequestHandler { - pub fn new() -> Self { - Self {} - } -} - -#[async_trait::async_trait] -impl DaemonRequestHandler for ListRequestHandler { - fn accepts(&self, request: &Request) -> bool { - matches!(request, Request::List(_)) - } - - async fn handle(&self, _: ConnectionStreams, runtime: Runtime, _: Request) -> Result { - let guests = runtime.list().await?; - let guests = guests - .into_iter() - .map(GuestInfo::from) - .collect::>(); - Ok(Response::List(ListResponse { guests })) - } -} diff --git a/daemon/src/handlers/mod.rs b/daemon/src/handlers/mod.rs deleted file mode 100644 index 46406fd..0000000 --- a/daemon/src/handlers/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -pub mod console; -pub mod destroy; -pub mod launch; -pub mod list; - -impl From for krata::control::GuestInfo { - fn from(value: crate::runtime::GuestInfo) -> Self { - krata::control::GuestInfo { - id: value.uuid.to_string(), - image: value.image.clone(), - ipv4: value.ipv4.map(|x| x.ip().to_string()), - ipv6: value.ipv6.map(|x| x.ip().to_string()), - } - } -} diff --git a/daemon/src/lib.rs b/daemon/src/lib.rs index f3a0e4c..a4c0a16 100644 --- a/daemon/src/lib.rs +++ b/daemon/src/lib.rs @@ -1,37 +1,74 @@ -use anyhow::Result; -use handlers::{ - console::ConsoleStreamRequestHandler, destroy::DestroyRequestHandler, - launch::LaunchRequestHandler, list::ListRequestHandler, -}; -use listen::{DaemonListener, DaemonRequestHandlers}; -use runtime::Runtime; -use tokio_listener::Listener; +use std::{net::SocketAddr, path::PathBuf, str::FromStr}; -pub mod handlers; -pub mod listen; +use anyhow::Result; +use control::RuntimeControlService; +use krata::{control::control_service_server::ControlServiceServer, dial::ControlDialAddress}; +use log::info; +use runtime::Runtime; +use tokio::net::UnixListener; +use tokio_stream::wrappers::UnixListenerStream; +use tonic::transport::{Identity, Server, ServerTlsConfig}; + +pub mod control; pub mod runtime; pub struct Daemon { + store: String, runtime: Runtime, } impl Daemon { - pub async fn new(runtime: Runtime) -> Result { - Ok(Self { runtime }) + pub async fn new(store: String, runtime: Runtime) -> Result { + Ok(Self { store, runtime }) } - pub async fn listen(&mut self, listener: Listener) -> Result<()> { - let handlers = DaemonRequestHandlers::new( - self.runtime.clone(), - vec![ - Box::new(LaunchRequestHandler::new()), - Box::new(DestroyRequestHandler::new()), - Box::new(ConsoleStreamRequestHandler::new()), - Box::new(ListRequestHandler::new()), - ], - ); - let mut listener = DaemonListener::new(listener, handlers); - listener.handle().await?; + pub async fn listen(&mut self, addr: ControlDialAddress) -> Result<()> { + let control_service = RuntimeControlService::new(self.runtime.clone()); + + let mut server = Server::builder(); + + if let ControlDialAddress::Tls { + host: _, + port: _, + insecure, + } = &addr + { + let mut tls_config = ServerTlsConfig::new(); + if !insecure { + let certificate_path = format!("{}/tls/daemon.pem", self.store); + let key_path = format!("{}/tls/daemon.key", self.store); + tls_config = tls_config.identity(Identity::from_pem(certificate_path, key_path)); + } + server = server.tls_config(tls_config)?; + } + + let server = server.add_service(ControlServiceServer::new(control_service)); + info!("listening on address {}", addr); + match addr { + ControlDialAddress::UnixSocket { path } => { + let path = PathBuf::from(path); + if path.exists() { + tokio::fs::remove_file(&path).await?; + } + let listener = UnixListener::bind(path)?; + let stream = UnixListenerStream::new(listener); + server.serve_with_incoming(stream).await?; + } + + ControlDialAddress::Tcp { host, port } => { + let address = format!("{}:{}", host, port); + server.serve(SocketAddr::from_str(&address)?).await?; + } + + ControlDialAddress::Tls { + host, + port, + insecure: _, + } => { + let address = format!("{}:{}", host, port); + server.serve(SocketAddr::from_str(&address)?).await?; + } + } Ok(()) } } diff --git a/daemon/src/listen.rs b/daemon/src/listen.rs deleted file mode 100644 index 9f0f0b5..0000000 --- a/daemon/src/listen.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::{anyhow, Result}; -use krata::control::{ErrorResponse, Message, Request, RequestBox, Response, ResponseBox}; -use log::trace; -use log::warn; -use tokio::sync::Mutex; -use tokio::{ - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - select, - sync::mpsc::{channel, Receiver, Sender}, -}; -use tokio_listener::{Connection, Listener, SomeSocketAddrClonable}; -use tokio_stream::{wrappers::LinesStream, StreamExt}; - -use crate::runtime::Runtime; -use krata::stream::ConnectionStreams; - -const QUEUE_MAX_LEN: usize = 100; - -#[async_trait::async_trait] -pub trait DaemonRequestHandler: Send + Sync { - fn accepts(&self, request: &Request) -> bool; - async fn handle( - &self, - streams: ConnectionStreams, - runtime: Runtime, - request: Request, - ) -> Result; -} - -#[derive(Clone)] -pub struct DaemonRequestHandlers { - runtime: Runtime, - handlers: Arc>>, -} - -impl DaemonRequestHandlers { - pub fn new(runtime: Runtime, handlers: Vec>) -> Self { - DaemonRequestHandlers { - runtime, - handlers: Arc::new(handlers), - } - } - - async fn dispatch(&self, streams: ConnectionStreams, request: Request) -> Result { - for handler in self.handlers.iter() { - if handler.accepts(&request) { - return handler.handle(streams, self.runtime.clone(), request).await; - } - } - Err(anyhow!("daemon cannot handle that request")) - } -} - -pub struct DaemonListener { - listener: Listener, - handlers: DaemonRequestHandlers, - connections: Arc>>, - next: Arc>, -} - -impl DaemonListener { - pub fn new(listener: Listener, handlers: DaemonRequestHandlers) -> DaemonListener { - DaemonListener { - listener, - handlers, - connections: Arc::new(Mutex::new(HashMap::new())), - next: Arc::new(Mutex::new(0)), - } - } - - pub async fn handle(&mut self) -> Result<()> { - loop { - let (connection, addr) = self.listener.accept().await?; - let connection = - DaemonConnection::new(connection, addr.clonable(), self.handlers.clone()).await?; - let id = { - let mut next = self.next.lock().await; - let id = *next; - *next = id + 1; - id - }; - trace!("new connection from {}", connection.addr); - let tx_channel = connection.tx_sender.clone(); - let addr = connection.addr.clone(); - self.connections.lock().await.insert(id, connection); - let connections_for_close = self.connections.clone(); - tokio::task::spawn(async move { - tx_channel.closed().await; - trace!("connection from {} closed", addr); - connections_for_close.lock().await.remove(&id); - }); - } - } -} - -#[derive(Clone)] -pub struct DaemonConnection { - tx_sender: Sender, - addr: SomeSocketAddrClonable, - handlers: DaemonRequestHandlers, - streams: ConnectionStreams, -} - -impl DaemonConnection { - pub async fn new( - connection: Connection, - addr: SomeSocketAddrClonable, - handlers: DaemonRequestHandlers, - ) -> Result { - let (tx_sender, tx_receiver) = channel::(QUEUE_MAX_LEN); - let streams_tx_sender = tx_sender.clone(); - let instance = DaemonConnection { - tx_sender, - addr, - handlers, - streams: ConnectionStreams::new(streams_tx_sender), - }; - - { - let mut instance = instance.clone(); - tokio::task::spawn(async move { - if let Err(error) = instance.process(tx_receiver, connection).await { - warn!( - "failed to process daemon connection for {}: {}", - instance.addr, error - ); - } - }); - } - - Ok(instance) - } - - async fn process( - &mut self, - mut tx_receiver: Receiver, - connection: Connection, - ) -> Result<()> { - let (read, mut write) = tokio::io::split(connection); - let mut read = LinesStream::new(BufReader::new(read).lines()); - - loop { - select! { - x = read.next() => match x { - Some(Ok(line)) => { - let message: Message = serde_json::from_str(&line)?; - trace!("received message '{}' from {}", serde_json::to_string(&message)?, self.addr); - let mut context = self.clone(); - tokio::task::spawn(async move { - if let Err(error) = context.handle_message(&message).await { - let line = serde_json::to_string(&message).unwrap_or("".to_string()); - warn!("failed to handle message '{}' from {}: {}", line, context.addr, error); - } - }); - }, - - Some(Err(error)) => { - return Err(error.into()); - }, - - None => { - break; - } - }, - - x = tx_receiver.recv() => match x { - Some(message) => { - if let Message::StreamUpdated(ref update) = message { - self.streams.outgoing(update).await?; - } - let mut line = serde_json::to_string(&message)?; - trace!("sending message '{}' to {}", line, self.addr); - line.push('\n'); - write.write_all(line.as_bytes()).await?; - }, - None => { - break; - } - } - }; - } - Ok(()) - } - - async fn handle_message(&mut self, message: &Message) -> Result<()> { - match message { - Message::Request(req) => { - self.handle_request(req.clone()).await?; - } - - Message::Response(_) => { - return Err(anyhow!( - "received a response message from client {}, but this is the daemon", - self.addr - )); - } - - Message::StreamUpdated(updated) => { - self.streams.incoming(updated.clone()).await?; - } - } - Ok(()) - } - - async fn handle_request(&mut self, req: RequestBox) -> Result<()> { - let id = req.id; - let response = self - .handlers - .dispatch(self.streams.clone(), req.request) - .await - .map_err(|error| { - Response::Error(ErrorResponse { - message: error.to_string(), - }) - }); - let response = if let Err(response) = response { - response - } else { - response.unwrap() - }; - let resp = ResponseBox { id, response }; - self.tx_sender.send(Message::Response(resp)).await?; - Ok(()) - } -} diff --git a/resources/systemd/kratad.service b/resources/systemd/kratad.service index dac8460..84053e6 100644 --- a/resources/systemd/kratad.service +++ b/resources/systemd/kratad.service @@ -5,7 +5,7 @@ Description=Krata Controller Daemon Restart=on-failure Type=simple WorkingDirectory=/var/lib/krata -ExecStart=/usr/local/bin/kratad -l /var/lib/krata/daemon.socket --unix-listen-unlink +ExecStart=/usr/local/bin/kratad -l unix:///var/lib/krata/daemon.socket Environment=RUST_LOG=info User=root diff --git a/shared/Cargo.toml b/shared/Cargo.toml index 45ad9e1..0e31bd3 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -10,6 +10,12 @@ serde = { workspace = true } libc = { workspace = true } log = { workspace = true } tokio = { workspace = true } +url = { workspace = true } +tonic = { workspace = true } +prost = { workspace = true } + +[build-dependencies] +tonic-build = { workspace = true } [dependencies.nix] workspace = true diff --git a/shared/build.rs b/shared/build.rs new file mode 100644 index 0000000..6a18bf2 --- /dev/null +++ b/shared/build.rs @@ -0,0 +1,5 @@ +fn main() { + tonic_build::configure() + .compile(&["proto/krata/control.proto"], &["proto"]) + .unwrap(); +} diff --git a/shared/proto/krata/control.proto b/shared/proto/krata/control.proto new file mode 100644 index 0000000..c3801bb --- /dev/null +++ b/shared/proto/krata/control.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "dev.krata.proto.control"; +option java_outer_classname = "ControlProto"; + +package krata.control; + +message GuestInfo { + string id = 1; + string image = 2; + string ipv4 = 3; + string ipv6 = 4; +} + +message LaunchGuestRequest { + string image = 1; + uint32 vcpus = 2; + uint64 mem = 3; + repeated string env = 4; + repeated string run = 5; +} + +message LaunchGuestReply { + GuestInfo guest = 1; +} + +message ListGuestsRequest {} + +message ListGuestsReply { + repeated GuestInfo guests = 1; +} + +message DestroyGuestRequest { + string guest_id = 1; +} + +message DestroyGuestReply {} + +message ConsoleDataRequest { + string guest = 1; + bytes data = 2; +} + +message ConsoleDataReply { + bytes data = 1; +} + +service ControlService { + rpc LaunchGuest(LaunchGuestRequest) returns (LaunchGuestReply); + rpc DestroyGuest(DestroyGuestRequest) returns (DestroyGuestReply); + + rpc ListGuests(ListGuestsRequest) returns (ListGuestsReply); + + rpc ConsoleData(stream ConsoleDataRequest) returns (stream ConsoleDataReply); +} diff --git a/shared/src/control.rs b/shared/src/control.rs index 2f19243..5576d5c 100644 --- a/shared/src/control.rs +++ b/shared/src/control.rs @@ -1,115 +1 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GuestInfo { - pub id: String, - pub image: String, - pub ipv4: Option, - pub ipv6: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LaunchRequest { - pub image: String, - pub vcpus: u32, - pub mem: u64, - pub env: Option>, - pub run: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LaunchResponse { - pub guest: GuestInfo, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ListRequest {} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ListResponse { - pub guests: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DestroyRequest { - pub guest: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DestroyResponse { - pub guest: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConsoleStreamRequest { - pub guest: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConsoleStreamResponse { - pub stream: u64, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ConsoleStreamUpdate { - pub data: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErrorResponse { - pub message: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Request { - Launch(LaunchRequest), - Destroy(DestroyRequest), - List(ListRequest), - ConsoleStream(ConsoleStreamRequest), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Response { - Error(ErrorResponse), - Launch(LaunchResponse), - Destroy(DestroyResponse), - List(ListResponse), - ConsoleStream(ConsoleStreamResponse), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RequestBox { - pub id: u64, - pub request: Request, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResponseBox { - pub id: u64, - pub response: Response, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum StreamStatus { - Open, - Closed, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum StreamUpdate { - ConsoleStream(ConsoleStreamUpdate), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamUpdated { - pub id: u64, - pub update: Option, - pub status: StreamStatus, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Message { - Request(RequestBox), - Response(ResponseBox), - StreamUpdated(StreamUpdated), -} +tonic::include_proto!("krata.control"); diff --git a/shared/src/dial.rs b/shared/src/dial.rs new file mode 100644 index 0000000..4f44c39 --- /dev/null +++ b/shared/src/dial.rs @@ -0,0 +1,100 @@ +use std::{fmt::Display, str::FromStr}; + +use anyhow::anyhow; +use url::{Host, Url}; + +pub const KRATA_DEFAULT_TCP_PORT: u16 = 4350; +pub const KRATA_DEFAULT_TLS_PORT: u16 = 4353; + +#[derive(Clone)] +pub enum ControlDialAddress { + UnixSocket { + path: String, + }, + Tcp { + host: String, + port: u16, + }, + Tls { + host: String, + port: u16, + insecure: bool, + }, +} + +impl FromStr for ControlDialAddress { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse()?; + + let host = url.host().unwrap_or(Host::Domain("localhost")).to_string(); + + match url.scheme() { + "unix" => Ok(ControlDialAddress::UnixSocket { + path: url.path().to_string(), + }), + + "tcp" => { + let port = url.port().unwrap_or(KRATA_DEFAULT_TCP_PORT); + Ok(ControlDialAddress::Tcp { host, port }) + } + + "tls" | "tls-insecure" => { + let insecure = url.scheme() == "tls-insecure"; + let port = url.port().unwrap_or(KRATA_DEFAULT_TLS_PORT); + Ok(ControlDialAddress::Tls { + host, + port, + insecure, + }) + } + + _ => Err(anyhow!("unknown control address scheme: {}", url.scheme())), + } + } +} + +impl From for Url { + fn from(val: ControlDialAddress) -> Self { + match val { + ControlDialAddress::UnixSocket { path } => { + let mut url = Url::parse("unix:///").unwrap(); + url.set_path(&path); + url + } + + ControlDialAddress::Tcp { host, port } => { + let mut url = Url::parse("tcp://").unwrap(); + url.set_host(Some(&host)).unwrap(); + if port != KRATA_DEFAULT_TCP_PORT { + url.set_port(Some(port)).unwrap(); + } + url + } + + ControlDialAddress::Tls { + host, + port, + insecure, + } => { + let mut url = Url::parse("tls://").unwrap(); + if insecure { + url.set_scheme("tls-insecure").unwrap(); + } + url.set_host(Some(&host)).unwrap(); + if port != KRATA_DEFAULT_TLS_PORT { + url.set_port(Some(port)).unwrap(); + } + url + } + } + } +} + +impl Display for ControlDialAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let url: Url = self.clone().into(); + write!(f, "{}", url) + } +} diff --git a/shared/src/lib.rs b/shared/src/lib.rs index 43c5b0a..5aceb24 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -1,7 +1,4 @@ pub mod control; +pub mod dial; pub mod ethtool; pub mod launchcfg; -pub mod stream; - -pub const KRATA_DEFAULT_TCP_PORT: u16 = 4350; -pub const KRATA_DEFAULT_TLS_PORT: u16 = 4353; diff --git a/shared/src/stream.rs b/shared/src/stream.rs deleted file mode 100644 index 60388d0..0000000 --- a/shared/src/stream.rs +++ /dev/null @@ -1,152 +0,0 @@ -use crate::control::{Message, StreamStatus, StreamUpdate, StreamUpdated}; -use anyhow::{anyhow, Result}; -use log::warn; -use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{ - mpsc::{channel, Receiver, Sender}, - Mutex, -}; - -pub struct StreamContext { - pub id: u64, - pub receiver: Receiver, - sender: Sender, -} - -impl StreamContext { - pub async fn send(&self, update: StreamUpdate) -> Result<()> { - self.sender - .send(Message::StreamUpdated(StreamUpdated { - id: self.id, - update: Some(update), - status: StreamStatus::Open, - })) - .await?; - Ok(()) - } -} - -impl Drop for StreamContext { - fn drop(&mut self) { - if self.sender.is_closed() { - return; - } - let result = self.sender.try_send(Message::StreamUpdated(StreamUpdated { - id: self.id, - update: None, - status: StreamStatus::Closed, - })); - - if let Err(error) = result { - warn!( - "failed to send close message for stream {}: {}", - self.id, error - ); - } - } -} - -struct StreamStorage { - rx_sender: Sender, - rx_receiver: Option>, -} - -#[derive(Clone)] -pub struct ConnectionStreams { - next: Arc>, - streams: Arc>>, - tx_sender: Sender, -} - -const QUEUE_MAX_LEN: usize = 100; - -impl ConnectionStreams { - pub fn new(tx_sender: Sender) -> Self { - Self { - next: Arc::new(Mutex::new(0)), - streams: Arc::new(Mutex::new(HashMap::new())), - tx_sender, - } - } - - pub async fn open(&self) -> Result { - let id = { - let mut next = self.next.lock().await; - let id = *next; - *next = id + 1; - id - }; - - let (rx_sender, rx_receiver) = channel(QUEUE_MAX_LEN); - let store = StreamStorage { - rx_sender, - rx_receiver: None, - }; - - self.streams.lock().await.insert(id, store); - - let open = Message::StreamUpdated(StreamUpdated { - id, - update: None, - status: StreamStatus::Open, - }); - self.tx_sender.send(open).await?; - - Ok(StreamContext { - id, - sender: self.tx_sender.clone(), - receiver: rx_receiver, - }) - } - - pub async fn incoming(&self, updated: StreamUpdated) -> Result<()> { - let mut streams = self.streams.lock().await; - if updated.update.is_none() && updated.status == StreamStatus::Open { - let (rx_sender, rx_receiver) = channel(QUEUE_MAX_LEN); - let store = StreamStorage { - rx_sender, - rx_receiver: Some(rx_receiver), - }; - streams.insert(updated.id, store); - } - - let Some(storage) = streams.get(&updated.id) else { - return Ok(()); - }; - - if let Some(update) = updated.update { - storage.rx_sender.send(update).await?; - } - - if updated.status == StreamStatus::Closed { - streams.remove(&updated.id); - } - - Ok(()) - } - - pub async fn outgoing(&self, updated: &StreamUpdated) -> Result<()> { - if updated.status == StreamStatus::Closed { - let mut streams = self.streams.lock().await; - streams.remove(&updated.id); - } - Ok(()) - } - - pub async fn acquire(&self, id: u64) -> Result { - let mut streams = self.streams.lock().await; - let Some(storage) = streams.get_mut(&id) else { - return Err(anyhow!("stream {} has not been opened", id)); - }; - - let Some(receiver) = storage.rx_receiver.take() else { - return Err(anyhow!("stream has already been acquired")); - }; - - Ok(StreamContext { - id, - receiver, - sender: self.tx_sender.clone(), - }) - } -}