feat: now added ssh

This commit is contained in:
robcholz
2026-02-06 18:32:09 -05:00
parent 52bcd88d7e
commit cd76a4f4e0
5 changed files with 654 additions and 48 deletions
+123 -42
View File
@@ -15,6 +15,7 @@ use std::{
process::{Command, Stdio},
sync::{
Arc, Condvar, Mutex,
atomic::{AtomicBool, Ordering},
mpsc::{self, Receiver, Sender},
},
thread,
@@ -43,7 +44,7 @@ const LOGIN_EXPECT_TIMEOUT: Duration = Duration::from_secs(120);
const PROVISION_SCRIPT: &str = include_str!("provision.sh");
#[derive(Clone)]
enum LoginAction {
pub(crate) enum LoginAction {
Expect { text: String, timeout: Duration },
Send(String),
Script { path: PathBuf, index: usize },
@@ -51,14 +52,14 @@ enum LoginAction {
use LoginAction::*;
#[derive(Clone)]
struct DirectoryShare {
pub(crate) struct DirectoryShare {
host: PathBuf,
guest: PathBuf,
read_only: bool,
}
impl DirectoryShare {
fn new(
pub(crate) fn new(
host: PathBuf,
mut guest: PathBuf,
read_only: bool,
@@ -76,7 +77,7 @@ impl DirectoryShare {
})
}
fn from_mount_spec(spec: &str) -> Result<Self, Box<dyn std::error::Error>> {
pub(crate) fn from_mount_spec(spec: &str) -> Result<Self, Box<dyn std::error::Error>> {
let parts: Vec<&str> = spec.split(':').collect();
if parts.len() < 2 || parts.len() > 3 {
return Err(format!("Invalid mount spec: {spec}").into());
@@ -115,23 +116,19 @@ impl DirectoryShare {
}
}
pub fn run_cli() -> Result<(), Box<dyn std::error::Error>> {
let args = parse_cli()?;
if args.version() {
print_version();
return Ok(());
}
if args.help() {
print_help();
return Ok(());
}
run_with_args(args, spawn_vm_io)
pub fn run_with_args<F>(args: CliArgs, io_handler: F) -> Result<(), Box<dyn std::error::Error>>
where
F: FnOnce(Arc<OutputMonitor>, OwnedFd, OwnedFd) -> IoContext,
{
run_with_args_and_extras(args, io_handler, Vec::new(), Vec::new())
}
pub fn run_with_args<F>(args: CliArgs, io_handler: F) -> Result<(), Box<dyn std::error::Error>>
pub(crate) fn run_with_args_and_extras<F>(
args: CliArgs,
io_handler: F,
extra_login_actions: Vec<LoginAction>,
extra_directory_shares: Vec<DirectoryShare>,
) -> Result<(), Box<dyn std::error::Error>>
where
F: FnOnce(Arc<OutputMonitor>, OwnedFd, OwnedFd) -> IoContext,
{
@@ -233,6 +230,8 @@ where
}
}
directory_shares.extend(extra_directory_shares);
for spec in &args.mounts {
directory_shares.push(DirectoryShare::from_mount_spec(spec)?);
}
@@ -241,6 +240,8 @@ where
login_actions.push(motd_action);
}
login_actions.extend(extra_login_actions);
// Any user-provided login actions must come after our system ones
login_actions.extend(args.login_actions);
@@ -461,18 +462,6 @@ fn motd_login_action(directory_shares: &[DirectoryShare]) -> Option<LoginAction>
}
let mut output = String::new();
output.push_str(
"
░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░▒▓███████▓▒░░▒▓████████▓▒░
░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░
░▒▓█▓▒▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░
░▒▓█▓▒▒▓█▓▒░░▒▓█▓▒░▒▓███████▓▒░░▒▓██████▓▒░
░▒▓█▓▓█▓▒░ ░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░
░▒▓█▓▓█▓▒░ ░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░
░▒▓██▓▒░ ░▒▓█▓▒░▒▓███████▓▒░░▒▓████████▓▒░
",
);
output.push_str(&format!(
"{host_header:<host_width$} {guest_header:<guest_width$} {mode_header}\n",
host_width = host_width
@@ -548,6 +537,49 @@ impl OutputMonitor {
}
}
#[derive(Debug)]
pub struct IoControl {
forward_input: AtomicBool,
forward_output: AtomicBool,
restore_terminal: AtomicBool,
}
impl IoControl {
pub fn new() -> Arc<Self> {
Arc::new(Self {
forward_input: AtomicBool::new(true),
forward_output: AtomicBool::new(true),
restore_terminal: AtomicBool::new(false),
})
}
pub fn set_forward_input(&self, enabled: bool) {
self.forward_input.store(enabled, Ordering::SeqCst);
}
pub fn set_forward_output(&self, enabled: bool) {
self.forward_output.store(enabled, Ordering::SeqCst);
}
pub fn request_terminal_restore(&self) {
self.restore_terminal.store(true, Ordering::SeqCst);
}
fn forward_input(&self) -> bool {
self.forward_input.load(Ordering::SeqCst)
}
fn forward_output(&self) -> bool {
self.forward_output.load(Ordering::SeqCst)
}
fn take_restore_terminal(&self) -> bool {
self.restore_terminal
.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}
}
fn ensure_base_image(
base_raw: &Path,
base_compressed: &Path,
@@ -667,14 +699,17 @@ pub fn create_pipe() -> (OwnedFd, OwnedFd) {
(read_stream.into(), write_stream.into())
}
pub fn spawn_vm_io_with_line_handler<F>(
pub fn spawn_vm_io_with_hooks<F, G>(
output_monitor: Arc<OutputMonitor>,
vm_output_fd: OwnedFd,
vm_input_fd: OwnedFd,
io_control: Arc<IoControl>,
mut on_line: F,
mut on_output: G,
) -> IoContext
where
F: FnMut(&str) -> bool + ::std::marker::Send + 'static,
G: FnMut(&[u8]) + ::std::marker::Send + 'static,
{
let (input_tx, input_rx): (Sender<VmInput>, Receiver<VmInput>) = mpsc::channel();
@@ -721,17 +756,36 @@ where
}
}
fn poll_wakeup_only(wakeup_fd: RawFd, timeout_ms: i32) -> bool {
let mut fds = [libc::pollfd {
fd: wakeup_fd,
events: libc::POLLIN,
revents: 0,
}];
let ret = unsafe { libc::poll(fds.as_mut_ptr(), 1, timeout_ms) };
ret > 0 && fds[0].revents & libc::POLLIN != 0
}
// Copies from stdin to the VM; also polls wakeup_read to exit the thread when it's time to shutdown.
let stdin_thread = thread::spawn({
let input_tx = input_tx.clone();
let raw_guard = raw_guard.clone();
let wakeup_read = wakeup_read.try_clone().unwrap();
let io_control = io_control.clone();
move || {
let mut buf = [0u8; 1024];
let mut pending_command: Vec<u8> = Vec::new();
let mut command_mode = false;
loop {
if !io_control.forward_input() {
if poll_wakeup_only(wakeup_read.as_raw_fd(), 100) {
break;
}
continue;
}
match poll_with_wakeup(libc::STDIN_FILENO, wakeup_read.as_raw_fd(), &mut buf) {
PollResult::Shutdown | PollResult::Error => break,
PollResult::Spurious => continue,
@@ -776,28 +830,36 @@ where
let stdout_thread = thread::spawn({
let raw_guard = raw_guard.clone();
let wakeup_read = wakeup_read.try_clone().unwrap();
let io_control = io_control.clone();
move || {
let mut stdout = std::io::stdout().lock();
let mut buf = [0u8; 1024];
loop {
if io_control.take_restore_terminal() {
let mut guard = raw_guard.lock().unwrap();
*guard = None;
}
match poll_with_wakeup(vm_output_fd.as_raw_fd(), wakeup_read.as_raw_fd(), &mut buf)
{
PollResult::Shutdown | PollResult::Error => break,
PollResult::Spurious => continue,
PollResult::Ready(bytes) => {
// enable raw mode, if we haven't already
if raw_guard.lock().unwrap().is_none()
&& let Ok(guard) = enable_raw_mode(libc::STDIN_FILENO)
{
*raw_guard.lock().unwrap() = Some(guard);
}
if io_control.forward_output() {
// enable raw mode, if we haven't already
if raw_guard.lock().unwrap().is_none()
&& let Ok(guard) = enable_raw_mode(libc::STDIN_FILENO)
{
*raw_guard.lock().unwrap() = Some(guard);
}
if stdout.write_all(bytes).is_err() {
break;
let mut stdout = std::io::stdout().lock();
if stdout.write_all(bytes).is_err() {
break;
}
let _ = stdout.flush();
}
let _ = stdout.flush();
output_monitor.push(bytes);
on_output(bytes);
}
}
}
@@ -829,6 +891,25 @@ where
}
}
pub fn spawn_vm_io_with_line_handler<F>(
output_monitor: Arc<OutputMonitor>,
vm_output_fd: OwnedFd,
vm_input_fd: OwnedFd,
on_line: F,
) -> IoContext
where
F: FnMut(&str) -> bool + ::std::marker::Send + 'static,
{
spawn_vm_io_with_hooks(
output_monitor,
vm_output_fd,
vm_input_fd,
IoControl::new(),
on_line,
|_| {},
)
}
pub fn spawn_vm_io(
output_monitor: Arc<OutputMonitor>,
vm_output_fd: OwnedFd,