diff --git a/rust/crates/runtime/src/mcp_lifecycle_hardened.rs b/rust/crates/runtime/src/mcp_lifecycle_hardened.rs index b41ab91..01000f8 100644 --- a/rust/crates/runtime/src/mcp_lifecycle_hardened.rs +++ b/rust/crates/runtime/src/mcp_lifecycle_hardened.rs @@ -124,11 +124,11 @@ pub enum McpPhaseResult { Failure { phase: McpLifecyclePhase, error: McpErrorSurface, - recoverable: bool, }, Timeout { phase: McpLifecyclePhase, waited: Duration, + error: McpErrorSurface, }, } @@ -200,6 +200,15 @@ impl McpLifecycleState { fn record_result(&mut self, result: McpPhaseResult) { self.phase_results.push(result); } + + fn can_resume_after_error(&self) -> bool { + match self.phase_results.last() { + Some(McpPhaseResult::Failure { error, .. } | McpPhaseResult::Timeout { error, .. }) => { + error.recoverable + } + _ => false, + } + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -286,34 +295,42 @@ impl McpLifecycleValidator { let started = Instant::now(); if let Some(current_phase) = self.state.current_phase() { - if !Self::validate_phase_transition(current_phase, phase) { - return self.record_failure( - phase, - McpErrorSurface::new( - phase, - None, - format!("invalid MCP lifecycle transition from {current_phase} to {phase}"), - BTreeMap::from([ - ("from".to_string(), current_phase.to_string()), - ("to".to_string(), phase.to_string()), - ]), - false, - ), - false, - ); - } - } else if phase != McpLifecyclePhase::ConfigLoad { - return self.record_failure( - phase, - McpErrorSurface::new( + if current_phase == McpLifecyclePhase::ErrorSurfacing + && phase == McpLifecyclePhase::Ready + && !self.state.can_resume_after_error() + { + return self.record_failure(McpErrorSurface::new( phase, None, - format!("invalid initial MCP lifecycle phase {phase}"), - BTreeMap::from([("phase".to_string(), phase.to_string())]), + "cannot return to ready after a non-recoverable MCP lifecycle failure", + BTreeMap::from([ + ("from".to_string(), current_phase.to_string()), + ("to".to_string(), phase.to_string()), + ]), false, - ), + )); + } + + if !Self::validate_phase_transition(current_phase, phase) { + return self.record_failure(McpErrorSurface::new( + phase, + None, + format!("invalid MCP lifecycle transition from {current_phase} to {phase}"), + BTreeMap::from([ + ("from".to_string(), current_phase.to_string()), + ("to".to_string(), phase.to_string()), + ]), + false, + )); + } + } else if phase != McpLifecyclePhase::ConfigLoad { + return self.record_failure(McpErrorSurface::new( + phase, + None, + format!("invalid initial MCP lifecycle phase {phase}"), + BTreeMap::from([("phase".to_string(), phase.to_string())]), false, - ); + )); } self.state.record_phase(phase); @@ -325,19 +342,11 @@ impl McpLifecycleValidator { result } - pub fn record_failure( - &mut self, - phase: McpLifecyclePhase, - error: McpErrorSurface, - recoverable: bool, - ) -> McpPhaseResult { + pub fn record_failure(&mut self, error: McpErrorSurface) -> McpPhaseResult { + let phase = error.phase; self.state.record_error(error.clone()); self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); - let result = McpPhaseResult::Failure { - phase, - error, - recoverable, - }; + let result = McpPhaseResult::Failure { phase, error }; self.state.record_result(result.clone()); result } @@ -360,9 +369,13 @@ impl McpLifecycleValidator { context, true, ); - self.state.record_error(error); + self.state.record_error(error.clone()); self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); - let result = McpPhaseResult::Timeout { phase, waited }; + let result = McpPhaseResult::Timeout { + phase, + waited, + error, + }; self.state.record_result(result.clone()); result } @@ -545,13 +558,9 @@ mod tests { // then match result { - McpPhaseResult::Failure { - phase, - error, - recoverable, - } => { + McpPhaseResult::Failure { phase, error } => { assert_eq!(phase, McpLifecyclePhase::Ready); - assert!(!recoverable); + assert!(!error.recoverable); assert_eq!(error.phase, McpLifecyclePhase::Ready); assert_eq!( error.context.get("from").map(String::as_str), @@ -581,27 +590,25 @@ mod tests { // when / then for phase in McpLifecyclePhase::all() { - let result = validator.record_failure( + let result = validator.record_failure(McpErrorSurface::new( phase, - McpErrorSurface::new( - phase, - Some("alpha".to_string()), - format!("failure at {phase}"), - BTreeMap::from([("server".to_string(), "alpha".to_string())]), - phase == McpLifecyclePhase::ResourceDiscovery, - ), + Some("alpha".to_string()), + format!("failure at {phase}"), + BTreeMap::from([("server".to_string(), "alpha".to_string())]), phase == McpLifecyclePhase::ResourceDiscovery, - ); + )); match result { McpPhaseResult::Failure { phase: failed_phase, error, - recoverable, } => { assert_eq!(failed_phase, phase); assert_eq!(error.phase, phase); - assert_eq!(recoverable, phase == McpLifecyclePhase::ResourceDiscovery); + assert_eq!( + error.recoverable, + phase == McpLifecyclePhase::ResourceDiscovery + ); } other => panic!("expected failure result, got {other:?}"), } @@ -628,9 +635,12 @@ mod tests { McpPhaseResult::Timeout { phase, waited: actual, + error, } => { assert_eq!(phase, McpLifecyclePhase::SpawnConnect); assert_eq!(actual, waited); + assert!(error.recoverable); + assert_eq!(error.server_name.as_deref(), Some("alpha")); } other => panic!("expected timeout result, got {other:?}"), } @@ -707,17 +717,13 @@ mod tests { let result = validator.run_phase(phase); assert!(matches!(result, McpPhaseResult::Success { .. })); } - let _ = validator.record_failure( + let _ = validator.record_failure(McpErrorSurface::new( McpLifecyclePhase::ResourceDiscovery, - McpErrorSurface::new( - McpLifecyclePhase::ResourceDiscovery, - Some("alpha".to_string()), - "resource listing failed", - BTreeMap::from([("reason".to_string(), "timeout".to_string())]), - true, - ), + Some("alpha".to_string()), + "resource listing failed", + BTreeMap::from([("reason".to_string(), "timeout".to_string())]), true, - ); + )); // when let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown); @@ -758,4 +764,79 @@ mod tests { let trait_object: &dyn std::error::Error = &error; assert_eq!(trait_object.to_string(), rendered); } + + #[test] + fn given_nonrecoverable_failure_when_returning_to_ready_then_validator_rejects_resume() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::Ready, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + let _ = validator.record_failure(McpErrorSurface::new( + McpLifecyclePhase::Invocation, + Some("alpha".to_string()), + "tool call corrupted the session", + BTreeMap::from([("reason".to_string(), "invalid frame".to_string())]), + false, + )); + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + match result { + McpPhaseResult::Failure { phase, error } => { + assert_eq!(phase, McpLifecyclePhase::Ready); + assert!(!error.recoverable); + assert!(error.message.contains("non-recoverable")); + } + other => panic!("expected failure result, got {other:?}"), + } + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::ErrorSurfacing) + ); + } + + #[test] + fn given_recoverable_failure_when_returning_to_ready_then_validator_allows_resume() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::Ready, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + let _ = validator.record_failure(McpErrorSurface::new( + McpLifecyclePhase::Invocation, + Some("alpha".to_string()), + "tool call failed but can be retried", + BTreeMap::from([("reason".to_string(), "upstream timeout".to_string())]), + true, + )); + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + assert!(matches!(result, McpPhaseResult::Success { .. })); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::Ready) + ); + } }