From f932fad2edae6b1e539646740062ce0dfbe0a6f0 Mon Sep 17 00:00:00 2001 From: ezl-keygraph Date: Fri, 13 Feb 2026 20:26:16 +0530 Subject: [PATCH] feat: add workflow resume from workspace via --workspace flag When a workflow is interrupted (VM crash, Ctrl+C, Docker restart), it can now be resumed by passing the workspace name. The system reads session.json to determine which agents completed, validates deliverables exist on disk, restores the git checkpoint, and skips already-completed agents. - Add --workspace CLI flag and auto-terminate conflicting workflows - Add loadResumeState, restoreGitCheckpoint, recordResumeAttempt activities - Add skip logic for all 5 pipeline phases including parallel execution - Separate sessionId (persistent directory) from workflowId (execution ID) - Track resume attempts in session.json for audit trail - Derive AgentName type from ALL_AGENTS array to eliminate duplication - Add getDeliverablePath mapping for deliverable validation --- src/audit/audit-session.ts | 30 ++++- src/audit/metrics-tracker.ts | 71 +++++++++- src/temporal/activities.ts | 238 +++++++++++++++++++++++++++++++-- src/temporal/client.ts | 126 +++++++++++++++++- src/temporal/shared.ts | 10 ++ src/temporal/workflows.ts | 247 ++++++++++++++++++++++++++--------- src/types/agents.ts | 66 ++++++++-- 7 files changed, 691 insertions(+), 97 deletions(-) diff --git a/src/audit/audit-session.ts b/src/audit/audit-session.ts index bdfb3c5..6e09133 100644 --- a/src/audit/audit-session.ts +++ b/src/audit/audit-session.ts @@ -64,8 +64,10 @@ export class AuditSession { /** * Initialize audit session (creates directories, session.json) * Idempotent and race-safe + * + * @param workflowId - Optional workflow ID for tracking original or resume workflows */ - async initialize(): Promise { + async initialize(workflowId?: string): Promise { if (this.initialized) { return; // Already initialized } @@ -74,7 +76,7 @@ export class AuditSession { await initializeAuditStructure(this.sessionMetadata); // Initialize metrics tracker (loads or creates session.json) - await this.metricsTracker.initialize(); + await this.metricsTracker.initialize(workflowId); // Initialize workflow logger await this.workflowLogger.initialize(); @@ -252,4 +254,28 @@ export class AuditSession { await this.ensureInitialized(); await this.workflowLogger.logWorkflowComplete(summary); } + + /** + * Add a resume attempt to the session + * Call this when a workflow is resuming from an existing workspace + * + * @param workflowId - The new workflow ID for this resume attempt + * @param terminatedWorkflows - IDs of workflows that were terminated + * @param checkpointHash - Git checkpoint hash that was restored + */ + async addResumeAttempt( + workflowId: string, + terminatedWorkflows: string[], + checkpointHash?: string + ): Promise { + await this.ensureInitialized(); + + const unlock = await sessionMutex.lock(this.sessionId); + try { + await this.metricsTracker.reload(); + await this.metricsTracker.addResumeAttempt(workflowId, terminatedWorkflows, checkpointHash); + } finally { + unlock(); + } + } } diff --git a/src/audit/metrics-tracker.ts b/src/audit/metrics-tracker.ts index 1096fe0..4462d90 100644 --- a/src/audit/metrics-tracker.ts +++ b/src/audit/metrics-tracker.ts @@ -46,6 +46,13 @@ interface PhaseMetrics { agent_count: number; } +export interface ResumeAttempt { + workflowId: string; + timestamp: string; + terminatedPrevious?: string; + resumedFromCheckpoint?: string; +} + interface SessionData { session: { id: string; @@ -54,6 +61,8 @@ interface SessionData { status: 'in-progress' | 'completed' | 'failed'; createdAt: string; completedAt?: string; + originalWorkflowId?: string; // First workflow that created this workspace + resumeAttempts?: ResumeAttempt[]; // Track all resume attempts }; metrics: { total_duration_ms: number; @@ -95,8 +104,10 @@ export class MetricsTracker { /** * Initialize session.json (idempotent) + * + * @param workflowId - Optional workflow ID to set as originalWorkflowId for new sessions */ - async initialize(): Promise { + async initialize(workflowId?: string): Promise { // Check if session.json already exists const exists = await fileExists(this.sessionJsonPath); @@ -105,21 +116,24 @@ export class MetricsTracker { this.data = await readJson(this.sessionJsonPath); } else { // Create new session.json - this.data = this.createInitialData(); + this.data = this.createInitialData(workflowId); await this.save(); } } /** * Create initial session.json structure + * + * @param workflowId - Optional workflow ID to set as originalWorkflowId */ - private createInitialData(): SessionData { + private createInitialData(workflowId?: string): SessionData { const sessionData: SessionData = { session: { id: this.sessionMetadata.id, webUrl: this.sessionMetadata.webUrl, status: 'in-progress', createdAt: (this.sessionMetadata as { createdAt?: string }).createdAt || formatTimestamp(), + resumeAttempts: [], }, metrics: { total_duration_ms: 0, @@ -128,6 +142,12 @@ export class MetricsTracker { agents: {}, // Agent-level metrics }, }; + + // Set originalWorkflowId if provided (for new workspaces) + if (workflowId) { + sessionData.session.originalWorkflowId = workflowId; + } + // Only add repoPath if it exists if (this.sessionMetadata.repoPath) { sessionData.session.repoPath = this.sessionMetadata.repoPath; @@ -229,6 +249,51 @@ export class MetricsTracker { await this.save(); } + /** + * Add a resume attempt to the session + * + * @param workflowId - The new workflow ID for this resume attempt + * @param terminatedWorkflows - IDs of workflows that were terminated + * @param checkpointHash - Git checkpoint hash that was restored + */ + async addResumeAttempt( + workflowId: string, + terminatedWorkflows: string[], + checkpointHash?: string + ): Promise { + if (!this.data) { + throw new Error('MetricsTracker not initialized'); + } + + // Ensure originalWorkflowId is set (backfill if missing from old sessions) + if (!this.data.session.originalWorkflowId) { + this.data.session.originalWorkflowId = this.data.session.id; + } + + // Ensure resumeAttempts array exists + if (!this.data.session.resumeAttempts) { + this.data.session.resumeAttempts = []; + } + + // Add new resume attempt + const resumeAttempt: ResumeAttempt = { + workflowId, + timestamp: formatTimestamp(), + }; + + if (terminatedWorkflows.length > 0) { + resumeAttempt.terminatedPrevious = terminatedWorkflows.join(','); + } + + if (checkpointHash) { + resumeAttempt.resumedFromCheckpoint = checkpointHash; + } + + this.data.session.resumeAttempts.push(resumeAttempt); + + await this.save(); + } + /** * Recalculate aggregations (total duration, total cost, phases) */ diff --git a/src/temporal/activities.ts b/src/temporal/activities.ts index a351a94..40fcf4f 100644 --- a/src/temporal/activities.ts +++ b/src/temporal/activities.ts @@ -72,9 +72,14 @@ import { getPromptNameForAgent } from '../types/agents.js'; import { AuditSession } from '../audit/index.js'; import type { WorkflowSummary } from '../audit/workflow-logger.js'; import type { AgentName } from '../types/agents.js'; -import type { AgentMetrics } from './shared.js'; +import { getDeliverablePath, ALL_AGENTS } from '../types/agents.js'; +import type { AgentMetrics, ResumeState } from './shared.js'; import type { DistributedConfig } from '../types/config.js'; -import { copyDeliverablesToAudit, type SessionMetadata } from '../audit/utils.js'; +import { copyDeliverablesToAudit, type SessionMetadata, readJson, fileExists } from '../audit/utils.js'; +import type { ResumeAttempt } from '../audit/metrics-tracker.js'; +import { executeGitCommandWithRetry } from '../utils/git-manager.js'; +import path from 'path'; +import fs from 'fs/promises'; const HEARTBEAT_INTERVAL_MS = 2000; // Must be < heartbeatTimeout (10min production, 5min testing) @@ -89,6 +94,7 @@ export interface ActivityInput { outputPath?: string; pipelineTestingMode?: boolean; workflowId: string; + sessionId: string; // Workspace name (for resume) or workflowId (for new runs) } /** @@ -142,8 +148,9 @@ async function runAgentActivity( } // 2. Build session metadata for audit + // Use sessionId (workspace name) for directory, workflowId for tracking const sessionMetadata: SessionMetadata = { - id: workflowId, + id: input.sessionId, webUrl, repoPath, ...(outputPath && { outputPath }), @@ -151,7 +158,7 @@ async function runAgentActivity( // 3. Initialize audit session (idempotent, safe across retries) const auditSession = new AuditSession(sessionMetadata); - await auditSession.initialize(); + await auditSession.initialize(workflowId); // 4. Load prompt const promptName = getPromptNameForAgent(agentName); @@ -449,6 +456,219 @@ export async function checkExploitationQueue( }; } +// === Resume Activities === + +/** + * Session.json structure for resume state loading + */ +interface SessionJson { + session: { + id: string; + webUrl: string; + repoPath?: string; + originalWorkflowId?: string; + resumeAttempts?: ResumeAttempt[]; + }; + metrics: { + agents: Record; + }; +} + +/** + * Load resume state from an existing workspace. + * Validates workspace exists, URL matches, and determines which agents to skip. + * + * @throws ApplicationFailure.nonRetryable if workspace not found or URL mismatch + */ +export async function loadResumeState( + workspaceName: string, + expectedUrl: string, + expectedRepoPath: string +): Promise { + const sessionPath = path.join('./audit-logs', workspaceName, 'session.json'); + + // Validate workspace exists + const exists = await fileExists(sessionPath); + if (!exists) { + throw ApplicationFailure.nonRetryable( + `Workspace not found: ${workspaceName}\nExpected path: ${sessionPath}`, + 'WorkspaceNotFoundError' + ); + } + + // Load session.json + let session: SessionJson; + try { + session = await readJson(sessionPath); + } catch (error) { + const errorMsg = error instanceof Error ? error.message : String(error); + throw ApplicationFailure.nonRetryable( + `Corrupted session.json in workspace ${workspaceName}: ${errorMsg}`, + 'CorruptedSessionError' + ); + } + + // Validate URL matches + if (session.session.webUrl !== expectedUrl) { + throw ApplicationFailure.nonRetryable( + `URL mismatch with workspace\n Workspace URL: ${session.session.webUrl}\n Provided URL: ${expectedUrl}`, + 'URLMismatchError' + ); + } + + // Find completed agents (status === 'success' AND deliverable exists) + const completedAgents: string[] = []; + const agents = session.metrics.agents; + + for (const agentName of ALL_AGENTS) { + const agentData = agents[agentName]; + + // Skip if agent never ran or didn't succeed + if (!agentData || agentData.status !== 'success') { + continue; + } + + // Validate deliverable exists + const deliverablePath = getDeliverablePath(agentName, expectedRepoPath); + const deliverableExists = await fileExists(deliverablePath); + + if (!deliverableExists) { + console.log( + chalk.yellow(`Agent ${agentName} shows success but deliverable missing, will re-run`) + ); + continue; + } + + // Agent completed successfully and deliverable exists + completedAgents.push(agentName); + } + + // Find latest checkpoint from completed agents + const checkpoints = completedAgents + .map((name) => agents[name]?.checkpoint) + .filter((hash): hash is string => hash != null); + + if (checkpoints.length === 0) { + throw ApplicationFailure.nonRetryable( + `No successful agent checkpoints found in workspace ${workspaceName}`, + 'NoCheckpointsError' + ); + } + + // Find most recent commit among checkpoints + const checkpointHash = await findLatestCommit(expectedRepoPath, checkpoints); + + const originalWorkflowId = session.session.originalWorkflowId || session.session.id; + + console.log(chalk.cyan(`=== RESUME STATE ===`)); + console.log(`Workspace: ${workspaceName}`); + console.log(`Completed agents: ${completedAgents.length}`); + console.log(`Checkpoint: ${checkpointHash}`); + + return { + workspaceName, + originalUrl: session.session.webUrl, + completedAgents, + checkpointHash, + originalWorkflowId, + }; +} + +/** + * Find the most recent commit among a list of commit hashes. + * Uses git rev-list to determine which commit is newest. + */ +async function findLatestCommit(repoPath: string, commitHashes: string[]): Promise { + if (commitHashes.length === 1) { + const hash = commitHashes[0]; + if (!hash) { + throw new Error('Empty commit hash in array'); + } + return hash; + } + + // Use git rev-list to find the most recent commit among all hashes + const result = await executeGitCommandWithRetry( + ['git', 'rev-list', '--max-count=1', ...commitHashes], + repoPath, + 'find latest commit' + ); + + return result.stdout.trim(); +} + +/** + * Restore git workspace to a checkpoint and clean up partial deliverables. + * + * @param repoPath - Repository path + * @param checkpointHash - Git commit hash to reset to + * @param incompleteAgents - Agents that didn't complete (will have deliverables cleaned up) + */ +export async function restoreGitCheckpoint( + repoPath: string, + checkpointHash: string, + incompleteAgents: AgentName[] +): Promise { + console.log(chalk.blue(`Restoring git workspace to ${checkpointHash}...`)); + + // Git reset to checkpoint + await executeGitCommandWithRetry( + ['git', 'reset', '--hard', checkpointHash], + repoPath, + 'reset to checkpoint for resume' + ); + await executeGitCommandWithRetry( + ['git', 'clean', '-fd'], + repoPath, + 'clean untracked files for resume' + ); + + + // Explicitly delete deliverables for incomplete agents + for (const agentName of incompleteAgents) { + const deliverablePath = getDeliverablePath(agentName, repoPath); + try { + const exists = await fileExists(deliverablePath); + if (exists) { + console.log(chalk.yellow(`Cleaning partial deliverable: ${agentName}`)); + await fs.unlink(deliverablePath); + } + } catch (error) { + // Non-fatal, just log + console.log(chalk.gray(`Note: Failed to delete ${deliverablePath}: ${error}`)); + } + } + + console.log(chalk.green('Workspace restored to clean state')); +} + +/** + * Record a resume attempt in session.json. + * Tracks the new workflow ID, terminated workflows, and checkpoint hash. + */ +export async function recordResumeAttempt( + input: ActivityInput, + terminatedWorkflows: string[], + checkpointHash: string +): Promise { + const { webUrl, repoPath, outputPath, sessionId, workflowId } = input; + + const sessionMetadata: SessionMetadata = { + id: sessionId, + webUrl, + repoPath, + ...(outputPath && { outputPath }), + }; + + const auditSession = new AuditSession(sessionMetadata); + await auditSession.initialize(); + + await auditSession.addResumeAttempt(workflowId, terminatedWorkflows, checkpointHash); +} + /** * Log phase transition to the unified workflow log. * Called at phase boundaries for per-workflow logging. @@ -458,10 +678,10 @@ export async function logPhaseTransition( phase: string, event: 'start' | 'complete' ): Promise { - const { webUrl, repoPath, outputPath, workflowId } = input; + const { webUrl, repoPath, outputPath, sessionId } = input; const sessionMetadata: SessionMetadata = { - id: workflowId, + id: sessionId, webUrl, repoPath, ...(outputPath && { outputPath }), @@ -485,16 +705,16 @@ export async function logWorkflowComplete( input: ActivityInput, summary: WorkflowSummary ): Promise { - const { webUrl, repoPath, outputPath, workflowId } = input; + const { webUrl, repoPath, outputPath, sessionId, workflowId } = input; const sessionMetadata: SessionMetadata = { - id: workflowId, + id: sessionId, webUrl, repoPath, ...(outputPath && { outputPath }), }; const auditSession = new AuditSession(sessionMetadata); - await auditSession.initialize(); + await auditSession.initialize(workflowId); await auditSession.logWorkflowComplete(summary); } diff --git a/src/temporal/client.ts b/src/temporal/client.ts index 945af42..c1b47a1 100644 --- a/src/temporal/client.ts +++ b/src/temporal/client.ts @@ -26,19 +26,86 @@ * TEMPORAL_ADDRESS - Temporal server address (default: localhost:7233) */ -import { Connection, Client } from '@temporalio/client'; +import { Connection, Client, WorkflowNotFoundError } from '@temporalio/client'; import dotenv from 'dotenv'; import chalk from 'chalk'; import { displaySplashScreen } from '../splash-screen.js'; import { sanitizeHostname } from '../audit/utils.js'; +import { readJson, fileExists } from '../audit/utils.js'; +import path from 'path'; // Import types only - these don't pull in workflow runtime code import type { PipelineInput, PipelineState, PipelineProgress } from './shared.js'; +/** + * Session.json structure for resume validation + */ +interface SessionJson { + session: { + id: string; + webUrl: string; + originalWorkflowId?: string; + resumeAttempts?: Array<{ workflowId: string }>; + }; +} + dotenv.config(); // Query name must match the one defined in workflows.ts const PROGRESS_QUERY = 'getProgress'; +/** + * Terminate any running workflows associated with a workspace. + * Returns the list of terminated workflow IDs. + */ +async function terminateExistingWorkflows( + client: Client, + workspaceName: string +): Promise { + const sessionPath = path.join('./audit-logs', workspaceName, 'session.json'); + + if (!(await fileExists(sessionPath))) { + throw new Error( + `Workspace not found: ${workspaceName}\n` + + `Expected path: ${sessionPath}` + ); + } + + const session = await readJson(sessionPath); + + // Collect all workflow IDs associated with this workspace + const workflowIds = [ + session.session.originalWorkflowId || session.session.id, + ...(session.session.resumeAttempts?.map((r) => r.workflowId) || []), + ].filter((id): id is string => id != null); + + const terminated: string[] = []; + + for (const wfId of workflowIds) { + try { + const handle = client.workflow.getHandle(wfId); + const description = await handle.describe(); + + if (description.status.name === 'RUNNING') { + console.log(chalk.yellow(`Terminating running workflow: ${wfId}`)); + await handle.terminate('Superseded by resume workflow'); + terminated.push(wfId); + console.log(chalk.green(`Terminated: ${wfId}`)); + } else { + console.log(chalk.gray(`Workflow already ${description.status.name}: ${wfId}`)); + } + } catch (error) { + if (error instanceof WorkflowNotFoundError) { + console.log(chalk.gray(`Workflow not found (already cleaned up): ${wfId}`)); + } else { + console.log(chalk.red(`Failed to terminate ${wfId}: ${error}`)); + // Continue anyway - don't block resume on termination failure + } + } + } + + return terminated; +} + function showUsage(): void { console.log(chalk.cyan.bold('\nShannon Temporal Client')); console.log(chalk.gray('Start a pentest pipeline workflow\n')); @@ -50,6 +117,7 @@ function showUsage(): void { console.log(' --config Configuration file path'); console.log(' --output Output directory for audit logs'); console.log(' --pipeline-testing Use minimal prompts for fast testing'); + console.log(' --workspace Resume from existing workspace'); console.log( ' --workflow-id Custom workflow ID (default: shannon-)' ); @@ -78,6 +146,7 @@ async function startPipeline(): Promise { let pipelineTestingMode = false; let customWorkflowId: string | undefined; let waitForCompletion = false; + let resumeFromWorkspace: string | undefined; for (let i = 0; i < args.length; i++) { const arg = args[i]; @@ -107,6 +176,12 @@ async function startPipeline(): Promise { } } else if (arg === '--pipeline-testing') { pipelineTestingMode = true; + } else if (arg === '--workspace') { + const nextArg = args[i + 1]; + if (nextArg && !nextArg.startsWith('-')) { + resumeFromWorkspace = nextArg; + i++; + } } else if (arg === '--wait') { waitForCompletion = true; } else if (arg && !arg.startsWith('-')) { @@ -134,26 +209,67 @@ async function startPipeline(): Promise { const client = new Client({ connection }); try { - const hostname = sanitizeHostname(webUrl); - const workflowId = customWorkflowId || `${hostname}_shannon-${Date.now()}`; + let terminatedWorkflows: string[] = []; + let workflowId: string; + let sessionId: string; // Workspace name (persistent directory) + + // === Resume Mode === + if (resumeFromWorkspace) { + console.log(chalk.cyan('=== RESUME MODE ===')); + console.log(`Workspace: ${resumeFromWorkspace}\n`); + + // Terminate any running workflows for this workspace + terminatedWorkflows = await terminateExistingWorkflows(client, resumeFromWorkspace); + + if (terminatedWorkflows.length > 0) { + console.log(chalk.yellow(`Terminated ${terminatedWorkflows.length} previous workflow(s)\n`)); + } + + // Validate URL matches workspace + const sessionPath = path.join('./audit-logs', resumeFromWorkspace, 'session.json'); + const session = await readJson(sessionPath); + + if (session.session.webUrl !== webUrl) { + console.error(chalk.red('ERROR: URL mismatch with workspace')); + console.error(` Workspace URL: ${session.session.webUrl}`); + console.error(` Provided URL: ${webUrl}`); + process.exit(1); + } + + // Generate resume workflow ID + workflowId = `${resumeFromWorkspace}_resume_${Date.now()}`; + sessionId = resumeFromWorkspace; + } else { + // === New Workflow === + const hostname = sanitizeHostname(webUrl); + workflowId = customWorkflowId || `${hostname}_shannon-${Date.now()}`; + sessionId = workflowId; + } const input: PipelineInput = { webUrl, repoPath, + workflowId, // Add for audit correlation ...(configPath && { configPath }), ...(outputPath && { outputPath }), ...(pipelineTestingMode && { pipelineTestingMode }), + ...(resumeFromWorkspace && { resumeFromWorkspace }), + ...(terminatedWorkflows.length > 0 && { terminatedWorkflows }), }; - // Determine output directory for display + // Determine output directory for display (use sessionId for persistent directory) // Use displayOutputPath (host path) if provided, otherwise fall back to outputPath or default const effectiveDisplayPath = displayOutputPath || outputPath || './audit-logs'; - const outputDir = `${effectiveDisplayPath}/${workflowId}`; + const outputDir = `${effectiveDisplayPath}/${sessionId}`; console.log(chalk.green.bold(`✓ Workflow started: ${workflowId}`)); + if (resumeFromWorkspace) { + console.log(chalk.gray(` (Resuming workspace: ${sessionId})`)); + } console.log(); console.log(chalk.white(' Target: ') + chalk.cyan(webUrl)); console.log(chalk.white(' Repository: ') + chalk.cyan(repoPath)); + console.log(chalk.white(' Workspace: ') + chalk.cyan(sessionId)); if (configPath) { console.log(chalk.white(' Config: ') + chalk.cyan(configPath)); } diff --git a/src/temporal/shared.ts b/src/temporal/shared.ts index 9120bfc..3ab7f92 100644 --- a/src/temporal/shared.ts +++ b/src/temporal/shared.ts @@ -9,6 +9,16 @@ export interface PipelineInput { outputPath?: string; pipelineTestingMode?: boolean; workflowId?: string; // Added by client, used for audit correlation + resumeFromWorkspace?: string; // Workspace name to resume from + terminatedWorkflows?: string[]; // Workflows terminated during resume +} + +export interface ResumeState { + workspaceName: string; + originalUrl: string; + completedAgents: string[]; + checkpointHash: string; + originalWorkflowId: string; } export interface AgentMetrics { diff --git a/src/temporal/workflows.ts b/src/temporal/workflows.ts index d7d16b0..75b10e2 100644 --- a/src/temporal/workflows.ts +++ b/src/temporal/workflows.ts @@ -38,8 +38,11 @@ import { type PipelineSummary, type VulnExploitPipelineResult, type AgentMetrics, + type ResumeState, } from './shared.js'; import type { VulnType } from '../queue-validation.js'; +import type { AgentName } from '../types/agents.js'; +import { ALL_AGENTS } from '../types/agents.js'; // Retry configuration for production (long intervals for billing recovery) const PRODUCTION_RETRY = { @@ -127,10 +130,14 @@ export async function pentestPipelineWorkflow( // Build ActivityInput with required workflowId for audit correlation // Activities require workflowId (non-optional), PipelineInput has it optional // Use spread to conditionally include optional properties (exactOptionalPropertyTypes) + // sessionId is workspace name for resume, or workflowId for new runs + const sessionId = input.resumeFromWorkspace || workflowId; + const activityInput: ActivityInput = { webUrl: input.webUrl, repoPath: input.repoPath, workflowId, + sessionId, ...(input.configPath !== undefined && { configPath: input.configPath }), ...(input.outputPath !== undefined && { outputPath: input.outputPath }), ...(input.pipelineTestingMode !== undefined && { @@ -138,23 +145,70 @@ export async function pentestPipelineWorkflow( }), }; + // === RESUME LOGIC === + let resumeState: ResumeState | null = null; + + if (input.resumeFromWorkspace) { + // Load resume state from existing workspace + resumeState = await a.loadResumeState( + input.resumeFromWorkspace, + input.webUrl, + input.repoPath + ); + + // Restore git checkpoint and clean up partial deliverables + const incompleteAgents = ALL_AGENTS.filter( + (agentName) => !resumeState!.completedAgents.includes(agentName) + ) as AgentName[]; + + await a.restoreGitCheckpoint( + input.repoPath, + resumeState.checkpointHash, + incompleteAgents + ); + + // Record resume attempt in session.json + await a.recordResumeAttempt( + activityInput, + input.terminatedWorkflows || [], + resumeState.checkpointHash + ); + + console.log('Resume state loaded and workspace restored'); + } + + // Helper to check if an agent should be skipped + const shouldSkip = (agentName: string): boolean => { + return resumeState?.completedAgents.includes(agentName) ?? false; + }; + try { // === Phase 1: Pre-Reconnaissance === - state.currentPhase = 'pre-recon'; - state.currentAgent = 'pre-recon'; - await a.logPhaseTransition(activityInput, 'pre-recon', 'start'); - state.agentMetrics['pre-recon'] = - await a.runPreReconAgent(activityInput); - state.completedAgents.push('pre-recon'); - await a.logPhaseTransition(activityInput, 'pre-recon', 'complete'); + if (!shouldSkip('pre-recon')) { + state.currentPhase = 'pre-recon'; + state.currentAgent = 'pre-recon'; + await a.logPhaseTransition(activityInput, 'pre-recon', 'start'); + state.agentMetrics['pre-recon'] = + await a.runPreReconAgent(activityInput); + state.completedAgents.push('pre-recon'); + await a.logPhaseTransition(activityInput, 'pre-recon', 'complete'); + } else { + console.log('Skipping pre-recon (already complete)'); + state.completedAgents.push('pre-recon'); + } // === Phase 2: Reconnaissance === - state.currentPhase = 'recon'; - state.currentAgent = 'recon'; - await a.logPhaseTransition(activityInput, 'recon', 'start'); - state.agentMetrics['recon'] = await a.runReconAgent(activityInput); - state.completedAgents.push('recon'); - await a.logPhaseTransition(activityInput, 'recon', 'complete'); + if (!shouldSkip('recon')) { + state.currentPhase = 'recon'; + state.currentAgent = 'recon'; + await a.logPhaseTransition(activityInput, 'recon', 'start'); + state.agentMetrics['recon'] = await a.runReconAgent(activityInput); + state.completedAgents.push('recon'); + await a.logPhaseTransition(activityInput, 'recon', 'complete'); + } else { + console.log('Skipping recon (already complete)'); + state.completedAgents.push('recon'); + } // === Phases 3-4: Vulnerability Analysis + Exploitation (Pipelined) === // Each vuln type runs as an independent pipeline: @@ -165,22 +219,34 @@ export async function pentestPipelineWorkflow( state.currentAgent = 'pipelines'; await a.logPhaseTransition(activityInput, 'vulnerability-exploitation', 'start'); - // Helper: Run a single vuln→exploit pipeline + // Helper: Run a single vuln→exploit pipeline with skip logic async function runVulnExploitPipeline( vulnType: VulnType, runVulnAgent: () => Promise, runExploitAgent: () => Promise ): Promise { - // Step 1: Run vulnerability agent - const vulnMetrics = await runVulnAgent(); + const vulnAgentName = `${vulnType}-vuln`; + const exploitAgentName = `${vulnType}-exploit`; - // Step 2: Check exploitation queue (starts immediately after vuln) + // Step 1: Run vulnerability agent (or skip if completed) + let vulnMetrics: AgentMetrics | null = null; + if (!shouldSkip(vulnAgentName)) { + vulnMetrics = await runVulnAgent(); + } else { + console.log(`Skipping ${vulnAgentName} (already complete)`); + } + + // Step 2: Check exploitation queue (only if vuln agent ran or completed previously) const decision = await a.checkExploitationQueue(activityInput, vulnType); - // Step 3: Conditionally run exploit agent + // Step 3: Conditionally run exploit agent (skip if already completed) let exploitMetrics: AgentMetrics | null = null; if (decision.shouldExploit) { - exploitMetrics = await runExploitAgent(); + if (!shouldSkip(exploitAgentName)) { + exploitMetrics = await runExploitAgent(); + } else { + console.log(`Skipping ${exploitAgentName} (already complete)`); + } } return { @@ -195,35 +261,75 @@ export async function pentestPipelineWorkflow( }; } - // Run all 5 pipelines in parallel with graceful failure handling + // Determine which pipelines to run (skip if both vuln and exploit completed) + const pipelinesToRun: Array> = []; + + // Only run pipeline if at least one agent (vuln or exploit) is incomplete + const pipelineConfigs: Array<{ + vulnType: VulnType; + vulnAgent: string; + exploitAgent: string; + runVuln: () => Promise; + runExploit: () => Promise; + }> = [ + { + vulnType: 'injection', + vulnAgent: 'injection-vuln', + exploitAgent: 'injection-exploit', + runVuln: () => a.runInjectionVulnAgent(activityInput), + runExploit: () => a.runInjectionExploitAgent(activityInput), + }, + { + vulnType: 'xss', + vulnAgent: 'xss-vuln', + exploitAgent: 'xss-exploit', + runVuln: () => a.runXssVulnAgent(activityInput), + runExploit: () => a.runXssExploitAgent(activityInput), + }, + { + vulnType: 'auth', + vulnAgent: 'auth-vuln', + exploitAgent: 'auth-exploit', + runVuln: () => a.runAuthVulnAgent(activityInput), + runExploit: () => a.runAuthExploitAgent(activityInput), + }, + { + vulnType: 'ssrf', + vulnAgent: 'ssrf-vuln', + exploitAgent: 'ssrf-exploit', + runVuln: () => a.runSsrfVulnAgent(activityInput), + runExploit: () => a.runSsrfExploitAgent(activityInput), + }, + { + vulnType: 'authz', + vulnAgent: 'authz-vuln', + exploitAgent: 'authz-exploit', + runVuln: () => a.runAuthzVulnAgent(activityInput), + runExploit: () => a.runAuthzExploitAgent(activityInput), + }, + ]; + + for (const config of pipelineConfigs) { + const vulnComplete = shouldSkip(config.vulnAgent); + const exploitComplete = shouldSkip(config.exploitAgent); + + // Only run pipeline if at least one agent needs to run + if (!vulnComplete || !exploitComplete) { + pipelinesToRun.push( + runVulnExploitPipeline(config.vulnType, config.runVuln, config.runExploit) + ); + } else { + console.log( + `Skipping entire ${config.vulnType} pipeline (both agents complete)` + ); + // Still need to mark them as completed in state + state.completedAgents.push(config.vulnAgent, config.exploitAgent); + } + } + + // Run pipelines in parallel with graceful failure handling // Promise.allSettled ensures other pipelines continue if one fails - const pipelineResults = await Promise.allSettled([ - runVulnExploitPipeline( - 'injection', - () => a.runInjectionVulnAgent(activityInput), - () => a.runInjectionExploitAgent(activityInput) - ), - runVulnExploitPipeline( - 'xss', - () => a.runXssVulnAgent(activityInput), - () => a.runXssExploitAgent(activityInput) - ), - runVulnExploitPipeline( - 'auth', - () => a.runAuthVulnAgent(activityInput), - () => a.runAuthExploitAgent(activityInput) - ), - runVulnExploitPipeline( - 'ssrf', - () => a.runSsrfVulnAgent(activityInput), - () => a.runSsrfExploitAgent(activityInput) - ), - runVulnExploitPipeline( - 'authz', - () => a.runAuthzVulnAgent(activityInput), - () => a.runAuthzExploitAgent(activityInput) - ), - ]); + const pipelineResults = await Promise.allSettled(pipelinesToRun); // Aggregate results from all pipelines const failedPipelines: string[] = []; @@ -231,16 +337,24 @@ export async function pentestPipelineWorkflow( if (result.status === 'fulfilled') { const { vulnType, vulnMetrics, exploitMetrics } = result.value; - // Record vuln agent metrics + // Record vuln agent + const vulnAgentName = `${vulnType}-vuln`; if (vulnMetrics) { - state.agentMetrics[`${vulnType}-vuln`] = vulnMetrics; - state.completedAgents.push(`${vulnType}-vuln`); + state.agentMetrics[vulnAgentName] = vulnMetrics; + state.completedAgents.push(vulnAgentName); + } else if (shouldSkip(vulnAgentName)) { + // Agent was skipped because already complete + state.completedAgents.push(vulnAgentName); } - // Record exploit agent metrics (if it ran) + // Record exploit agent (if it ran) + const exploitAgentName = `${vulnType}-exploit`; if (exploitMetrics) { - state.agentMetrics[`${vulnType}-exploit`] = exploitMetrics; - state.completedAgents.push(`${vulnType}-exploit`); + state.agentMetrics[exploitAgentName] = exploitMetrics; + state.completedAgents.push(exploitAgentName); + } else if (shouldSkip(exploitAgentName)) { + // Agent was skipped because already complete + state.completedAgents.push(exploitAgentName); } } else { // Pipeline failed - log error but continue with others @@ -266,21 +380,26 @@ export async function pentestPipelineWorkflow( await a.logPhaseTransition(activityInput, 'vulnerability-exploitation', 'complete'); // === Phase 5: Reporting === - state.currentPhase = 'reporting'; - state.currentAgent = 'report'; - await a.logPhaseTransition(activityInput, 'reporting', 'start'); + if (!shouldSkip('report')) { + state.currentPhase = 'reporting'; + state.currentAgent = 'report'; + await a.logPhaseTransition(activityInput, 'reporting', 'start'); - // First, assemble the concatenated report from exploitation evidence files - await a.assembleReportActivity(activityInput); + // First, assemble the concatenated report from exploitation evidence files + await a.assembleReportActivity(activityInput); - // Then run the report agent to add executive summary and clean up - state.agentMetrics['report'] = await a.runReportAgent(activityInput); - state.completedAgents.push('report'); + // Then run the report agent to add executive summary and clean up + state.agentMetrics['report'] = await a.runReportAgent(activityInput); + state.completedAgents.push('report'); - // Inject model metadata into the final report - await a.injectReportMetadataActivity(activityInput); + // Inject model metadata into the final report + await a.injectReportMetadataActivity(activityInput); - await a.logPhaseTransition(activityInput, 'reporting', 'complete'); + await a.logPhaseTransition(activityInput, 'reporting', 'complete'); + } else { + console.log('Skipping report (already complete)'); + state.completedAgents.push('report'); + } // === Complete === state.status = 'completed'; diff --git a/src/types/agents.ts b/src/types/agents.ts index a47256f..481346a 100644 --- a/src/types/agents.ts +++ b/src/types/agents.ts @@ -8,20 +8,33 @@ * Agent type definitions */ -export type AgentName = - | 'pre-recon' - | 'recon' - | 'injection-vuln' - | 'xss-vuln' - | 'auth-vuln' - | 'ssrf-vuln' - | 'authz-vuln' - | 'injection-exploit' - | 'xss-exploit' - | 'auth-exploit' - | 'ssrf-exploit' - | 'authz-exploit' - | 'report'; +import path from 'path'; + +/** + * List of all agents in execution order. + * Used for iteration during resume state checking. + */ +export const ALL_AGENTS = [ + 'pre-recon', + 'recon', + 'injection-vuln', + 'xss-vuln', + 'auth-vuln', + 'ssrf-vuln', + 'authz-vuln', + 'injection-exploit', + 'xss-exploit', + 'auth-exploit', + 'ssrf-exploit', + 'authz-exploit', + 'report', +] as const; + +/** + * Agent name type derived from ALL_AGENTS. + * This ensures type safety and prevents drift between type and array. + */ +export type AgentName = typeof ALL_AGENTS[number]; export type PromptName = | 'pre-recon-code' @@ -82,3 +95,28 @@ export function getPromptNameForAgent(agentName: AgentName): PromptName { return mappings[agentName]; } + +/** + * Maps an agent name to its deliverable file path. + * Must match mcp-server/src/types/deliverables.ts:DELIVERABLE_FILENAMES + */ +export function getDeliverablePath(agentName: AgentName, repoPath: string): string { + const deliverableMap: Record = { + 'pre-recon': 'code_analysis_deliverable.md', + 'recon': 'recon_deliverable.md', + 'injection-vuln': 'injection_analysis_deliverable.md', + 'xss-vuln': 'xss_analysis_deliverable.md', + 'auth-vuln': 'auth_analysis_deliverable.md', + 'ssrf-vuln': 'ssrf_analysis_deliverable.md', + 'authz-vuln': 'authz_analysis_deliverable.md', + 'injection-exploit': 'injection_exploitation_evidence.md', + 'xss-exploit': 'xss_exploitation_evidence.md', + 'auth-exploit': 'auth_exploitation_evidence.md', + 'ssrf-exploit': 'ssrf_exploitation_evidence.md', + 'authz-exploit': 'authz_exploitation_evidence.md', + 'report': 'comprehensive_security_assessment_report.md', + }; + + const filename = deliverableMap[agentName]; + return path.join(repoPath, 'deliverables', filename); +}