From 064bf787c48065a7653514ab0b9c619bb38658e7 Mon Sep 17 00:00:00 2001 From: Jobdori Date: Sat, 4 Apr 2026 00:53:06 +0900 Subject: [PATCH] feat(runtime): session control API Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- rust/crates/runtime/src/lib.rs | 5 +- rust/crates/runtime/src/session_control.rs | 644 ++++++++++++++++++++- 2 files changed, 642 insertions(+), 7 deletions(-) diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index f4ce913..f98a7db 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -17,9 +17,9 @@ pub mod permission_enforcer; mod permissions; mod prompt; mod remote; -pub mod session_control; pub mod sandbox; mod session; +pub mod session_control; mod sse; pub mod task_registry; pub mod team_cron_registry; @@ -100,6 +100,9 @@ pub use session::{ ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError, SessionFork, }; +pub use session_control::{ + SessionControlCommand, SessionControlResponse, SessionController, WorkerHandle, +}; pub use sse::{IncrementalSseParser, SseEvent}; pub use worker_boot::{ Worker, WorkerEvent, WorkerEventKind, WorkerFailure, WorkerFailureKind, WorkerReadySnapshot, diff --git a/rust/crates/runtime/src/session_control.rs b/rust/crates/runtime/src/session_control.rs index 2192156..96907ad 100644 --- a/rust/crates/runtime/src/session_control.rs +++ b/rust/crates/runtime/src/session_control.rs @@ -1,13 +1,15 @@ use std::env; use std::fmt::{Display, Formatter}; use std::fs; -use std::path::{Path, PathBuf}; -use std::time::UNIX_EPOCH; +use std::path::{Component, Path, PathBuf}; +use std::time::{Duration, Instant, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; use crate::session::{Session, SessionError}; -use crate::worker_boot::{Worker, WorkerReadySnapshot, WorkerRegistry, WorkerStatus}; +use crate::worker_boot::{ + Worker, WorkerFailure, WorkerReadySnapshot, WorkerRegistry, WorkerStatus, +}; pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl"; pub const LEGACY_SESSION_EXTENSION: &str = "json"; @@ -98,7 +100,7 @@ pub fn create_managed_session_handle_for( base_dir: impl AsRef, session_id: &str, ) -> Result { - let id = session_id.to_string(); + let id = validate_managed_session_id(session_id)?.to_string(); let path = managed_sessions_dir_for(base_dir)?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}")); Ok(SessionHandle { id, path }) @@ -152,6 +154,7 @@ pub fn resolve_managed_session_path_for( base_dir: impl AsRef, session_id: &str, ) -> Result { + let session_id = validate_managed_session_id(session_id)?; let directory = managed_sessions_dir_for(base_dir)?; for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] { let path = directory.join(format!("{session_id}.{extension}")); @@ -308,6 +311,27 @@ pub fn is_session_reference_alias(reference: &str) -> bool { .any(|alias| reference.eq_ignore_ascii_case(alias)) } +fn validate_managed_session_id(session_id: &str) -> Result<&str, SessionControlError> { + if session_id.is_empty() { + return Err(SessionControlError::Format( + "session id must not be empty".to_string(), + )); + } + if session_id.contains(['/', '\\']) { + return Err(SessionControlError::Format(format!( + "invalid managed session id `{session_id}`" + ))); + } + + let mut components = Path::new(session_id).components(); + match (components.next(), components.next()) { + (Some(Component::Normal(_)), None) => Ok(session_id), + _ => Err(SessionControlError::Format(format!( + "invalid managed session id `{session_id}`" + ))), + } +} + fn session_id_from_path(path: &Path) -> Option { path.file_name() .and_then(|value| value.to_str()) @@ -334,8 +358,8 @@ fn format_no_managed_sessions() -> String { mod tests { use super::{ create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias, - list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for, - ManagedSessionSummary, LATEST_SESSION_REFERENCE, + list_managed_sessions_for, load_managed_session_for, resolve_managed_session_path_for, + resolve_session_reference_for, ManagedSessionSummary, LATEST_SESSION_REFERENCE, }; use crate::session::Session; use std::fs; @@ -458,4 +482,612 @@ mod tests { ); fs::remove_dir_all(root).expect("temp dir should clean up"); } + + #[test] + fn rejects_managed_session_ids_with_path_traversal() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir should exist"); + + // when + let handle_error = create_managed_session_handle_for(&root, "../escape") + .expect_err("path traversal session id should be rejected"); + let resolve_error = resolve_managed_session_path_for(&root, "..") + .expect_err("path traversal managed session id should be rejected"); + + // then + assert!(handle_error + .to_string() + .contains("invalid managed session id")); + assert!(resolve_error + .to_string() + .contains("invalid managed session id")); + fs::remove_dir_all(root).expect("temp dir should clean up"); + } +} + +// --------------------------------------------------------------------------- +// Structured session control API — worker lifecycle commands +// --------------------------------------------------------------------------- + +/// Lightweight projection of a [`Worker`] for command responses. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct WorkerHandle { + pub id: String, + pub state: WorkerStatus, + pub created_at: u64, + pub last_error: Option, +} + +/// Commands accepted by [`SessionController::execute_command`]. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SessionControlCommand { + CreateWorker, + AwaitReady { timeout: u64 }, + SendTask { task: String }, + FetchState, + FetchLastError, + RestartWorker, + TerminateWorker, +} + +/// Responses returned by [`SessionController::execute_command`]. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SessionControlResponse { + WorkerCreated { + handle: WorkerHandle, + }, + ReadyStatus { + snapshot: WorkerReadySnapshot, + }, + TaskSent { + handle: WorkerHandle, + }, + State { + handle: WorkerHandle, + }, + LastError { + worker_id: String, + error: Option, + }, + WorkerRestarted { + handle: WorkerHandle, + }, + WorkerTerminated { + handle: WorkerHandle, + }, +} + +/// Thin controller that translates [`SessionControlCommand`] into +/// [`WorkerRegistry`] operations and returns [`SessionControlResponse`]. +#[derive(Debug, Clone)] +pub struct SessionController { + registry: WorkerRegistry, + default_cwd: String, +} + +impl SessionController { + #[must_use] + pub fn new(default_cwd: &str) -> Self { + Self { + registry: WorkerRegistry::new(), + default_cwd: default_cwd.to_owned(), + } + } + + /// Access the underlying registry for observation or direct queries. + #[must_use] + pub fn registry(&self) -> &WorkerRegistry { + &self.registry + } + + /// Dispatch a single control command, returning a typed response. + /// + /// `worker_id` is required for every command except + /// [`SessionControlCommand::CreateWorker`]. + pub fn execute_command( + &self, + worker_id: Option<&str>, + cmd: SessionControlCommand, + ) -> Result { + match cmd { + SessionControlCommand::CreateWorker => { + let worker = self.registry.create(&self.default_cwd, &[], true); + Ok(SessionControlResponse::WorkerCreated { + handle: worker_to_handle(&worker), + }) + } + SessionControlCommand::AwaitReady { timeout } => { + let id = require_worker_id(worker_id)?; + let snapshot = await_until_ready(&self.registry, id, timeout)?; + Ok(SessionControlResponse::ReadyStatus { snapshot }) + } + SessionControlCommand::SendTask { task } => { + let id = require_worker_id(worker_id)?; + let worker = self.registry.send_prompt(id, Some(&task))?; + Ok(SessionControlResponse::TaskSent { + handle: worker_to_handle(&worker), + }) + } + SessionControlCommand::FetchState => { + let id = require_worker_id(worker_id)?; + let worker = self + .registry + .get(id) + .ok_or_else(|| format!("worker not found: {id}"))?; + Ok(SessionControlResponse::State { + handle: worker_to_handle(&worker), + }) + } + SessionControlCommand::FetchLastError => { + let id = require_worker_id(worker_id)?; + let worker = self + .registry + .get(id) + .ok_or_else(|| format!("worker not found: {id}"))?; + Ok(SessionControlResponse::LastError { + worker_id: worker.worker_id.clone(), + error: worker.last_error.clone(), + }) + } + SessionControlCommand::RestartWorker => { + let id = require_worker_id(worker_id)?; + let worker = self.registry.restart(id)?; + Ok(SessionControlResponse::WorkerRestarted { + handle: worker_to_handle(&worker), + }) + } + SessionControlCommand::TerminateWorker => { + let id = require_worker_id(worker_id)?; + let worker = self.registry.terminate(id)?; + Ok(SessionControlResponse::WorkerTerminated { + handle: worker_to_handle(&worker), + }) + } + } + } +} + +fn require_worker_id(worker_id: Option<&str>) -> Result<&str, String> { + worker_id.ok_or_else(|| "worker_id is required for this command".to_string()) +} + +fn await_until_ready( + registry: &WorkerRegistry, + worker_id: &str, + timeout: u64, +) -> Result { + let start = Instant::now(); + let timeout = Duration::from_millis(timeout); + + loop { + let snapshot = registry.await_ready(worker_id)?; + if snapshot.ready || snapshot.blocked || start.elapsed() >= timeout { + return Ok(snapshot); + } + + std::thread::yield_now(); + std::thread::sleep(Duration::from_millis(10)); + } +} + +fn worker_to_handle(worker: &Worker) -> WorkerHandle { + WorkerHandle { + id: worker.worker_id.clone(), + state: worker.status, + created_at: worker.created_at, + last_error: worker.last_error.clone(), + } +} + +#[cfg(test)] +mod session_control_api_tests { + use super::*; + + fn create_worker(controller: &SessionController) -> WorkerHandle { + match controller + .execute_command(None, SessionControlCommand::CreateWorker) + .expect("create should succeed") + { + SessionControlResponse::WorkerCreated { handle } => handle, + other => panic!("expected WorkerCreated, got: {other:?}"), + } + } + + // --- create --- + + #[test] + fn create_worker_returns_spawning_handle() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + + // when + let response = controller + .execute_command(None, SessionControlCommand::CreateWorker) + .expect("create should succeed"); + + // then + match response { + SessionControlResponse::WorkerCreated { handle } => { + assert!(handle.id.starts_with("worker_")); + assert_eq!(handle.state, WorkerStatus::Spawning); + assert!(handle.created_at > 0); + assert!(handle.last_error.is_none()); + } + other => panic!("expected WorkerCreated, got: {other:?}"), + } + } + + // --- await ready --- + + #[test] + fn await_ready_reports_not_ready_before_observe() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + + // when + let response = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::AwaitReady { timeout: 0 }, + ) + .expect("await_ready should succeed"); + + // then + match response { + SessionControlResponse::ReadyStatus { snapshot } => { + assert!(!snapshot.ready); + assert!(!snapshot.blocked); + } + other => panic!("expected ReadyStatus, got: {other:?}"), + } + } + + #[test] + fn await_ready_reports_ready_after_observe() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + controller + .registry() + .observe(&handle.id, "Ready for your input\n>") + .expect("observe should succeed"); + + // when + let response = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::AwaitReady { timeout: 0 }, + ) + .expect("await_ready should succeed"); + + // then + match response { + SessionControlResponse::ReadyStatus { snapshot } => { + assert!(snapshot.ready); + assert!(!snapshot.blocked); + } + other => panic!("expected ReadyStatus, got: {other:?}"), + } + } + + // --- send task --- + + #[test] + fn send_task_transitions_to_prompt_accepted() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + controller + .registry() + .observe(&handle.id, "Ready for input\n>") + .expect("observe should succeed"); + + // when + let response = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::SendTask { + task: "Implement the feature".to_string(), + }, + ) + .expect("send_task should succeed"); + + // then + match response { + SessionControlResponse::TaskSent { handle } => { + assert_eq!(handle.state, WorkerStatus::PromptAccepted); + } + other => panic!("expected TaskSent, got: {other:?}"), + } + } + + // --- fetch state --- + + #[test] + fn fetch_state_returns_current_handle() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + + // when + let response = controller + .execute_command(Some(&handle.id), SessionControlCommand::FetchState) + .expect("fetch_state should succeed"); + + // then + match response { + SessionControlResponse::State { + handle: state_handle, + } => { + assert_eq!(state_handle.id, handle.id); + assert_eq!(state_handle.state, WorkerStatus::Spawning); + } + other => panic!("expected State, got: {other:?}"), + } + } + + // --- fetch last error --- + + #[test] + fn fetch_last_error_returns_none_when_healthy() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + + // when + let response = controller + .execute_command(Some(&handle.id), SessionControlCommand::FetchLastError) + .expect("fetch_last_error should succeed"); + + // then + match response { + SessionControlResponse::LastError { worker_id, error } => { + assert_eq!(worker_id, handle.id); + assert!(error.is_none()); + } + other => panic!("expected LastError, got: {other:?}"), + } + } + + #[test] + fn fetch_last_error_surfaces_trust_gate_failure() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + controller + .registry() + .observe( + &handle.id, + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ) + .expect("observe should succeed"); + + // when + let response = controller + .execute_command(Some(&handle.id), SessionControlCommand::FetchLastError) + .expect("fetch_last_error should succeed"); + + // then + match response { + SessionControlResponse::LastError { worker_id, error } => { + assert_eq!(worker_id, handle.id); + let failure = error.expect("error should be present"); + assert_eq!( + failure.kind, + crate::worker_boot::WorkerFailureKind::TrustGate + ); + assert!(failure.message.contains("trust prompt")); + } + other => panic!("expected LastError, got: {other:?}"), + } + } + + // --- restart --- + + #[test] + fn restart_worker_resets_to_spawning() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + controller + .registry() + .observe(&handle.id, "Ready for input\n>") + .expect("observe should succeed"); + + // when + let response = controller + .execute_command(Some(&handle.id), SessionControlCommand::RestartWorker) + .expect("restart should succeed"); + + // then + match response { + SessionControlResponse::WorkerRestarted { handle } => { + assert_eq!(handle.state, WorkerStatus::Spawning); + assert!(handle.last_error.is_none()); + } + other => panic!("expected WorkerRestarted, got: {other:?}"), + } + } + + // --- terminate --- + + #[test] + fn terminate_worker_transitions_to_finished() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + + // when + let response = controller + .execute_command(Some(&handle.id), SessionControlCommand::TerminateWorker) + .expect("terminate should succeed"); + + // then + match response { + SessionControlResponse::WorkerTerminated { handle } => { + assert_eq!(handle.state, WorkerStatus::Finished); + } + other => panic!("expected WorkerTerminated, got: {other:?}"), + } + } + + // --- full lifecycle --- + + #[test] + fn full_create_ready_send_state_restart_terminate_lifecycle() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + + // when: create + let handle = create_worker(&controller); + assert_eq!(handle.state, WorkerStatus::Spawning); + + // when: simulate readiness via registry observe + controller + .registry() + .observe(&handle.id, "Ready for input\n>") + .expect("observe should succeed"); + + // when: await ready + let ready_resp = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::AwaitReady { timeout: 0 }, + ) + .expect("await_ready should succeed"); + match &ready_resp { + SessionControlResponse::ReadyStatus { snapshot } => assert!(snapshot.ready), + other => panic!("expected ReadyStatus, got: {other:?}"), + } + + // when: send task + let sent_resp = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::SendTask { + task: "Run all tests".to_string(), + }, + ) + .expect("send_task should succeed"); + match &sent_resp { + SessionControlResponse::TaskSent { handle } => { + assert_eq!(handle.state, WorkerStatus::PromptAccepted); + } + other => panic!("expected TaskSent, got: {other:?}"), + } + + // when: fetch state (should reflect prompt accepted) + let state_resp = controller + .execute_command(Some(&handle.id), SessionControlCommand::FetchState) + .expect("fetch_state should succeed"); + match &state_resp { + SessionControlResponse::State { handle } => { + assert_eq!(handle.state, WorkerStatus::PromptAccepted); + } + other => panic!("expected State, got: {other:?}"), + } + + // when: restart + let restart_resp = controller + .execute_command(Some(&handle.id), SessionControlCommand::RestartWorker) + .expect("restart should succeed"); + match &restart_resp { + SessionControlResponse::WorkerRestarted { handle } => { + assert_eq!(handle.state, WorkerStatus::Spawning); + } + other => panic!("expected WorkerRestarted, got: {other:?}"), + } + + // when: terminate + let term_resp = controller + .execute_command(Some(&handle.id), SessionControlCommand::TerminateWorker) + .expect("terminate should succeed"); + match &term_resp { + SessionControlResponse::WorkerTerminated { handle } => { + assert_eq!(handle.state, WorkerStatus::Finished); + } + other => panic!("expected WorkerTerminated, got: {other:?}"), + } + } + + // --- error paths --- + + #[test] + fn non_create_commands_reject_missing_worker_id() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let commands = vec![ + SessionControlCommand::AwaitReady { timeout: 0 }, + SessionControlCommand::SendTask { + task: "test".to_string(), + }, + SessionControlCommand::FetchState, + SessionControlCommand::FetchLastError, + SessionControlCommand::RestartWorker, + SessionControlCommand::TerminateWorker, + ]; + + for cmd in commands { + // when + let result = controller.execute_command(None, cmd); + + // then + let error = result.expect_err("missing worker_id should fail"); + assert!( + error.contains("worker_id is required"), + "error was: {error}" + ); + } + } + + #[test] + fn commands_reject_nonexistent_worker() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + + // when + let result = + controller.execute_command(Some("nonexistent"), SessionControlCommand::FetchState); + + // then + let error = result.expect_err("nonexistent worker should fail"); + assert!(error.contains("worker not found"), "error was: {error}"); + } + + #[test] + fn await_ready_honors_timeout_until_worker_becomes_ready() { + // given + let controller = SessionController::new("/tmp/test-cwd"); + let handle = create_worker(&controller); + let registry = controller.registry().clone(); + let worker_id = handle.id.clone(); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(20)); + registry + .observe(&worker_id, "Ready for input\n>") + .expect("observe should succeed"); + }); + + // when + let response = controller + .execute_command( + Some(&handle.id), + SessionControlCommand::AwaitReady { timeout: 200 }, + ) + .expect("await_ready should succeed"); + + // then + match response { + SessionControlResponse::ReadyStatus { snapshot } => { + assert!(snapshot.ready); + assert!(!snapshot.blocked); + } + other => panic!("expected ReadyStatus, got: {other:?}"), + } + } }