diff --git a/src/ai/audit-logger.ts b/src/ai/audit-logger.ts new file mode 100644 index 0000000..e7d4491 --- /dev/null +++ b/src/ai/audit-logger.ts @@ -0,0 +1,79 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +// Null Object pattern for audit logging - callers never check for null + +import type { AuditSession } from '../audit/index.js'; +import { formatTimestamp } from '../utils/formatting.js'; + +export interface AuditLogger { + logLlmResponse(turn: number, content: string): Promise; + logToolStart(toolName: string, parameters: unknown): Promise; + logToolEnd(result: unknown): Promise; + logError(error: Error, duration: number, turns: number): Promise; +} + +class RealAuditLogger implements AuditLogger { + private auditSession: AuditSession; + + constructor(auditSession: AuditSession) { + this.auditSession = auditSession; + } + + async logLlmResponse(turn: number, content: string): Promise { + await this.auditSession.logEvent('llm_response', { + turn, + content, + timestamp: formatTimestamp(), + }); + } + + async logToolStart(toolName: string, parameters: unknown): Promise { + await this.auditSession.logEvent('tool_start', { + toolName, + parameters, + timestamp: formatTimestamp(), + }); + } + + async logToolEnd(result: unknown): Promise { + await this.auditSession.logEvent('tool_end', { + result, + timestamp: formatTimestamp(), + }); + } + + async logError(error: Error, duration: number, turns: number): Promise { + await this.auditSession.logEvent('error', { + message: error.message, + errorType: error.constructor.name, + stack: error.stack, + duration, + turns, + timestamp: formatTimestamp(), + }); + } +} + +/** Null Object implementation - all methods are safe no-ops */ +class NullAuditLogger implements AuditLogger { + async logLlmResponse(_turn: number, _content: string): Promise {} + + async logToolStart(_toolName: string, _parameters: unknown): Promise {} + + async logToolEnd(_result: unknown): Promise {} + + async logError(_error: Error, _duration: number, _turns: number): Promise {} +} + +// Returns no-op when auditSession is null +export function createAuditLogger(auditSession: AuditSession | null): AuditLogger { + if (auditSession) { + return new RealAuditLogger(auditSession); + } + + return new NullAuditLogger(); +} diff --git a/src/ai/claude-executor.ts b/src/ai/claude-executor.ts index c8fa5f2..d676379 100644 --- a/src/ai/claude-executor.ts +++ b/src/ai/claude-executor.ts @@ -4,34 +4,33 @@ // it under the terms of the GNU Affero General Public License version 3 // as published by the Free Software Foundation. -import { $, fs, path } from 'zx'; +// Production Claude agent execution with retry, git checkpoints, and audit logging + +import { fs, path } from 'zx'; import chalk, { type ChalkInstance } from 'chalk'; import { query } from '@anthropic-ai/claude-agent-sdk'; -import { fileURLToPath } from 'url'; -import { dirname } from 'path'; import { isRetryableError, getRetryDelay, PentestError } from '../error-handling.js'; -import { ProgressIndicator } from '../progress-indicator.js'; -import { timingResults, costResults, Timer } from '../utils/metrics.js'; -import { formatDuration } from '../audit/utils.js'; -import { createGitCheckpoint, commitGitSuccess, rollbackGitWorkspace } from '../utils/git-manager.js'; +import { timingResults, Timer } from '../utils/metrics.js'; +import { formatTimestamp } from '../utils/formatting.js'; +import { createGitCheckpoint, commitGitSuccess, rollbackGitWorkspace, getGitCommitHash } from '../utils/git-manager.js'; import { AGENT_VALIDATORS, MCP_AGENT_MAPPING } from '../constants.js'; -import { filterJsonToolCalls, getAgentPrefix } from '../utils/output-formatter.js'; import { generateSessionLogPath } from '../session-manager.js'; import { AuditSession } from '../audit/index.js'; import { createShannonHelperServer } from '../../mcp-server/dist/index.js'; import type { SessionMetadata } from '../audit/utils.js'; -import type { PromptName } from '../types/index.js'; +import { getPromptNameForAgent } from '../types/agents.js'; +import type { AgentName } from '../types/index.js'; + +import { dispatchMessage } from './message-handlers.js'; +import { detectExecutionContext, formatErrorOutput, formatCompletionMessage } from './output-formatters.js'; +import { createProgressManager } from './progress-manager.js'; +import { createAuditLogger } from './audit-logger.js'; -// Extend global for loader flag declare global { var SHANNON_DISABLE_LOADER: boolean | undefined; } -const __filename = fileURLToPath(import.meta.url); -const __dirname = dirname(__filename); - -// Result types interface ClaudePromptResult { result?: string | null; success: boolean; @@ -47,7 +46,6 @@ interface ClaudePromptResult { retryable?: boolean; } -// MCP Server types interface StdioMcpServer { type: 'stdio'; command: string; @@ -57,157 +55,29 @@ interface StdioMcpServer { type McpServer = ReturnType | StdioMcpServer; -/** - * Convert agent name to prompt name for MCP_AGENT_MAPPING lookup - */ -function agentNameToPromptName(agentName: string): PromptName { - // Special cases - if (agentName === 'pre-recon') return 'pre-recon-code'; - if (agentName === 'report') return 'report-executive'; - if (agentName === 'recon') return 'recon'; - - // Pattern: {type}-vuln โ†’ vuln-{type} - const vulnMatch = agentName.match(/^(.+)-vuln$/); - if (vulnMatch) { - return `vuln-${vulnMatch[1]}` as PromptName; - } - - // Pattern: {type}-exploit โ†’ exploit-{type} - const exploitMatch = agentName.match(/^(.+)-exploit$/); - if (exploitMatch) { - return `exploit-${exploitMatch[1]}` as PromptName; - } - - // Default: return as-is - return agentName as PromptName; -} - -// Simplified validation using direct agent name mapping -async function validateAgentOutput( - result: ClaudePromptResult, - agentName: string | null, - sourceDir: string -): Promise { - console.log(chalk.blue(` ๐Ÿ” Validating ${agentName} agent output`)); - - try { - // Check if agent completed successfully - if (!result.success || !result.result) { - console.log(chalk.red(` โŒ Validation failed: Agent execution was unsuccessful`)); - return false; - } - - // Get validator function for this agent - const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined; - - if (!validator) { - console.log(chalk.yellow(` โš ๏ธ No validator found for agent "${agentName}" - assuming success`)); - console.log(chalk.green(` โœ… Validation passed: Unknown agent with successful result`)); - return true; - } - - console.log(chalk.blue(` ๐Ÿ“‹ Using validator for agent: ${agentName}`)); - console.log(chalk.blue(` ๐Ÿ“‚ Source directory: ${sourceDir}`)); - - // Apply validation function - const validationResult = await validator(sourceDir); - - if (validationResult) { - console.log(chalk.green(` โœ… Validation passed: Required files/structure present`)); - } else { - console.log(chalk.red(` โŒ Validation failed: Missing required deliverable files`)); - } - - return validationResult; - - } catch (error) { - const errMsg = error instanceof Error ? error.message : String(error); - console.log(chalk.red(` โŒ Validation failed with error: ${errMsg}`)); - return false; // Assume invalid on validation error - } -} - -// Pure function: Run Claude Code with SDK - Maximum Autonomy -// WARNING: This is a low-level function. Use runClaudePromptWithRetry() for agent execution -async function runClaudePrompt( - prompt: string, +// Configures MCP servers for agent execution, with Docker-specific Chromium handling +function buildMcpServers( sourceDir: string, - _allowedTools: string = 'Read', - context: string = '', - description: string = 'Claude analysis', - agentName: string | null = null, - colorFn: ChalkInstance = chalk.cyan, - sessionMetadata: SessionMetadata | null = null, - auditSession: AuditSession | null = null, - attemptNumber: number = 1 -): Promise { - const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`); - const fullPrompt = context ? `${context}\n\n${prompt}` : prompt; - let totalCost = 0; - let partialCost = 0; // Track partial cost for crash safety + agentName: string | null +): Record { + const shannonHelperServer = createShannonHelperServer(sourceDir); - // Auto-detect execution mode to adjust logging behavior - const isParallelExecution = description.includes('vuln agent') || description.includes('exploit agent'); - const useCleanOutput = description.includes('Pre-recon agent') || - description.includes('Recon agent') || - description.includes('Executive Summary and Report Cleanup') || - description.includes('vuln agent') || - description.includes('exploit agent'); + const mcpServers: Record = { + 'shannon-helper': shannonHelperServer, + }; - // Disable status manager - using simple JSON filtering for all agents now - const statusManager = null; + if (agentName) { + const promptName = getPromptNameForAgent(agentName as AgentName); + const playwrightMcpName = MCP_AGENT_MAPPING[promptName as keyof typeof MCP_AGENT_MAPPING] || null; - // Setup progress indicator for clean output agents (unless disabled via flag) - let progressIndicator: ProgressIndicator | null = null; - if (useCleanOutput && !global.SHANNON_DISABLE_LOADER) { - const agentType = description.includes('Pre-recon') ? 'pre-reconnaissance' : - description.includes('Recon') ? 'reconnaissance' : - description.includes('Report') ? 'report generation' : 'analysis'; - progressIndicator = new ProgressIndicator(`Running ${agentType}...`); - } - - // NOTE: Logging now handled by AuditSession (append-only, crash-safe) - let logFilePath: string | null = null; - if (sessionMetadata && sessionMetadata.webUrl && sessionMetadata.id) { - const timestamp = new Date().toISOString().replace(/T/, '_').replace(/[:.]/g, '-').slice(0, 19); - const agentKey = description.toLowerCase().replace(/\s+/g, '-'); - const logDir = generateSessionLogPath(sessionMetadata.webUrl, sessionMetadata.id); - logFilePath = path.join(logDir, `${timestamp}_${agentKey}_attempt-${attemptNumber}.log`); - } else { - console.log(chalk.blue(` ๐Ÿค– Running Claude Code: ${description}...`)); - } - - // Declare variables that need to be accessible in both try and catch blocks - let turnCount = 0; - - try { - // Create MCP server with target directory context - const shannonHelperServer = createShannonHelperServer(sourceDir); - - // Look up agent's assigned Playwright MCP server - let playwrightMcpName: string | null = null; - if (agentName) { - const promptName = agentNameToPromptName(agentName); - playwrightMcpName = MCP_AGENT_MAPPING[promptName as keyof typeof MCP_AGENT_MAPPING] || null; - - if (playwrightMcpName) { - console.log(chalk.gray(` ๐ŸŽญ Assigned ${agentName} โ†’ ${playwrightMcpName}`)); - } - } - - // Configure MCP servers: shannon-helper (SDK) + playwright-agentN (stdio) - const mcpServers: Record = { - 'shannon-helper': shannonHelperServer, - }; - - // Add Playwright MCP server if this agent needs browser automation if (playwrightMcpName) { + console.log(chalk.gray(` Assigned ${agentName} -> ${playwrightMcpName}`)); + const userDataDir = `/tmp/${playwrightMcpName}`; - // Detect if running in Docker via explicit environment variable + // Docker uses system Chromium; local dev uses Playwright's bundled browsers const isDocker = process.env.SHANNON_DOCKER === 'true'; - // Build args array - conditionally add --executable-path for Docker const mcpArgs: string[] = [ '@playwright/mcp@latest', '--isolated', @@ -220,7 +90,6 @@ async function runClaudePrompt( mcpArgs.push('--browser', 'chromium'); } - // Filter out undefined env values for type safety const envVars: Record = Object.fromEntries( Object.entries({ ...process.env, @@ -236,256 +105,169 @@ async function runClaudePrompt( env: envVars, }; } + } - const options = { - model: 'claude-sonnet-4-5-20250929', // Use latest Claude 4.5 Sonnet - maxTurns: 10_000, // Maximum turns for autonomous work - cwd: sourceDir, // Set working directory using SDK option - permissionMode: 'bypassPermissions' as const, // Bypass all permission checks for pentesting - mcpServers, + return mcpServers; +} + +function outputLines(lines: string[]): void { + for (const line of lines) { + console.log(line); + } +} + +async function writeErrorLog( + err: Error & { code?: string; status?: number }, + sourceDir: string, + fullPrompt: string, + duration: number +): Promise { + try { + const errorLog = { + timestamp: formatTimestamp(), + agent: 'claude-executor', + error: { + name: err.constructor.name, + message: err.message, + code: err.code, + status: err.status, + stack: err.stack + }, + context: { + sourceDir, + prompt: fullPrompt.slice(0, 200) + '...', + retryable: isRetryableError(err) + }, + duration }; + const logPath = path.join(sourceDir, 'error.log'); + await fs.appendFile(logPath, JSON.stringify(errorLog) + '\n'); + } catch (logError) { + const logErrMsg = logError instanceof Error ? logError.message : String(logError); + console.log(chalk.gray(` (Failed to write error log: ${logErrMsg})`)); + } +} - // SDK Options only shown for verbose agents (not clean output) - if (!useCleanOutput) { - console.log(chalk.gray(` SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`)); +async function validateAgentOutput( + result: ClaudePromptResult, + agentName: string | null, + sourceDir: string +): Promise { + console.log(chalk.blue(` Validating ${agentName} agent output`)); + + try { + // Check if agent completed successfully + if (!result.success || !result.result) { + console.log(chalk.red(` Validation failed: Agent execution was unsuccessful`)); + return false; } - let result: string | null = null; - const messages: string[] = []; - let apiErrorDetected = false; + // Get validator function for this agent + const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined; - // Start progress indicator for clean output agents - if (progressIndicator) { - progressIndicator.start(); + if (!validator) { + console.log(chalk.yellow(` No validator found for agent "${agentName}" - assuming success`)); + console.log(chalk.green(` Validation passed: Unknown agent with successful result`)); + return true; } - let lastHeartbeat = Date.now(); - const HEARTBEAT_INTERVAL = 30000; // 30 seconds + console.log(chalk.blue(` Using validator for agent: ${agentName}`)); + console.log(chalk.blue(` Source directory: ${sourceDir}`)); - try { - for await (const message of query({ prompt: fullPrompt, options })) { - // Periodic heartbeat for long-running agents (only when loader is disabled) - const now = Date.now(); - if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) { - console.log(chalk.blue(` โฑ๏ธ [${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`)); - lastHeartbeat = now; - } + // Apply validation function + const validationResult = await validator(sourceDir); - if (message.type === "assistant") { - turnCount++; - - const messageContent = message.message as { content: unknown }; - const content = Array.isArray(messageContent.content) - ? messageContent.content.map((c: { text?: string }) => c.text || JSON.stringify(c)).join('\n') - : String(messageContent.content); - - if (statusManager) { - // Smart status updates for parallel execution - disabled - } else if (useCleanOutput) { - // Clean output for all agents: filter JSON tool calls but show meaningful text - const cleanedContent = filterJsonToolCalls(content); - if (cleanedContent.trim()) { - // Temporarily stop progress indicator to show output - if (progressIndicator) { - progressIndicator.stop(); - } - - if (isParallelExecution) { - // Compact output for parallel agents with prefixes - const prefix = getAgentPrefix(description); - console.log(colorFn(`${prefix} ${cleanedContent}`)); - } else { - // Full turn output for single agents - console.log(colorFn(`\n ๐Ÿค– Turn ${turnCount} (${description}):`)); - console.log(colorFn(` ${cleanedContent}`)); - } - - // Restart progress indicator after output - if (progressIndicator) { - progressIndicator.start(); - } - } - } else { - // Full streaming output - show complete messages with specialist color - console.log(colorFn(`\n ๐Ÿค– Turn ${turnCount} (${description}):`)); - console.log(colorFn(` ${content}`)); - } - - // Log to audit system (crash-safe, append-only) - if (auditSession) { - await auditSession.logEvent('llm_response', { - turn: turnCount, - content, - timestamp: new Date().toISOString() - }); - } - - messages.push(content); - - // Check for API error patterns in assistant message content - if (content && typeof content === 'string') { - const lowerContent = content.toLowerCase(); - if (lowerContent.includes('session limit reached')) { - throw new PentestError('Session limit reached', 'billing', false); - } - if (lowerContent.includes('api error') || lowerContent.includes('terminated')) { - apiErrorDetected = true; - console.log(chalk.red(` โš ๏ธ API Error detected in assistant response: ${content.trim()}`)); - } - } - - } else if (message.type === "system" && (message as { subtype?: string }).subtype === "init") { - // Show useful system info only for verbose agents - if (!useCleanOutput) { - const initMsg = message as { model?: string; permissionMode?: string; mcp_servers?: Array<{ name: string; status: string }> }; - console.log(chalk.blue(` โ„น๏ธ Model: ${initMsg.model}, Permission: ${initMsg.permissionMode}`)); - if (initMsg.mcp_servers && initMsg.mcp_servers.length > 0) { - const mcpStatus = initMsg.mcp_servers.map(s => `${s.name}(${s.status})`).join(', '); - console.log(chalk.blue(` ๐Ÿ“ฆ MCP: ${mcpStatus}`)); - } - } - - } else if (message.type === "user") { - // Skip user messages (these are our own inputs echoed back) - continue; - - } else if ((message.type as string) === "tool_use") { - const toolMsg = message as unknown as { name: string; input?: Record }; - console.log(chalk.yellow(`\n ๐Ÿ”ง Using Tool: ${toolMsg.name}`)); - if (toolMsg.input && Object.keys(toolMsg.input).length > 0) { - console.log(chalk.gray(` Input: ${JSON.stringify(toolMsg.input, null, 2)}`)); - } - - // Log tool start event - if (auditSession) { - await auditSession.logEvent('tool_start', { - toolName: toolMsg.name, - parameters: toolMsg.input, - timestamp: new Date().toISOString() - }); - } - } else if ((message.type as string) === "tool_result") { - const resultMsg = message as unknown as { content?: unknown }; - console.log(chalk.green(` โœ… Tool Result:`)); - if (resultMsg.content) { - // Show tool results but truncate if too long - const resultStr = typeof resultMsg.content === 'string' ? resultMsg.content : JSON.stringify(resultMsg.content, null, 2); - if (resultStr.length > 500) { - console.log(chalk.gray(` ${resultStr.slice(0, 500)}...\n [Result truncated - ${resultStr.length} total chars]`)); - } else { - console.log(chalk.gray(` ${resultStr}`)); - } - } - - // Log tool end event - if (auditSession) { - await auditSession.logEvent('tool_end', { - result: resultMsg.content, - timestamp: new Date().toISOString() - }); - } - } else if (message.type === "result") { - const resultMessage = message as { - result?: string; - total_cost_usd?: number; - duration_ms?: number; - subtype?: string; - permission_denials?: unknown[]; - }; - result = resultMessage.result || null; - - if (!statusManager) { - if (useCleanOutput) { - // Clean completion output - just duration and cost - console.log(chalk.magenta(`\n ๐Ÿ COMPLETED:`)); - const cost = resultMessage.total_cost_usd || 0; - console.log(chalk.gray(` โฑ๏ธ Duration: ${((resultMessage.duration_ms || 0)/1000).toFixed(1)}s, Cost: $${cost.toFixed(4)}`)); - - if (resultMessage.subtype === "error_max_turns") { - console.log(chalk.red(` โš ๏ธ Stopped: Hit maximum turns limit`)); - } else if (resultMessage.subtype === "error_during_execution") { - console.log(chalk.red(` โŒ Stopped: Execution error`)); - } - - if (resultMessage.permission_denials && resultMessage.permission_denials.length > 0) { - console.log(chalk.yellow(` ๐Ÿšซ ${resultMessage.permission_denials.length} permission denials`)); - } - } else { - // Full completion output for agents without clean output - console.log(chalk.magenta(`\n ๐Ÿ COMPLETED:`)); - const cost = resultMessage.total_cost_usd || 0; - console.log(chalk.gray(` โฑ๏ธ Duration: ${((resultMessage.duration_ms || 0)/1000).toFixed(1)}s, Cost: $${cost.toFixed(4)}`)); - - if (resultMessage.subtype === "error_max_turns") { - console.log(chalk.red(` โš ๏ธ Stopped: Hit maximum turns limit`)); - } else if (resultMessage.subtype === "error_during_execution") { - console.log(chalk.red(` โŒ Stopped: Execution error`)); - } - - if (resultMessage.permission_denials && resultMessage.permission_denials.length > 0) { - console.log(chalk.yellow(` ๐Ÿšซ ${resultMessage.permission_denials.length} permission denials`)); - } - - // Show result content (if it's reasonable length) - if (result && typeof result === 'string') { - if (result.length > 1000) { - console.log(chalk.magenta(` ๐Ÿ“„ ${result.slice(0, 1000)}... [${result.length} total chars]`)); - } else { - console.log(chalk.magenta(` ๐Ÿ“„ ${result}`)); - } - } - } - } - - // Track cost for all agents - const cost = resultMessage.total_cost_usd || 0; - const agentKey = description.toLowerCase().replace(/\s+/g, '-'); - costResults.agents[agentKey] = cost; - costResults.total += cost; - - // Store cost for return value and partial tracking - totalCost = cost; - partialCost = cost; - break; - } else { - // Log any other message types we might not be handling - console.log(chalk.gray(` ๐Ÿ’ฌ ${message.type}: ${JSON.stringify(message, null, 2)}`)); - } - } - } catch (queryError) { - throw queryError; // Re-throw to outer catch + if (validationResult) { + console.log(chalk.green(` Validation passed: Required files/structure present`)); + } else { + console.log(chalk.red(` Validation failed: Missing required deliverable files`)); } + return validationResult; + + } catch (error) { + const errMsg = error instanceof Error ? error.message : String(error); + console.log(chalk.red(` Validation failed with error: ${errMsg}`)); + return false; + } +} + +// Low-level SDK execution. Handles message streaming, progress, and audit logging. +async function runClaudePrompt( + prompt: string, + sourceDir: string, + context: string = '', + description: string = 'Claude analysis', + agentName: string | null = null, + colorFn: ChalkInstance = chalk.cyan, + sessionMetadata: SessionMetadata | null = null, + auditSession: AuditSession | null = null, + attemptNumber: number = 1 +): Promise { + const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`); + const fullPrompt = context ? `${context}\n\n${prompt}` : prompt; + + const execContext = detectExecutionContext(description); + const progress = createProgressManager( + { description, useCleanOutput: execContext.useCleanOutput }, + global.SHANNON_DISABLE_LOADER ?? false + ); + const auditLogger = createAuditLogger(auditSession); + + const logFilePath = buildLogFilePath(sessionMetadata, execContext.agentKey, attemptNumber); + if (!logFilePath) { + console.log(chalk.blue(` Running Claude Code: ${description}...`)); + } + + const mcpServers = buildMcpServers(sourceDir, agentName); + const options = { + model: 'claude-sonnet-4-5-20250929', + maxTurns: 10_000, + cwd: sourceDir, + permissionMode: 'bypassPermissions' as const, + mcpServers, + }; + + if (!execContext.useCleanOutput) { + console.log(chalk.gray(` SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`)); + } + + let turnCount = 0; + let result: string | null = null; + let apiErrorDetected = false; + let totalCost = 0; + + progress.start(); + + try { + const messageLoopResult = await processMessageStream( + fullPrompt, + options, + { execContext, description, colorFn, progress, auditLogger }, + timer + ); + + turnCount = messageLoopResult.turnCount; + result = messageLoopResult.result; + apiErrorDetected = messageLoopResult.apiErrorDetected; + totalCost = messageLoopResult.cost; + const duration = timer.stop(); - const agentKey = description.toLowerCase().replace(/\s+/g, '-'); - timingResults.agents[agentKey] = duration; + timingResults.agents[execContext.agentKey] = duration; - // API error detection is logged but not immediately failed if (apiErrorDetected) { - console.log(chalk.yellow(` โš ๏ธ API Error detected in ${description} - will validate deliverables before failing`)); + console.log(chalk.yellow(` API Error detected in ${description} - will validate deliverables before failing`)); } - // Show completion messages based on agent type - if (progressIndicator) { - const agentType = description.includes('Pre-recon') ? 'Pre-recon analysis' : - description.includes('Recon') ? 'Reconnaissance' : - description.includes('Report') ? 'Report generation' : 'Analysis'; - progressIndicator.finish(`${agentType} complete! (${turnCount} turns, ${formatDuration(duration)})`); - } else if (isParallelExecution) { - const prefix = getAgentPrefix(description); - console.log(chalk.green(`${prefix} โœ… Complete (${turnCount} turns, ${formatDuration(duration)})`)); - } else if (!useCleanOutput) { - console.log(chalk.green(` โœ… Claude Code completed: ${description} (${turnCount} turns) in ${formatDuration(duration)}`)); - } + progress.finish(formatCompletionMessage(execContext, description, turnCount, duration)); - // Return result with log file path for all agents const returnData: ClaudePromptResult = { result, success: true, duration, turns: turnCount, cost: totalCost, - partialCost, + partialCost: totalCost, apiErrorDetected }; if (logFilePath) { @@ -495,76 +277,14 @@ async function runClaudePrompt( } catch (error) { const duration = timer.stop(); - const agentKey = description.toLowerCase().replace(/\s+/g, '-'); - timingResults.agents[agentKey] = duration; + timingResults.agents[execContext.agentKey] = duration; - const err = error as Error & { code?: string; status?: number; duration?: number; cost?: number }; + const err = error as Error & { code?: string; status?: number }; - // Log error to audit system - if (auditSession) { - await auditSession.logEvent('error', { - message: err.message, - errorType: err.constructor.name, - stack: err.stack, - duration, - turns: turnCount, - timestamp: new Date().toISOString() - }); - } - - // Show error messages based on agent type - if (progressIndicator) { - progressIndicator.stop(); - const agentType = description.includes('Pre-recon') ? 'Pre-recon analysis' : - description.includes('Recon') ? 'Reconnaissance' : - description.includes('Report') ? 'Report generation' : 'Analysis'; - console.log(chalk.red(`โŒ ${agentType} failed (${formatDuration(duration)})`)); - } else if (isParallelExecution) { - const prefix = getAgentPrefix(description); - console.log(chalk.red(`${prefix} โŒ Failed (${formatDuration(duration)})`)); - } else if (!useCleanOutput) { - console.log(chalk.red(` โŒ Claude Code failed: ${description} (${formatDuration(duration)})`)); - } - console.log(chalk.red(` Error Type: ${err.constructor.name}`)); - console.log(chalk.red(` Message: ${err.message}`)); - console.log(chalk.gray(` Agent: ${description}`)); - console.log(chalk.gray(` Working Directory: ${sourceDir}`)); - console.log(chalk.gray(` Retryable: ${isRetryableError(err) ? 'Yes' : 'No'}`)); - - // Log additional context if available - if (err.code) { - console.log(chalk.gray(` Error Code: ${err.code}`)); - } - if (err.status) { - console.log(chalk.gray(` HTTP Status: ${err.status}`)); - } - - // Save detailed error to log file for debugging - try { - const errorLog = { - timestamp: new Date().toISOString(), - agent: description, - error: { - name: err.constructor.name, - message: err.message, - code: err.code, - status: err.status, - stack: err.stack - }, - context: { - sourceDir, - prompt: fullPrompt.slice(0, 200) + '...', - retryable: isRetryableError(err) - }, - duration - }; - - const logPath = path.join(sourceDir, 'error.log'); - await fs.appendFile(logPath, JSON.stringify(errorLog) + '\n'); - } catch (logError) { - const logErrMsg = logError instanceof Error ? logError.message : String(logError); - console.log(chalk.gray(` (Failed to write error log: ${logErrMsg})`)); - } + await auditLogger.logError(err, duration, turnCount); + progress.stop(); + outputLines(formatErrorOutput(err, execContext, description, duration, sourceDir, isRetryableError(err))); + await writeErrorLog(err, sourceDir, fullPrompt, duration); return { error: err.message, @@ -572,17 +292,97 @@ async function runClaudePrompt( prompt: fullPrompt.slice(0, 100) + '...', success: false, duration, - cost: partialCost, + cost: totalCost, retryable: isRetryableError(err) }; } } -// PREFERRED: Production-ready Claude agent execution with full orchestration +function buildLogFilePath( + sessionMetadata: SessionMetadata | null, + agentKey: string, + attemptNumber: number +): string | null { + if (!sessionMetadata || !sessionMetadata.webUrl || !sessionMetadata.id) { + return null; + } + const timestamp = formatTimestamp().replace(/T/, '_').replace(/[:.]/g, '-').slice(0, 19); + const logDir = generateSessionLogPath(sessionMetadata.webUrl, sessionMetadata.id); + return path.join(logDir, `${timestamp}_${agentKey}_attempt-${attemptNumber}.log`); +} + +interface MessageLoopResult { + turnCount: number; + result: string | null; + apiErrorDetected: boolean; + cost: number; +} + +interface MessageLoopDeps { + execContext: ReturnType; + description: string; + colorFn: ChalkInstance; + progress: ReturnType; + auditLogger: ReturnType; +} + +async function processMessageStream( + fullPrompt: string, + options: NonNullable[0]['options']>, + deps: MessageLoopDeps, + timer: Timer +): Promise { + const { execContext, description, colorFn, progress, auditLogger } = deps; + const HEARTBEAT_INTERVAL = 30000; + + let turnCount = 0; + let result: string | null = null; + let apiErrorDetected = false; + let cost = 0; + let lastHeartbeat = Date.now(); + + for await (const message of query({ prompt: fullPrompt, options })) { + // Heartbeat logging when loader is disabled + const now = Date.now(); + if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) { + console.log(chalk.blue(` [${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`)); + lastHeartbeat = now; + } + + // Increment turn count for assistant messages + if (message.type === 'assistant') { + turnCount++; + } + + const dispatchResult = await dispatchMessage( + message as { type: string; subtype?: string }, + turnCount, + { execContext, description, colorFn, progress, auditLogger } + ); + + if (dispatchResult.type === 'throw') { + throw dispatchResult.error; + } + + if (dispatchResult.type === 'complete') { + result = dispatchResult.result; + cost = dispatchResult.cost; + break; + } + + if (dispatchResult.type === 'continue' && dispatchResult.apiErrorDetected) { + apiErrorDetected = true; + } + } + + return { turnCount, result, apiErrorDetected, cost }; +} + +// Main entry point for agent execution. Handles retries, git checkpoints, and validation. export async function runClaudePromptWithRetry( prompt: string, sourceDir: string, - allowedTools: string = 'Read', + _allowedTools: string = 'Read', context: string = '', description: string = 'Claude analysis', agentName: string | null = null, @@ -593,9 +393,8 @@ export async function runClaudePromptWithRetry( let lastError: Error | undefined; let retryContext = context; - console.log(chalk.cyan(`๐Ÿš€ Starting ${description} with ${maxRetries} max attempts`)); + console.log(chalk.cyan(`Starting ${description} with ${maxRetries} max attempts`)); - // Initialize audit session (crash-safe logging) let auditSession: AuditSession | null = null; if (sessionMetadata && agentName) { auditSession = new AuditSession(sessionMetadata); @@ -603,29 +402,27 @@ export async function runClaudePromptWithRetry( } for (let attempt = 1; attempt <= maxRetries; attempt++) { - // Create checkpoint before each attempt await createGitCheckpoint(sourceDir, description, attempt); - // Start agent tracking in audit system (saves prompt snapshot automatically) if (auditSession && agentName) { const fullPrompt = retryContext ? `${retryContext}\n\n${prompt}` : prompt; await auditSession.startAgent(agentName, fullPrompt, attempt); } try { - const result = await runClaudePrompt(prompt, sourceDir, allowedTools, retryContext, description, agentName, colorFn, sessionMetadata, auditSession, attempt); + const result = await runClaudePrompt( + prompt, sourceDir, retryContext, + description, agentName, colorFn, sessionMetadata, auditSession, attempt + ); - // Validate output after successful run if (result.success) { const validationPassed = await validateAgentOutput(result, agentName, sourceDir); if (validationPassed) { - // Check if API error was detected but validation passed if (result.apiErrorDetected) { - console.log(chalk.yellow(`๐Ÿ“‹ Validation: Ready for exploitation despite API error warnings`)); + console.log(chalk.yellow(`Validation: Ready for exploitation despite API error warnings`)); } - // Record successful attempt in audit system if (auditSession && agentName) { const commitHash = await getGitCommitHash(sourceDir); const endResult: { @@ -646,15 +443,13 @@ export async function runClaudePromptWithRetry( await auditSession.endAgent(agentName, endResult); } - // Commit successful changes (will include the snapshot) await commitGitSuccess(sourceDir, description); - console.log(chalk.green.bold(`๐ŸŽ‰ ${description} completed successfully on attempt ${attempt}/${maxRetries}`)); + console.log(chalk.green.bold(`${description} completed successfully on attempt ${attempt}/${maxRetries}`)); return result; + // Validation failure is retryable - agent might succeed on retry with cleaner workspace } else { - // Agent completed but output validation failed - console.log(chalk.yellow(`โš ๏ธ ${description} completed but output validation failed`)); + console.log(chalk.yellow(`${description} completed but output validation failed`)); - // Record failed validation attempt in audit system if (auditSession && agentName) { await auditSession.endAgent(agentName, { attemptNumber: attempt, @@ -666,20 +461,17 @@ export async function runClaudePromptWithRetry( }); } - // If API error detected AND validation failed, this is a retryable error if (result.apiErrorDetected) { - console.log(chalk.yellow(`โš ๏ธ API Error detected with validation failure - treating as retryable`)); + console.log(chalk.yellow(`API Error detected with validation failure - treating as retryable`)); lastError = new Error('API Error: terminated with validation failure'); } else { lastError = new Error('Output validation failed'); } if (attempt < maxRetries) { - // Rollback contaminated workspace await rollbackGitWorkspace(sourceDir, 'validation failure'); continue; } else { - // FAIL FAST - Don't continue with broken pipeline throw new PentestError( `Agent ${description} failed output validation after ${maxRetries} attempts. Required deliverable files were not created.`, 'validation', @@ -694,7 +486,6 @@ export async function runClaudePromptWithRetry( const err = error as Error & { duration?: number; cost?: number; partialResults?: unknown }; lastError = err; - // Record failed attempt in audit system if (auditSession && agentName) { await auditSession.endAgent(agentName, { attemptNumber: attempt, @@ -706,24 +497,21 @@ export async function runClaudePromptWithRetry( }); } - // Check if error is retryable if (!isRetryableError(err)) { - console.log(chalk.red(`โŒ ${description} failed with non-retryable error: ${err.message}`)); + console.log(chalk.red(`${description} failed with non-retryable error: ${err.message}`)); await rollbackGitWorkspace(sourceDir, 'non-retryable error cleanup'); throw err; } if (attempt < maxRetries) { - // Rollback for clean retry await rollbackGitWorkspace(sourceDir, 'retryable error cleanup'); const delay = getRetryDelay(err, attempt); const delaySeconds = (delay / 1000).toFixed(1); - console.log(chalk.yellow(`โš ๏ธ ${description} failed (attempt ${attempt}/${maxRetries})`)); + console.log(chalk.yellow(`${description} failed (attempt ${attempt}/${maxRetries})`)); console.log(chalk.gray(` Error: ${err.message}`)); console.log(chalk.gray(` Workspace rolled back, retrying in ${delaySeconds}s...`)); - // Preserve any partial results for next retry if (err.partialResults) { retryContext = `${context}\n\nPrevious partial results: ${JSON.stringify(err.partialResults)}`; } @@ -731,7 +519,7 @@ export async function runClaudePromptWithRetry( await new Promise(resolve => setTimeout(resolve, delay)); } else { await rollbackGitWorkspace(sourceDir, 'final failure cleanup'); - console.log(chalk.red(`โŒ ${description} failed after ${maxRetries} attempts`)); + console.log(chalk.red(`${description} failed after ${maxRetries} attempts`)); console.log(chalk.red(` Final error: ${err.message}`)); } } @@ -739,13 +527,3 @@ export async function runClaudePromptWithRetry( throw lastError; } - -// Helper function to get git commit hash -async function getGitCommitHash(sourceDir: string): Promise { - try { - const result = await $`cd ${sourceDir} && git rev-parse HEAD`; - return result.stdout.trim(); - } catch { - return null; - } -} diff --git a/src/ai/message-handlers.ts b/src/ai/message-handlers.ts new file mode 100644 index 0000000..098d239 --- /dev/null +++ b/src/ai/message-handlers.ts @@ -0,0 +1,244 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +// Pure functions for processing SDK message types + +import { PentestError } from '../error-handling.js'; +import { filterJsonToolCalls } from '../utils/output-formatter.js'; +import { formatTimestamp } from '../utils/formatting.js'; +import chalk from 'chalk'; +import { + formatAssistantOutput, + formatResultOutput, + formatToolUseOutput, + formatToolResultOutput, +} from './output-formatters.js'; +import { costResults } from '../utils/metrics.js'; +import type { AuditLogger } from './audit-logger.js'; +import type { ProgressManager } from './progress-manager.js'; +import type { + AssistantMessage, + ResultMessage, + ToolUseMessage, + ToolResultMessage, + AssistantResult, + ResultData, + ToolUseData, + ToolResultData, + ApiErrorDetection, + ContentBlock, + SystemInitMessage, + ExecutionContext, +} from './types.js'; +import type { ChalkInstance } from 'chalk'; + +// Handles both array and string content formats from SDK +export function extractMessageContent(message: AssistantMessage): string { + const messageContent = message.message; + + if (Array.isArray(messageContent.content)) { + return messageContent.content + .map((c: ContentBlock) => c.text || JSON.stringify(c)) + .join('\n'); + } + + return String(messageContent.content); +} + +export function detectApiError(content: string): ApiErrorDetection { + if (!content || typeof content !== 'string') { + return { detected: false }; + } + + const lowerContent = content.toLowerCase(); + + // Fatal error - should throw immediately + if (lowerContent.includes('session limit reached')) { + return { + detected: true, + shouldThrow: new PentestError('Session limit reached', 'billing', false), + }; + } + + // Non-fatal API errors - detected but continue + if (lowerContent.includes('api error') || lowerContent.includes('terminated')) { + return { detected: true }; + } + + return { detected: false }; +} + +export function handleAssistantMessage( + message: AssistantMessage, + turnCount: number +): AssistantResult { + const content = extractMessageContent(message); + const cleanedContent = filterJsonToolCalls(content); + const errorDetection = detectApiError(content); + + const result: AssistantResult = { + content, + cleanedContent, + apiErrorDetected: errorDetection.detected, + logData: { + turn: turnCount, + content, + timestamp: formatTimestamp(), + }, + }; + + // Only add shouldThrow if it exists (exactOptionalPropertyTypes compliance) + if (errorDetection.shouldThrow) { + result.shouldThrow = errorDetection.shouldThrow; + } + + return result; +} + +// Final message of a query with cost/duration info +export function handleResultMessage(message: ResultMessage): ResultData { + const result: ResultData = { + result: message.result || null, + cost: message.total_cost_usd || 0, + duration_ms: message.duration_ms || 0, + permissionDenials: message.permission_denials?.length || 0, + }; + + // Only add subtype if it exists (exactOptionalPropertyTypes compliance) + if (message.subtype) { + result.subtype = message.subtype; + } + + return result; +} + +export function handleToolUseMessage(message: ToolUseMessage): ToolUseData { + return { + toolName: message.name, + parameters: message.input || {}, + timestamp: formatTimestamp(), + }; +} + +// Truncates long results for display (500 char limit), preserves full content for logging +export function handleToolResultMessage(message: ToolResultMessage): ToolResultData { + const content = message.content; + const contentStr = + typeof content === 'string' ? content : JSON.stringify(content, null, 2); + + const displayContent = + contentStr.length > 500 + ? `${contentStr.slice(0, 500)}...\n[Result truncated - ${contentStr.length} total chars]` + : contentStr; + + return { + content, + displayContent, + timestamp: formatTimestamp(), + }; +} + +// Output helper for console logging +function outputLines(lines: string[]): void { + for (const line of lines) { + console.log(line); + } +} + +// Message dispatch result types +export type MessageDispatchAction = + | { type: 'continue'; apiErrorDetected?: boolean } + | { type: 'complete'; result: string | null; cost: number } + | { type: 'throw'; error: Error }; + +export interface MessageDispatchDeps { + execContext: ExecutionContext; + description: string; + colorFn: ChalkInstance; + progress: ProgressManager; + auditLogger: AuditLogger; +} + +// Dispatches SDK messages to appropriate handlers and formatters +export async function dispatchMessage( + message: { type: string; subtype?: string }, + turnCount: number, + deps: MessageDispatchDeps +): Promise { + const { execContext, description, colorFn, progress, auditLogger } = deps; + + switch (message.type) { + case 'assistant': { + const assistantResult = handleAssistantMessage(message as AssistantMessage, turnCount); + + if (assistantResult.shouldThrow) { + return { type: 'throw', error: assistantResult.shouldThrow }; + } + + if (assistantResult.cleanedContent.trim()) { + progress.stop(); + outputLines(formatAssistantOutput( + assistantResult.cleanedContent, + execContext, + turnCount, + description, + colorFn + )); + progress.start(); + } + + await auditLogger.logLlmResponse(turnCount, assistantResult.content); + + if (assistantResult.apiErrorDetected) { + console.log(chalk.red(` API Error detected in assistant response`)); + return { type: 'continue', apiErrorDetected: true }; + } + + return { type: 'continue' }; + } + + case 'system': { + if (message.subtype === 'init' && !execContext.useCleanOutput) { + const initMsg = message as SystemInitMessage; + console.log(chalk.blue(` Model: ${initMsg.model}, Permission: ${initMsg.permissionMode}`)); + if (initMsg.mcp_servers && initMsg.mcp_servers.length > 0) { + const mcpStatus = initMsg.mcp_servers.map(s => `${s.name}(${s.status})`).join(', '); + console.log(chalk.blue(` MCP: ${mcpStatus}`)); + } + } + return { type: 'continue' }; + } + + case 'user': + return { type: 'continue' }; + + case 'tool_use': { + const toolData = handleToolUseMessage(message as unknown as ToolUseMessage); + outputLines(formatToolUseOutput(toolData.toolName, toolData.parameters)); + await auditLogger.logToolStart(toolData.toolName, toolData.parameters); + return { type: 'continue' }; + } + + case 'tool_result': { + const toolResultData = handleToolResultMessage(message as unknown as ToolResultMessage); + outputLines(formatToolResultOutput(toolResultData.displayContent)); + await auditLogger.logToolEnd(toolResultData.content); + return { type: 'continue' }; + } + + case 'result': { + const resultData = handleResultMessage(message as ResultMessage); + outputLines(formatResultOutput(resultData, !execContext.useCleanOutput)); + costResults.agents[execContext.agentKey] = resultData.cost; + costResults.total += resultData.cost; + return { type: 'complete', result: resultData.result, cost: resultData.cost }; + } + + default: + console.log(chalk.gray(` ${message.type}: ${JSON.stringify(message, null, 2)}`)); + return { type: 'continue' }; + } +} diff --git a/src/ai/output-formatters.ts b/src/ai/output-formatters.ts new file mode 100644 index 0000000..833c71c --- /dev/null +++ b/src/ai/output-formatters.ts @@ -0,0 +1,169 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +// Pure functions for formatting console output + +import chalk from 'chalk'; +import { extractAgentType, formatDuration } from '../utils/formatting.js'; +import { getAgentPrefix } from '../utils/output-formatter.js'; +import type { ExecutionContext, ResultData } from './types.js'; + +export function detectExecutionContext(description: string): ExecutionContext { + const isParallelExecution = + description.includes('vuln agent') || description.includes('exploit agent'); + + const useCleanOutput = + description.includes('Pre-recon agent') || + description.includes('Recon agent') || + description.includes('Executive Summary and Report Cleanup') || + description.includes('vuln agent') || + description.includes('exploit agent'); + + const agentType = extractAgentType(description); + + const agentKey = description.toLowerCase().replace(/\s+/g, '-'); + + return { isParallelExecution, useCleanOutput, agentType, agentKey }; +} + +export function formatAssistantOutput( + cleanedContent: string, + context: ExecutionContext, + turnCount: number, + description: string, + colorFn: typeof chalk.cyan = chalk.cyan +): string[] { + if (!cleanedContent.trim()) { + return []; + } + + const lines: string[] = []; + + if (context.isParallelExecution) { + // Compact output for parallel agents with prefixes + const prefix = getAgentPrefix(description); + lines.push(colorFn(`${prefix} ${cleanedContent}`)); + } else { + // Full turn output for sequential agents + lines.push(colorFn(`\n Turn ${turnCount} (${description}):`)); + lines.push(colorFn(` ${cleanedContent}`)); + } + + return lines; +} + +export function formatResultOutput(data: ResultData, showFullResult: boolean): string[] { + const lines: string[] = []; + + lines.push(chalk.magenta(`\n COMPLETED:`)); + lines.push( + chalk.gray( + ` Duration: ${(data.duration_ms / 1000).toFixed(1)}s, Cost: $${data.cost.toFixed(4)}` + ) + ); + + if (data.subtype === 'error_max_turns') { + lines.push(chalk.red(` Stopped: Hit maximum turns limit`)); + } else if (data.subtype === 'error_during_execution') { + lines.push(chalk.red(` Stopped: Execution error`)); + } + + if (data.permissionDenials > 0) { + lines.push(chalk.yellow(` ${data.permissionDenials} permission denials`)); + } + + if (showFullResult && data.result && typeof data.result === 'string') { + if (data.result.length > 1000) { + lines.push(chalk.magenta(` ${data.result.slice(0, 1000)}... [${data.result.length} total chars]`)); + } else { + lines.push(chalk.magenta(` ${data.result}`)); + } + } + + return lines; +} + +export function formatErrorOutput( + error: Error & { code?: string; status?: number }, + context: ExecutionContext, + description: string, + duration: number, + sourceDir: string, + isRetryable: boolean +): string[] { + const lines: string[] = []; + + if (context.isParallelExecution) { + const prefix = getAgentPrefix(description); + lines.push(chalk.red(`${prefix} Failed (${formatDuration(duration)})`)); + } else if (context.useCleanOutput) { + lines.push(chalk.red(`${context.agentType} failed (${formatDuration(duration)})`)); + } else { + lines.push(chalk.red(` Claude Code failed: ${description} (${formatDuration(duration)})`)); + } + + lines.push(chalk.red(` Error Type: ${error.constructor.name}`)); + lines.push(chalk.red(` Message: ${error.message}`)); + lines.push(chalk.gray(` Agent: ${description}`)); + lines.push(chalk.gray(` Working Directory: ${sourceDir}`)); + lines.push(chalk.gray(` Retryable: ${isRetryable ? 'Yes' : 'No'}`)); + + if (error.code) { + lines.push(chalk.gray(` Error Code: ${error.code}`)); + } + if (error.status) { + lines.push(chalk.gray(` HTTP Status: ${error.status}`)); + } + + return lines; +} + +export function formatCompletionMessage( + context: ExecutionContext, + description: string, + turnCount: number, + duration: number +): string { + if (context.isParallelExecution) { + const prefix = getAgentPrefix(description); + return chalk.green(`${prefix} Complete (${turnCount} turns, ${formatDuration(duration)})`); + } + + if (context.useCleanOutput) { + return chalk.green( + `${context.agentType.charAt(0).toUpperCase() + context.agentType.slice(1)} complete! (${turnCount} turns, ${formatDuration(duration)})` + ); + } + + return chalk.green( + ` Claude Code completed: ${description} (${turnCount} turns) in ${formatDuration(duration)}` + ); +} + +export function formatToolUseOutput( + toolName: string, + input: Record | undefined +): string[] { + const lines: string[] = []; + + lines.push(chalk.yellow(`\n Using Tool: ${toolName}`)); + if (input && Object.keys(input).length > 0) { + lines.push(chalk.gray(` Input: ${JSON.stringify(input, null, 2)}`)); + } + + return lines; +} + +export function formatToolResultOutput(displayContent: string): string[] { + const lines: string[] = []; + + lines.push(chalk.green(` Tool Result:`)); + if (displayContent) { + lines.push(chalk.gray(` ${displayContent}`)); + } + + return lines; +} diff --git a/src/ai/progress-manager.ts b/src/ai/progress-manager.ts new file mode 100644 index 0000000..ceee32d --- /dev/null +++ b/src/ai/progress-manager.ts @@ -0,0 +1,76 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +// Null Object pattern for progress indicator - callers never check for null + +import { ProgressIndicator } from '../progress-indicator.js'; +import { extractAgentType } from '../utils/formatting.js'; + +export interface ProgressContext { + description: string; + useCleanOutput: boolean; +} + +export interface ProgressManager { + start(): void; + stop(): void; + finish(message: string): void; + isActive(): boolean; +} + +class RealProgressManager implements ProgressManager { + private indicator: ProgressIndicator; + private active: boolean = false; + + constructor(message: string) { + this.indicator = new ProgressIndicator(message); + } + + start(): void { + this.indicator.start(); + this.active = true; + } + + stop(): void { + this.indicator.stop(); + this.active = false; + } + + finish(message: string): void { + this.indicator.finish(message); + this.active = false; + } + + isActive(): boolean { + return this.active; + } +} + +/** Null Object implementation - all methods are safe no-ops */ +class NullProgressManager implements ProgressManager { + start(): void {} + + stop(): void {} + + finish(_message: string): void {} + + isActive(): boolean { + return false; + } +} + +// Returns no-op when disabled +export function createProgressManager( + context: ProgressContext, + disableLoader: boolean +): ProgressManager { + if (!context.useCleanOutput || disableLoader) { + return new NullProgressManager(); + } + + const agentType = extractAgentType(context.description); + return new RealProgressManager(`Running ${agentType}...`); +} diff --git a/src/ai/types.ts b/src/ai/types.ts new file mode 100644 index 0000000..b754d0c --- /dev/null +++ b/src/ai/types.ts @@ -0,0 +1,134 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +// Type definitions for Claude executor message processing pipeline + +export interface ExecutionContext { + isParallelExecution: boolean; + useCleanOutput: boolean; + agentType: string; + agentKey: string; +} + +export interface ProcessingState { + turnCount: number; + result: string | null; + apiErrorDetected: boolean; + totalCost: number; + partialCost: number; + lastHeartbeat: number; +} + +export interface ProcessingResult { + result: string | null; + turnCount: number; + apiErrorDetected: boolean; + totalCost: number; +} + +export interface AssistantResult { + content: string; + cleanedContent: string; + apiErrorDetected: boolean; + shouldThrow?: Error; + logData: { + turn: number; + content: string; + timestamp: string; + }; +} + +export interface ResultData { + result: string | null; + cost: number; + duration_ms: number; + subtype?: string; + permissionDenials: number; +} + +export interface ToolUseData { + toolName: string; + parameters: Record; + timestamp: string; +} + +export interface ToolResultData { + content: unknown; + displayContent: string; + timestamp: string; +} + +export interface ContentBlock { + type?: string; + text?: string; +} + +export interface AssistantMessage { + type: 'assistant'; + message: { + content: ContentBlock[] | string; + }; +} + +export interface ResultMessage { + type: 'result'; + result?: string; + total_cost_usd?: number; + duration_ms?: number; + subtype?: string; + permission_denials?: unknown[]; +} + +export interface ToolUseMessage { + type: 'tool_use'; + name: string; + input?: Record; +} + +export interface ToolResultMessage { + type: 'tool_result'; + content?: unknown; +} + +export interface ApiErrorDetection { + detected: boolean; + shouldThrow?: Error; +} + +// Message types from SDK stream +export type SdkMessage = + | AssistantMessage + | ResultMessage + | ToolUseMessage + | ToolResultMessage + | SystemInitMessage + | UserMessage; + +export interface SystemInitMessage { + type: 'system'; + subtype: 'init'; + model?: string; + permissionMode?: string; + mcp_servers?: Array<{ name: string; status: string }>; +} + +export interface UserMessage { + type: 'user'; +} + +// Dispatch result types for message processing +export type MessageDispatchResult = + | { action: 'continue' } + | { action: 'break'; result: string | null; cost: number } + | { action: 'throw'; error: Error }; + +export interface MessageDispatchContext { + turnCount: number; + execContext: ExecutionContext; + description: string; + colorFn: (text: string) => string; + useCleanOutput: boolean; +} diff --git a/src/audit/audit-session.ts b/src/audit/audit-session.ts index b3540a7..ccb6636 100644 --- a/src/audit/audit-session.ts +++ b/src/audit/audit-session.ts @@ -13,7 +13,8 @@ import { AgentLogger } from './logger.js'; import { MetricsTracker } from './metrics-tracker.js'; -import { initializeAuditStructure, formatTimestamp, type SessionMetadata } from './utils.js'; +import { initializeAuditStructure, type SessionMetadata } from './utils.js'; +import { formatTimestamp } from '../utils/formatting.js'; import { SessionMutex } from '../utils/concurrency.js'; // Global mutex instance @@ -145,7 +146,7 @@ export class AuditSession { // Mutex-protected update to session.json const unlock = await sessionMutex.lock(this.sessionId); try { - // Reload metrics (in case of parallel updates) + // Reload inside mutex to prevent lost updates during parallel exploitation phase await this.metricsTracker.reload(); // Update metrics diff --git a/src/audit/logger.ts b/src/audit/logger.ts index 281563a..c8e902d 100644 --- a/src/audit/logger.ts +++ b/src/audit/logger.ts @@ -15,10 +15,10 @@ import fs from 'fs'; import { generateLogPath, generatePromptPath, - atomicWrite, - formatTimestamp, type SessionMetadata, } from './utils.js'; +import { atomicWrite } from '../utils/file-io.js'; +import { formatTimestamp } from '../utils/formatting.js'; interface LogEvent { type: string; @@ -96,22 +96,13 @@ export class AgentLogger { return; } - // Write and flush immediately (crash-safe) const needsDrain = !this.stream.write(text, 'utf8', (error) => { - if (error) { - reject(error); - } + if (error) reject(error); }); if (needsDrain) { - // Buffer is full, wait for drain - const drainHandler = (): void => { - this.stream!.removeListener('drain', drainHandler); - resolve(); - }; - this.stream.once('drain', drainHandler); + this.stream.once('drain', resolve); } else { - // Buffer has space, resolve immediately resolve(); } }); diff --git a/src/audit/metrics-tracker.ts b/src/audit/metrics-tracker.ts index 54ec973..3e552ef 100644 --- a/src/audit/metrics-tracker.ts +++ b/src/audit/metrics-tracker.ts @@ -13,13 +13,12 @@ import { generateSessionJsonPath, - atomicWrite, - readJson, - fileExists, - formatTimestamp, - calculatePercentage, type SessionMetadata, } from './utils.js'; +import { atomicWrite, readJson, fileExists } from '../utils/file-io.js'; +import { formatTimestamp, calculatePercentage } from '../utils/formatting.js'; +import { AGENT_PHASE_MAP, type PhaseName } from '../session-manager.js'; +import type { AgentName } from '../types/index.js'; interface AttemptData { attempt_number: number; @@ -152,16 +151,14 @@ export class MetricsTracker { } // Initialize agent metrics if not exists - if (!this.data.metrics.agents[agentName]) { - this.data.metrics.agents[agentName] = { - status: 'in-progress', - attempts: [], - final_duration_ms: 0, - total_cost_usd: 0, - }; - } - - const agent = this.data.metrics.agents[agentName]!; + const existingAgent = this.data.metrics.agents[agentName]; + const agent = existingAgent ?? { + status: 'in-progress' as const, + attempts: [], + final_duration_ms: 0, + total_cost_usd: 0, + }; + this.data.metrics.agents[agentName] = agent; // Add attempt to array const attempt: AttemptData = { @@ -255,36 +252,19 @@ export class MetricsTracker { private calculatePhaseMetrics( successfulAgents: Array<[string, AgentMetrics]> ): Record { - const phases: Record = { + const phases: Record = { 'pre-recon': [], - recon: [], + 'recon': [], 'vulnerability-analysis': [], - exploitation: [], - reporting: [], + 'exploitation': [], + 'reporting': [], }; - // Map agents to phases - const agentPhaseMap: Record = { - 'pre-recon': 'pre-recon', - recon: 'recon', - 'injection-vuln': 'vulnerability-analysis', - 'xss-vuln': 'vulnerability-analysis', - 'auth-vuln': 'vulnerability-analysis', - 'authz-vuln': 'vulnerability-analysis', - 'ssrf-vuln': 'vulnerability-analysis', - 'injection-exploit': 'exploitation', - 'xss-exploit': 'exploitation', - 'auth-exploit': 'exploitation', - 'authz-exploit': 'exploitation', - 'ssrf-exploit': 'exploitation', - report: 'reporting', - }; - - // Group agents by phase + // Group agents by phase using imported AGENT_PHASE_MAP for (const [agentName, agentData] of successfulAgents) { - const phase = agentPhaseMap[agentName]; - if (phase && phases[phase]) { - phases[phase]!.push(agentData); + const phase = AGENT_PHASE_MAP[agentName as AgentName]; + if (phase) { + phases[phase].push(agentData); } } @@ -296,7 +276,6 @@ export class MetricsTracker { if (agentList.length === 0) continue; const phaseDuration = agentList.reduce((sum, agent) => sum + agent.final_duration_ms, 0); - const phaseCost = agentList.reduce((sum, agent) => sum + agent.total_cost_usd, 0); phaseMetrics[phaseName] = { diff --git a/src/error-handling.ts b/src/error-handling.ts index c2b5766..9dd6831 100644 --- a/src/error-handling.ts +++ b/src/error-handling.ts @@ -37,11 +37,11 @@ export class PentestError extends Error { } // Centralized error logging function -export const logError = async ( +export async function logError( error: Error & { type?: PentestErrorType; retryable?: boolean; context?: PentestErrorContext }, contextMsg: string, sourceDir: string | null = null -): Promise => { +): Promise { const timestamp = new Date().toISOString(); const logEntry: LogEntry = { timestamp, @@ -80,13 +80,13 @@ export const logError = async ( } return logEntry; -}; +} // Handle tool execution errors -export const handleToolError = ( +export function handleToolError( toolName: string, error: Error & { code?: string } -): ToolErrorResult => { +): ToolErrorResult { const isRetryable = error.code === 'ECONNRESET' || error.code === 'ETIMEDOUT' || @@ -105,13 +105,13 @@ export const handleToolError = ( { toolName, originalError: error.message, errorCode: error.code } ), }; -}; +} // Handle prompt loading errors -export const handlePromptError = ( +export function handlePromptError( promptName: string, error: Error -): PromptErrorResult => { +): PromptErrorResult { return { success: false, error: new PentestError( @@ -121,78 +121,63 @@ export const handlePromptError = ( { promptName, originalError: error.message } ), }; -}; +} -// Check if an error should trigger a retry for Claude agents -export const isRetryableError = (error: Error): boolean => { +// Patterns that indicate retryable errors +const RETRYABLE_PATTERNS = [ + // Network and connection errors + 'network', + 'connection', + 'timeout', + 'econnreset', + 'enotfound', + 'econnrefused', + // Rate limiting + 'rate limit', + '429', + 'too many requests', + // Server errors + 'server error', + '5xx', + 'internal server error', + 'service unavailable', + 'bad gateway', + // Claude API errors + 'mcp server', + 'model unavailable', + 'service temporarily unavailable', + 'api error', + 'terminated', + // Max turns + 'max turns', + 'maximum turns', +]; + +// Patterns that indicate non-retryable errors (checked before default) +const NON_RETRYABLE_PATTERNS = [ + 'authentication', + 'invalid prompt', + 'out of memory', + 'permission denied', + 'session limit reached', + 'invalid api key', +]; + +// Conservative retry classification - unknown errors don't retry (fail-safe default) +export function isRetryableError(error: Error): boolean { const message = error.message.toLowerCase(); - // Network and connection errors - always retryable - if ( - message.includes('network') || - message.includes('connection') || - message.includes('timeout') || - message.includes('econnreset') || - message.includes('enotfound') || - message.includes('econnrefused') - ) { - return true; - } - - // Rate limiting - retryable with longer backoff - if ( - message.includes('rate limit') || - message.includes('429') || - message.includes('too many requests') - ) { - return true; - } - - // Server errors - retryable - if ( - message.includes('server error') || - message.includes('5xx') || - message.includes('internal server error') || - message.includes('service unavailable') || - message.includes('bad gateway') - ) { - return true; - } - - // Claude API specific errors - retryable - if ( - message.includes('mcp server') || - message.includes('model unavailable') || - message.includes('service temporarily unavailable') || - message.includes('api error') || - message.includes('terminated') - ) { - return true; - } - - // Max turns without completion - retryable once - if (message.includes('max turns') || message.includes('maximum turns')) { - return true; - } - - // Non-retryable errors - if ( - message.includes('authentication') || - message.includes('invalid prompt') || - message.includes('out of memory') || - message.includes('permission denied') || - message.includes('session limit reached') || - message.includes('invalid api key') - ) { + // Check for explicit non-retryable patterns first + if (NON_RETRYABLE_PATTERNS.some((pattern) => message.includes(pattern))) { return false; } - // Default to non-retryable for unknown errors - return false; -}; + // Check for retryable patterns + return RETRYABLE_PATTERNS.some((pattern) => message.includes(pattern)); +} -// Get retry delay based on error type and attempt number -export const getRetryDelay = (error: Error, attempt: number): number => { +// Rate limit errors get longer base delay (30s) vs standard exponential backoff (2s) +export function getRetryDelay(error: Error, attempt: number): number { const message = error.message.toLowerCase(); // Rate limiting gets longer delays @@ -204,4 +189,4 @@ export const getRetryDelay = (error: Error, attempt: number): number => { const baseDelay = Math.pow(2, attempt) * 1000; // 2s, 4s, 8s const jitter = Math.random() * 1000; // 0-1s random return Math.min(baseDelay + jitter, 30000); // Max 30s -}; +} diff --git a/src/phases/pre-recon.ts b/src/phases/pre-recon.ts index 34f1580..5430029 100644 --- a/src/phases/pre-recon.ts +++ b/src/phases/pre-recon.ts @@ -7,7 +7,7 @@ import { $, fs, path } from 'zx'; import chalk from 'chalk'; import { Timer } from '../utils/metrics.js'; -import { formatDuration } from '../audit/utils.js'; +import { formatDuration } from '../utils/formatting.js'; import { handleToolError, PentestError } from '../error-handling.js'; import { AGENTS } from '../session-manager.js'; import { runClaudePromptWithRetry } from '../ai/claude-executor.js'; @@ -40,11 +40,17 @@ interface PromptVariables { repoPath: string; } +// Discriminated union for Wave1 tool results - clearer than loose union types +type Wave1ToolResult = + | { kind: 'scan'; result: TerminalScanResult } + | { kind: 'skipped'; message: string } + | { kind: 'agent'; result: AgentResult }; + interface Wave1Results { - nmap: TerminalScanResult | string | AgentResult; - subfinder: TerminalScanResult | string | AgentResult; - whatweb: TerminalScanResult | string | AgentResult; - naabu?: TerminalScanResult | string | AgentResult; + nmap: Wave1ToolResult; + subfinder: Wave1ToolResult; + whatweb: Wave1ToolResult; + naabu?: Wave1ToolResult; codeAnalysis: AgentResult; } @@ -57,7 +63,7 @@ interface PreReconResult { report: string; } -// Pure function: Run terminal scanning tools +// Runs external security tools (nmap, whatweb, etc). Schemathesis requires schemas from code analysis. async function runTerminalScan(tool: ToolName, target: string, sourceDir: string | null = null): Promise { const timer = new Timer(`command-${tool}`); try { @@ -89,7 +95,7 @@ async function runTerminalScan(tool: ToolName, target: string, sourceDir: string return { tool: 'whatweb', output: result.stdout, status: 'success', duration: whatwebDuration }; } case 'schemathesis': { - // Only run if API schemas found + // Schemathesis depends on code analysis output - skip if no schemas found const schemasDir = path.join(sourceDir || '.', 'outputs', 'schemas'); if (await fs.pathExists(schemasDir)) { const schemaFiles = await fs.readdir(schemasDir) as string[]; @@ -146,6 +152,8 @@ async function runPreReconWave1( const operations: Promise[] = []; + const skippedResult = (message: string): Wave1ToolResult => ({ kind: 'skipped', message }); + // Skip external commands in pipeline testing mode if (pipelineTestingMode) { console.log(chalk.gray(' โญ๏ธ Skipping external tools (pipeline testing mode)')); @@ -163,9 +171,9 @@ async function runPreReconWave1( ); const [codeAnalysis] = await Promise.all(operations); return { - nmap: 'Skipped (pipeline testing mode)', - subfinder: 'Skipped (pipeline testing mode)', - whatweb: 'Skipped (pipeline testing mode)', + nmap: skippedResult('Skipped (pipeline testing mode)'), + subfinder: skippedResult('Skipped (pipeline testing mode)'), + whatweb: skippedResult('Skipped (pipeline testing mode)'), codeAnalysis: codeAnalysis as AgentResult }; } else { @@ -192,9 +200,9 @@ async function runPreReconWave1( const [nmap, subfinder, whatweb, codeAnalysis] = await Promise.all(operations); return { - nmap: nmap as TerminalScanResult, - subfinder: subfinder as TerminalScanResult, - whatweb: whatweb as TerminalScanResult, + nmap: { kind: 'scan', result: nmap as TerminalScanResult }, + subfinder: { kind: 'scan', result: subfinder as TerminalScanResult }, + whatweb: { kind: 'scan', result: whatweb as TerminalScanResult }, codeAnalysis: codeAnalysis as AgentResult }; } @@ -250,17 +258,21 @@ async function runPreReconWave2( return response; } -// Helper type for stitching results -interface StitchableResult { - status?: string; - output?: string; - tool?: string; +// Extracts status and output from a Wave1 tool result +function extractResult(r: Wave1ToolResult | undefined): { status: string; output: string } { + if (!r) return { status: 'Skipped', output: 'No output' }; + switch (r.kind) { + case 'scan': + return { status: r.result.status || 'Skipped', output: r.result.output || 'No output' }; + case 'skipped': + return { status: 'Skipped', output: r.message }; + case 'agent': + return { status: r.result.success ? 'success' : 'error', output: 'See agent output' }; + } } -// Pure function: Stitch together pre-recon outputs and save to file -async function stitchPreReconOutputs(outputs: (StitchableResult | string | undefined)[], sourceDir: string): Promise { - const [nmap, subfinder, whatweb, naabu, codeAnalysis, ...additionalScans] = outputs; - +// Combines tool outputs into single deliverable. Falls back to reference if file missing. +async function stitchPreReconOutputs(wave1: Wave1Results, additionalScans: TerminalScanResult[], sourceDir: string): Promise { // Try to read the code analysis deliverable file let codeAnalysisContent = 'No analysis available'; try { @@ -269,62 +281,45 @@ async function stitchPreReconOutputs(outputs: (StitchableResult | string | undef } catch (error) { const err = error as Error; console.log(chalk.yellow(`โš ๏ธ Could not read code analysis deliverable: ${err.message}`)); - // Fallback message if file doesn't exist codeAnalysisContent = 'Analysis located in deliverables/code_analysis_deliverable.md'; } - // Build additional scans section let additionalSection = ''; - if (additionalScans && additionalScans.length > 0) { + if (additionalScans.length > 0) { additionalSection = '\n## Authenticated Scans\n'; - additionalScans.forEach(scan => { - const s = scan as StitchableResult; - if (s && s.tool) { - additionalSection += ` -### ${s.tool.toUpperCase()} -Status: ${s.status} -${s.output} + for (const scan of additionalScans) { + additionalSection += ` +### ${scan.tool.toUpperCase()} +Status: ${scan.status} +${scan.output} `; - } - }); + } } - const nmapResult = nmap as StitchableResult | string | undefined; - const subfinderResult = subfinder as StitchableResult | string | undefined; - const whatwebResult = whatweb as StitchableResult | string | undefined; - const naabuResult = naabu as StitchableResult | string | undefined; - - const getStatus = (r: StitchableResult | string | undefined): string => { - if (!r) return 'Skipped'; - if (typeof r === 'string') return 'Skipped'; - return r.status || 'Skipped'; - }; - - const getOutput = (r: StitchableResult | string | undefined): string => { - if (!r) return 'No output'; - if (typeof r === 'string') return r; - return r.output || 'No output'; - }; + const nmap = extractResult(wave1.nmap); + const subfinder = extractResult(wave1.subfinder); + const whatweb = extractResult(wave1.whatweb); + const naabu = extractResult(wave1.naabu); const report = ` # Pre-Reconnaissance Report ## Port Discovery (naabu) -Status: ${getStatus(naabuResult)} -${getOutput(naabuResult)} +Status: ${naabu.status} +${naabu.output} ## Network Scanning (nmap) -Status: ${getStatus(nmapResult)} -${getOutput(nmapResult)} +Status: ${nmap.status} +${nmap.output} ## Subdomain Discovery (subfinder) -Status: ${getStatus(subfinderResult)} -${getOutput(subfinderResult)} +Status: ${subfinder.status} +${subfinder.output} ## Technology Detection (whatweb) -Status: ${getStatus(whatwebResult)} -${getOutput(whatwebResult)} +Status: ${whatweb.status} +${whatweb.output} ## Code Analysis ${codeAnalysisContent} ${additionalSection} @@ -375,16 +370,8 @@ export async function executePreReconPhase( console.log(chalk.green(' โœ… Wave 2 operations completed')); console.log(chalk.blue('๐Ÿ“ Stitching pre-recon outputs...')); - // Combine wave 1 and wave 2 results for stitching - const allResults: (StitchableResult | string | undefined)[] = [ - wave1Results.nmap as StitchableResult | string, - wave1Results.subfinder as StitchableResult | string, - wave1Results.whatweb as StitchableResult | string, - wave1Results.naabu as StitchableResult | string | undefined, - wave1Results.codeAnalysis as unknown as StitchableResult, - ...(wave2Results.schemathesis ? [wave2Results.schemathesis as StitchableResult] : []) - ]; - const preReconReport = await stitchPreReconOutputs(allResults, sourceDir); + const additionalScans = wave2Results.schemathesis ? [wave2Results.schemathesis] : []; + const preReconReport = await stitchPreReconOutputs(wave1Results, additionalScans, sourceDir); const duration = timer.stop(); console.log(chalk.green(`โœ… Pre-reconnaissance complete in ${formatDuration(duration)}`)); diff --git a/src/queue-validation.ts b/src/queue-validation.ts index 1f84a1e..ce21e1d 100644 --- a/src/queue-validation.ts +++ b/src/queue-validation.ts @@ -6,6 +6,7 @@ import { fs, path } from 'zx'; import { PentestError } from './error-handling.js'; +import { asyncPipe } from './utils/functional.js'; export type VulnType = 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz'; @@ -16,9 +17,11 @@ interface VulnTypeConfigItem { type VulnTypeConfig = Record; +type ErrorMessageResolver = string | ((existence: FileExistence) => string); + interface ValidationRule { predicate: (existence: FileExistence) => boolean; - errorMessage: string; + errorMessage: ErrorMessageResolver; retryable: boolean; } @@ -94,40 +97,36 @@ const VULN_TYPE_CONFIG: VulnTypeConfig = Object.freeze({ }), }) as VulnTypeConfig; -// Functional composition utilities - async pipe for promise chain -type PipeFunction = (x: any) => any | Promise; - -const pipe = - (...fns: PipeFunction[]) => - (x: any): Promise => - fns.reduce(async (v, f) => f(await v), Promise.resolve(x)); - // Pure function to create validation rule -const createValidationRule = ( +function createValidationRule( predicate: (existence: FileExistence) => boolean, - errorMessage: string, + errorMessage: ErrorMessageResolver, retryable: boolean = true -): ValidationRule => Object.freeze({ predicate, errorMessage, retryable }); +): ValidationRule { + return Object.freeze({ predicate, errorMessage, retryable }); +} -// Validation rules for file existence (following QUEUE_VALIDATION_FLOW.md) +// Symmetric deliverable rules: queue and deliverable must exist together (prevents partial analysis from triggering exploitation) const fileExistenceRules: readonly ValidationRule[] = Object.freeze([ - // Rule 1: Neither deliverable nor queue exists createValidationRule( - ({ deliverableExists, queueExists }) => deliverableExists || queueExists, - 'Analysis failed: Neither deliverable nor queue file exists. Analysis agent must create both files.' - ), - // Rule 2: Queue doesn't exist but deliverable exists - createValidationRule( - ({ deliverableExists, queueExists }) => !(!queueExists && deliverableExists), - 'Analysis incomplete: Deliverable exists but queue file missing. Analysis agent must create both files.' - ), - // Rule 3: Queue exists but deliverable doesn't exist - createValidationRule( - ({ deliverableExists, queueExists }) => !(queueExists && !deliverableExists), - 'Analysis incomplete: Queue exists but deliverable file missing. Analysis agent must create both files.' + ({ deliverableExists, queueExists }) => deliverableExists && queueExists, + getExistenceErrorMessage ), ]); +// Generate appropriate error message based on which files are missing +function getExistenceErrorMessage(existence: FileExistence): string { + const { deliverableExists, queueExists } = existence; + + if (!deliverableExists && !queueExists) { + return 'Analysis failed: Neither deliverable nor queue file exists. Analysis agent must create both files.'; + } + if (!queueExists) { + return 'Analysis incomplete: Deliverable exists but queue file missing. Analysis agent must create both files.'; + } + return 'Analysis incomplete: Queue exists but deliverable file missing. Analysis agent must create both files.'; +} + // Pure function to create file paths const createPaths = ( vulnType: VulnType, @@ -170,7 +169,7 @@ const checkFileExistence = async ( }); }; -// Pure function to validate existence rules +// Validates deliverable/queue symmetry - both must exist or neither const validateExistenceRules = ( pathsWithExistence: PathsWithExistence | PathsWithError ): PathsWithExistence | PathsWithError => { @@ -182,9 +181,14 @@ const validateExistenceRules = ( const failedRule = fileExistenceRules.find((rule) => !rule.predicate(existence)); if (failedRule) { + const message = + typeof failedRule.errorMessage === 'function' + ? failedRule.errorMessage(existence) + : failedRule.errorMessage; + return { error: new PentestError( - `${failedRule.errorMessage} (${vulnType})`, + `${message} (${vulnType})`, 'validation', failedRule.retryable, { @@ -224,7 +228,7 @@ const validateQueueStructure = (content: string): QueueValidationResult => { } }; -// Pure function to read and validate queue content +// Queue parse failures are retryable - agent can fix malformed JSON on retry const validateQueueContent = async ( pathsWithExistence: PathsWithExistence | PathsWithError ): Promise => { @@ -273,7 +277,7 @@ const validateQueueContent = async ( } }; -// Pure function to determine exploitation decision +// Final decision: skip if queue says no vulns, proceed if vulns found, error otherwise const determineExploitationDecision = ( validatedData: PathsWithQueue | PathsWithError ): ExploitationDecision => { @@ -294,17 +298,18 @@ const determineExploitationDecision = ( }; // Main functional validation pipeline -export const validateQueueAndDeliverable = async ( +export async function validateQueueAndDeliverable( vulnType: VulnType, sourceDir: string -): Promise => - (await pipe( - () => createPaths(vulnType, sourceDir), +): Promise { + return asyncPipe( + createPaths(vulnType, sourceDir), checkFileExistence, validateExistenceRules, validateQueueContent, determineExploitationDecision - )(() => createPaths(vulnType, sourceDir))) as ExploitationDecision; + ); +} // Pure function to safely validate (returns result instead of throwing) export const safeValidateQueueAndDeliverable = async ( diff --git a/src/session-manager.ts b/src/session-manager.ts index fbf6d4d..7a31d1a 100644 --- a/src/session-manager.ts +++ b/src/session-manager.ts @@ -106,6 +106,26 @@ export const getParallelGroups = (): Readonly<{ vuln: AgentName[]; exploit: Agen exploit: ['injection-exploit', 'xss-exploit', 'auth-exploit', 'ssrf-exploit', 'authz-exploit'] }); +// Phase names for metrics aggregation +export type PhaseName = 'pre-recon' | 'recon' | 'vulnerability-analysis' | 'exploitation' | 'reporting'; + +// Map agents to their corresponding phases (single source of truth) +export const AGENT_PHASE_MAP: Readonly> = Object.freeze({ + 'pre-recon': 'pre-recon', + 'recon': 'recon', + 'injection-vuln': 'vulnerability-analysis', + 'xss-vuln': 'vulnerability-analysis', + 'auth-vuln': 'vulnerability-analysis', + 'authz-vuln': 'vulnerability-analysis', + 'ssrf-vuln': 'vulnerability-analysis', + 'injection-exploit': 'exploitation', + 'xss-exploit': 'exploitation', + 'auth-exploit': 'exploitation', + 'authz-exploit': 'exploitation', + 'ssrf-exploit': 'exploitation', + 'report': 'reporting', +}); + // Generate a session-based log folder path (used by claude-executor.ts) export const generateSessionLogPath = (webUrl: string, sessionId: string): string => { const hostname = new URL(webUrl).hostname.replace(/[^a-zA-Z0-9-]/g, '-'); diff --git a/src/shannon.ts b/src/shannon.ts index e493047..8eddcdb 100644 --- a/src/shannon.ts +++ b/src/shannon.ts @@ -17,6 +17,7 @@ import { checkToolAvailability, handleMissingTools } from './tool-checker.js'; // Session import { AGENTS, getParallelGroups } from './session-manager.js'; +import { getPromptNameForAgent } from './types/agents.js'; import type { AgentName, PromptName } from './types/index.js'; // Setup and Deliverables @@ -32,7 +33,8 @@ import { assembleFinalReport } from './phases/reporting.js'; // Utils import { timingResults, displayTimingSummary, Timer } from './utils/metrics.js'; -import { formatDuration, generateAuditPath } from './audit/utils.js'; +import { formatDuration } from './utils/formatting.js'; +import { generateAuditPath } from './audit/utils.js'; import type { SessionMetadata } from './audit/utils.js'; import { AuditSession } from './audit/audit-session.js'; @@ -86,6 +88,7 @@ async function saveSessions(store: SessionStore): Promise { await fs.writeJson(STORE_PATH, store, { spaces: 2 }); } +// Session prevents concurrent runs on same repo - different repos can run in parallel async function createSession(webUrl: string, repoPath: string): Promise { const store = await loadSessions(); @@ -155,32 +158,26 @@ interface ParallelAgentResult { error?: string | undefined; } +type VulnType = 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz'; + +interface ParallelAgentConfig { + phaseType: 'vuln' | 'exploit'; + headerText: string; + specialistLabel: string; +} + +interface AgentExecutionContext { + sourceDir: string; + variables: PromptVariables; + distributedConfig: DistributedConfig | null; + pipelineTestingMode: boolean; + sessionMetadata: SessionMetadata; +} + // Configure zx to disable timeouts (let tools run as long as needed) $.timeout = 0; -// Helper function to get prompt name from agent name -const getPromptName = (agentName: AgentName): PromptName => { - const mappings: Record = { - 'pre-recon': 'pre-recon-code', - 'recon': 'recon', - 'injection-vuln': 'vuln-injection', - 'xss-vuln': 'vuln-xss', - 'auth-vuln': 'vuln-auth', - 'ssrf-vuln': 'vuln-ssrf', - 'authz-vuln': 'vuln-authz', - 'injection-exploit': 'exploit-injection', - 'xss-exploit': 'exploit-xss', - 'auth-exploit': 'exploit-auth', - 'ssrf-exploit': 'exploit-ssrf', - 'authz-exploit': 'exploit-authz', - 'report': 'report-executive' - }; - - return mappings[agentName] || agentName as PromptName; -}; - -// Get color function for agent -const getAgentColor = (agentName: AgentName): ChalkInstance => { +function getAgentColor(agentName: AgentName): ChalkInstance { const colorMap: Partial> = { 'injection-vuln': chalk.red, 'injection-exploit': chalk.red, @@ -194,11 +191,9 @@ const getAgentColor = (agentName: AgentName): ChalkInstance => { 'authz-exploit': chalk.green }; return colorMap[agentName] || chalk.cyan; -}; +} -/** - * Consolidate deliverables from target repo into the session folder - */ +// Non-fatal copy - failure logs warning but doesn't halt pipeline async function consolidateOutputs(sourceDir: string, sessionPath: string): Promise { const srcDeliverables = path.join(sourceDir, 'deliverables'); const destDeliverables = path.join(sessionPath, 'deliverables'); @@ -228,7 +223,7 @@ async function runAgent( sessionMetadata: SessionMetadata ): Promise { const agent = AGENTS[agentName]; - const promptName = getPromptName(agentName); + const promptName = getPromptNameForAgent(agentName); const prompt = await loadPrompt(promptName, variables, distributedConfig, pipelineTestingMode); return await runClaudePromptWithRetry( @@ -244,85 +239,68 @@ async function runAgent( } /** - * Run vulnerability agents in parallel + * Execute a single agent with retry logic */ -async function runParallelVuln( - sourceDir: string, - variables: PromptVariables, - distributedConfig: DistributedConfig | null, - pipelineTestingMode: boolean, - sessionMetadata: SessionMetadata -): Promise { - const { vuln: vulnAgents } = getParallelGroups(); +async function executeAgentWithRetry( + agentName: AgentName, + context: AgentExecutionContext, + onSuccess?: (agentName: AgentName) => Promise +): Promise { + const { sourceDir, variables, distributedConfig, pipelineTestingMode, sessionMetadata } = context; + const maxAttempts = 3; + let lastError: Error | undefined; + let attempts = 0; - console.log(chalk.cyan(`\nStarting ${vulnAgents.length} vulnerability analysis specialists in parallel...`)); - console.log(chalk.gray(' Specialists: ' + vulnAgents.join(', '))); - console.log(); + while (attempts < maxAttempts) { + attempts++; + try { + const result = await runAgent( + agentName, + sourceDir, + variables, + distributedConfig, + pipelineTestingMode, + sessionMetadata + ); - const startTime = Date.now(); - - const results = await Promise.allSettled( - vulnAgents.map(async (agentName, index) => { - // Add 2-second stagger to prevent API overwhelm - await new Promise(resolve => setTimeout(resolve, index * 2000)); - - let lastError: Error | undefined; - let attempts = 0; - const maxAttempts = 3; - - while (attempts < maxAttempts) { - attempts++; - try { - const result = await runAgent( - agentName, - sourceDir, - variables, - distributedConfig, - pipelineTestingMode, - sessionMetadata - ); - - // Validate vulnerability analysis results - const vulnType = agentName.replace('-vuln', ''); - try { - const validation = await safeValidateQueueAndDeliverable(vulnType as 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz', sourceDir); - - if (validation.success && validation.data) { - console.log(chalk.blue(`${agentName}: ${validation.data.shouldExploit ? `Ready for exploitation (${validation.data.vulnerabilityCount} vulnerabilities)` : 'No vulnerabilities found'}`)); - } - } catch { - // Validation failure is non-critical - } - - return { - agentName, - success: result.success, - timing: result.duration, - cost: result.cost, - attempts - }; - } catch (error) { - lastError = error as Error; - if (attempts < maxAttempts) { - console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`)); - await new Promise(resolve => setTimeout(resolve, 5000)); - } - } + if (onSuccess) { + await onSuccess(agentName); } return { agentName, - success: false, - attempts, - error: lastError?.message || 'Unknown error' + success: result.success, + timing: result.duration, + cost: result.cost, + attempts }; - }) - ); + } catch (error) { + lastError = error as Error; + if (attempts < maxAttempts) { + console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`)); + await new Promise(resolve => setTimeout(resolve, 5000)); + } + } + } - const totalDuration = Date.now() - startTime; + return { + agentName, + success: false, + attempts, + error: lastError?.message || 'Unknown error' + }; +} - // Process and display results - console.log(chalk.cyan('\nVulnerability Analysis Results')); +/** + * Display results table for parallel agent execution + */ +function displayParallelResults( + results: PromiseSettledResult[], + agents: AgentName[], + headerText: string, + totalDuration: number +): ParallelAgentResult[] { + console.log(chalk.cyan(`\n${headerText}`)); console.log(chalk.gray('-'.repeat(80))); console.log(chalk.bold('Agent Status Attempt Duration Cost')); console.log(chalk.gray('-'.repeat(80))); @@ -330,7 +308,7 @@ async function runParallelVuln( const processedResults: ParallelAgentResult[] = []; results.forEach((result, index) => { - const agentName = vulnAgents[index]!; + const agentName = agents[index]!; const agentDisplay = agentName.padEnd(22); if (result.status === 'fulfilled') { @@ -371,159 +349,90 @@ async function runParallelVuln( console.log(chalk.gray('-'.repeat(80))); const successCount = processedResults.filter(r => r.success).length; - console.log(chalk.cyan(`Summary: ${successCount}/${vulnAgents.length} succeeded in ${formatDuration(totalDuration)}`)); + console.log(chalk.cyan(`Summary: ${successCount}/${agents.length} succeeded in ${formatDuration(totalDuration)}`)); return processedResults; } /** - * Run exploitation agents in parallel + * Run agents in parallel with retry logic and result display */ -async function runParallelExploit( - sourceDir: string, - variables: PromptVariables, - distributedConfig: DistributedConfig | null, - pipelineTestingMode: boolean, - sessionMetadata: SessionMetadata +async function runParallelAgents( + context: AgentExecutionContext, + config: ParallelAgentConfig ): Promise { - const { exploit: exploitAgents, vuln: vulnAgents } = getParallelGroups(); + const { sourceDir } = context; + const { phaseType, headerText, specialistLabel } = config; + const parallelGroups = getParallelGroups(); + const allAgents = parallelGroups[phaseType]; - // Load validation module - const { safeValidateQueueAndDeliverable } = await import('./queue-validation.js'); + // For exploit phase, filter to only eligible agents + let agents: AgentName[]; + if (phaseType === 'exploit') { + const eligibilityChecks = await Promise.all( + allAgents.map(async (agentName) => { + const vulnAgentName = agentName.replace('-exploit', '-vuln') as AgentName; + const vulnType = vulnAgentName.replace('-vuln', '') as VulnType; - // Check eligibility - const eligibilityChecks = await Promise.all( - exploitAgents.map(async (agentName) => { - const vulnAgentName = agentName.replace('-exploit', '-vuln') as AgentName; - const vulnType = vulnAgentName.replace('-vuln', '') as 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz'; + const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir); - const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir); + if (!validation.success || !validation.data?.shouldExploit) { + console.log(chalk.gray(`Skipping ${agentName} (no vulnerabilities found in ${vulnAgentName})`)); + return { agentName, eligible: false }; + } - if (!validation.success || !validation.data?.shouldExploit) { - console.log(chalk.gray(`Skipping ${agentName} (no vulnerabilities found in ${vulnAgentName})`)); - return { agentName, eligible: false }; - } + console.log(chalk.blue(`${agentName} eligible (${validation.data.vulnerabilityCount} vulnerabilities from ${vulnAgentName})`)); + return { agentName, eligible: true }; + }) + ); - console.log(chalk.blue(`${agentName} eligible (${validation.data.vulnerabilityCount} vulnerabilities from ${vulnAgentName})`)); - return { agentName, eligible: true }; - }) - ); + agents = eligibilityChecks + .filter(check => check.eligible) + .map(check => check.agentName); - const eligibleAgents = eligibilityChecks - .filter(check => check.eligible) - .map(check => check.agentName); - - if (eligibleAgents.length === 0) { - console.log(chalk.gray('No exploitation agents eligible (no vulnerabilities found)')); - return []; + if (agents.length === 0) { + console.log(chalk.gray('No exploitation agents eligible (no vulnerabilities found)')); + return []; + } + } else { + agents = allAgents; } - console.log(chalk.cyan(`\nStarting ${eligibleAgents.length} exploitation specialists in parallel...`)); - console.log(chalk.gray(' Specialists: ' + eligibleAgents.join(', '))); + console.log(chalk.cyan(`\nStarting ${agents.length} ${specialistLabel} in parallel...`)); + console.log(chalk.gray(' Specialists: ' + agents.join(', '))); console.log(); const startTime = Date.now(); - const results = await Promise.allSettled( - eligibleAgents.map(async (agentName, index) => { - await new Promise(resolve => setTimeout(resolve, index * 2000)); - - let lastError: Error | undefined; - let attempts = 0; - const maxAttempts = 3; - - while (attempts < maxAttempts) { - attempts++; + // Build onSuccess callback for vuln phase (validation logging) + const onSuccess = phaseType === 'vuln' + ? async (agentName: AgentName): Promise => { + const vulnType = agentName.replace('-vuln', '') as VulnType; try { - const result = await runAgent( - agentName, - sourceDir, - variables, - distributedConfig, - pipelineTestingMode, - sessionMetadata - ); - - return { - agentName, - success: result.success, - timing: result.duration, - cost: result.cost, - attempts - }; - } catch (error) { - lastError = error as Error; - if (attempts < maxAttempts) { - console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`)); - await new Promise(resolve => setTimeout(resolve, 5000)); + const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir); + if (validation.success && validation.data) { + const message = validation.data.shouldExploit + ? `Ready for exploitation (${validation.data.vulnerabilityCount} vulnerabilities)` + : 'No vulnerabilities found'; + console.log(chalk.blue(`${agentName}: ${message}`)); } + } catch { + // Validation failure is non-critical } } + : undefined; - return { - agentName, - success: false, - attempts, - error: lastError?.message || 'Unknown error' - }; + const results = await Promise.allSettled( + agents.map(async (agentName, index) => { + // Add 2-second stagger to prevent API overwhelm + await new Promise(resolve => setTimeout(resolve, index * 2000)); + return executeAgentWithRetry(agentName, context, onSuccess); }) ); const totalDuration = Date.now() - startTime; - // Process and display results - console.log(chalk.cyan('\nExploitation Results')); - console.log(chalk.gray('-'.repeat(80))); - console.log(chalk.bold('Agent Status Attempt Duration Cost')); - console.log(chalk.gray('-'.repeat(80))); - - const processedResults: ParallelAgentResult[] = []; - - results.forEach((result, index) => { - const agentName = eligibleAgents[index]!; - const agentDisplay = agentName.padEnd(22); - - if (result.status === 'fulfilled') { - const data = result.value; - processedResults.push(data); - - if (data.success) { - const duration = formatDuration(data.timing || 0); - const cost = `$${(data.cost || 0).toFixed(4)}`; - - console.log( - `${chalk.green(agentDisplay)} ${chalk.green('Success')} ` + - `${data.attempts}/3 ${duration.padEnd(11)} ${cost}` - ); - } else { - console.log( - `${chalk.red(agentDisplay)} ${chalk.red('Failed ')} ` + - `${data.attempts}/3 - -` - ); - if (data.error) { - console.log(chalk.gray(` Error: ${data.error.substring(0, 60)}...`)); - } - } - } else { - processedResults.push({ - agentName, - success: false, - attempts: 3, - error: String(result.reason) - }); - - console.log( - `${chalk.red(agentDisplay)} ${chalk.red('Failed ')} ` + - `3/3 - -` - ); - } - }); - - console.log(chalk.gray('-'.repeat(80))); - const successCount = processedResults.filter(r => r.success).length; - console.log(chalk.cyan(`Summary: ${successCount}/${eligibleAgents.length} succeeded in ${formatDuration(totalDuration)}`)); - - return processedResults; + return displayParallelResults(results, agents, headerText, totalDuration); } // Setup graceful cleanup on process signals @@ -677,13 +586,19 @@ async function main( const vulnTimer = new Timer('phase-3-vulnerability-analysis'); console.log(chalk.red.bold('\n๐Ÿšจ PHASE 3: VULNERABILITY ANALYSIS')); - const vulnResults = await runParallelVuln( + const executionContext: AgentExecutionContext = { sourceDir, variables, distributedConfig, pipelineTestingMode, sessionMetadata - ); + }; + + const vulnResults = await runParallelAgents(executionContext, { + phaseType: 'vuln', + headerText: 'Vulnerability Analysis Results', + specialistLabel: 'vulnerability analysis specialists' + }); const vulnDuration = vulnTimer.stop(); console.log(chalk.green(`โœ… Vulnerability analysis phase complete in ${formatDuration(vulnDuration)}`)); @@ -692,13 +607,11 @@ async function main( const exploitTimer = new Timer('phase-4-exploitation'); console.log(chalk.red.bold('\n๐Ÿ’ฅ PHASE 4: EXPLOITATION')); - const exploitResults = await runParallelExploit( - sourceDir, - variables, - distributedConfig, - pipelineTestingMode, - sessionMetadata - ); + const exploitResults = await runParallelAgents(executionContext, { + phaseType: 'exploit', + headerText: 'Exploitation Results', + specialistLabel: 'exploitation specialists' + }); const exploitDuration = exploitTimer.stop(); console.log(chalk.green(`โœ… Exploitation phase complete in ${formatDuration(exploitDuration)}`)); diff --git a/src/types/agents.ts b/src/types/agents.ts index d358095..a47256f 100644 --- a/src/types/agents.ts +++ b/src/types/agents.ts @@ -47,10 +47,6 @@ export type PlaywrightAgent = export type AgentValidator = (sourceDir: string) => Promise; -export type AgentValidatorMap = Record; - -export type McpAgentMapping = Record; - export type AgentStatus = | 'pending' | 'in_progress' @@ -63,3 +59,26 @@ export interface AgentDefinition { displayName: string; prerequisites: AgentName[]; } + +/** + * Maps an agent name to its corresponding prompt file name. + */ +export function getPromptNameForAgent(agentName: AgentName): PromptName { + const mappings: Record = { + 'pre-recon': 'pre-recon-code', + 'recon': 'recon', + 'injection-vuln': 'vuln-injection', + 'xss-vuln': 'vuln-xss', + 'auth-vuln': 'vuln-auth', + 'ssrf-vuln': 'vuln-ssrf', + 'authz-vuln': 'vuln-authz', + 'injection-exploit': 'exploit-injection', + 'xss-exploit': 'exploit-xss', + 'auth-exploit': 'exploit-auth', + 'ssrf-exploit': 'exploit-ssrf', + 'authz-exploit': 'exploit-authz', + 'report': 'report-executive', + }; + + return mappings[agentName]; +} diff --git a/src/utils/concurrency.ts b/src/utils/concurrency.ts index e10de45..1edf03b 100644 --- a/src/utils/concurrency.ts +++ b/src/utils/concurrency.ts @@ -31,13 +31,12 @@ type UnlockFunction = () => void; * } * ``` */ +// Promise-based mutex with queue semantics - safe for parallel agents on same session export class SessionMutex { // Map of sessionId -> Promise (represents active lock) private locks: Map> = new Map(); - /** - * Acquire lock for a session - */ + // Wait for existing lock, then acquire. Queue ensures FIFO ordering. async lock(sessionId: string): Promise { if (this.locks.has(sessionId)) { // Wait for existing lock to be released diff --git a/src/utils/file-io.ts b/src/utils/file-io.ts new file mode 100644 index 0000000..0f35c83 --- /dev/null +++ b/src/utils/file-io.ts @@ -0,0 +1,73 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +/** + * File I/O Utilities + * + * Core utility functions for file operations including atomic writes, + * directory creation, and JSON file handling. + */ + +import fs from 'fs/promises'; + +/** + * Ensure directory exists (idempotent, race-safe) + */ +export async function ensureDirectory(dirPath: string): Promise { + try { + await fs.mkdir(dirPath, { recursive: true }); + } catch (error) { + // Ignore EEXIST errors (race condition safe) + if ((error as NodeJS.ErrnoException).code !== 'EEXIST') { + throw error; + } + } +} + +/** + * Atomic write using temp file + rename pattern + * Guarantees no partial writes or corruption on crash + */ +export async function atomicWrite(filePath: string, data: object | string): Promise { + const tempPath = `${filePath}.tmp`; + const content = typeof data === 'string' ? data : JSON.stringify(data, null, 2); + + try { + // Write to temp file + await fs.writeFile(tempPath, content, 'utf8'); + + // Atomic rename (POSIX guarantee: atomic on same filesystem) + await fs.rename(tempPath, filePath); + } catch (error) { + // Clean up temp file on failure + try { + await fs.unlink(tempPath); + } catch { + // Ignore cleanup errors + } + throw error; + } +} + +/** + * Read and parse JSON file + */ +export async function readJson(filePath: string): Promise { + const content = await fs.readFile(filePath, 'utf8'); + return JSON.parse(content) as T; +} + +/** + * Check if file exists + */ +export async function fileExists(filePath: string): Promise { + try { + await fs.access(filePath); + return true; + } catch { + return false; + } +} diff --git a/src/utils/formatting.ts b/src/utils/formatting.ts new file mode 100644 index 0000000..3f60d20 --- /dev/null +++ b/src/utils/formatting.ts @@ -0,0 +1,60 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +/** + * Formatting Utilities + * + * Generic formatting functions for durations, timestamps, and percentages. + */ + +/** + * Format duration in milliseconds to human-readable string + */ +export function formatDuration(ms: number): string { + if (ms < 1000) { + return `${ms}ms`; + } + + const seconds = ms / 1000; + if (seconds < 60) { + return `${seconds.toFixed(1)}s`; + } + + const minutes = Math.floor(seconds / 60); + const remainingSeconds = Math.floor(seconds % 60); + return `${minutes}m ${remainingSeconds}s`; +} + +/** + * Format timestamp to ISO 8601 string + */ +export function formatTimestamp(timestamp: number = Date.now()): string { + return new Date(timestamp).toISOString(); +} + +/** + * Calculate percentage + */ +export function calculatePercentage(part: number, total: number): number { + if (total === 0) return 0; + return (part / total) * 100; +} + +/** + * Extract agent type from description string for display purposes + */ +export function extractAgentType(description: string): string { + if (description.includes('Pre-recon')) { + return 'pre-reconnaissance'; + } + if (description.includes('Recon')) { + return 'reconnaissance'; + } + if (description.includes('Report')) { + return 'report generation'; + } + return 'analysis'; +} diff --git a/src/utils/functional.ts b/src/utils/functional.ts new file mode 100644 index 0000000..ee1dac7 --- /dev/null +++ b/src/utils/functional.ts @@ -0,0 +1,29 @@ +// Copyright (C) 2025 Keygraph, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License version 3 +// as published by the Free Software Foundation. + +/** + * Functional Programming Utilities + * + * Generic functional composition patterns for async operations. + */ + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type PipelineFunction = (x: any) => any | Promise; + +/** + * Async pipeline that passes result through a series of functions. + * Clearer than reduce-based pipe and easier to debug. + */ +export async function asyncPipe( + initial: unknown, + ...fns: PipelineFunction[] +): Promise { + let result = initial; + for (const fn of fns) { + result = await fn(result); + } + return result as TResult; +} diff --git a/src/utils/git-manager.ts b/src/utils/git-manager.ts index b48ad96..969e811 100644 --- a/src/utils/git-manager.ts +++ b/src/utils/git-manager.ts @@ -13,7 +13,57 @@ interface GitOperationResult { error?: Error; } -// Global git operations semaphore to prevent index.lock conflicts during parallel execution +/** + * Get list of changed files from git status --porcelain output + */ +async function getChangedFiles( + sourceDir: string, + operationDescription: string +): Promise { + const status = await executeGitCommandWithRetry( + ['git', 'status', '--porcelain'], + sourceDir, + operationDescription + ); + return status.stdout + .trim() + .split('\n') + .filter((line) => line.length > 0); +} + +/** + * Log a summary of changed files with truncation for long lists + */ +function logChangeSummary( + changes: string[], + messageWithChanges: string, + messageWithoutChanges: string, + color: typeof chalk.green, + maxToShow: number = 5 +): void { + if (changes.length > 0) { + console.log(color(messageWithChanges.replace('{count}', String(changes.length)))); + changes.slice(0, maxToShow).forEach((change) => console.log(chalk.gray(` ${change}`))); + if (changes.length > maxToShow) { + console.log(chalk.gray(` ... and ${changes.length - maxToShow} more files`)); + } + } else { + console.log(color(messageWithoutChanges)); + } +} + +/** + * Convert unknown error to GitOperationResult + */ +function toErrorResult(error: unknown): GitOperationResult { + const errMsg = error instanceof Error ? error.message : String(error); + return { + success: false, + error: error instanceof Error ? error : new Error(errMsg), + }; +} + +// Serializes git operations to prevent index.lock conflicts during parallel agent execution class GitSemaphore { private queue: Array<() => void> = []; private running: boolean = false; @@ -41,33 +91,38 @@ class GitSemaphore { const gitSemaphore = new GitSemaphore(); -// Execute git commands with retry logic for index.lock conflicts -export const executeGitCommandWithRetry = async ( +const GIT_LOCK_ERROR_PATTERNS = [ + 'index.lock', + 'unable to lock', + 'Another git process', + 'fatal: Unable to create', + 'fatal: index file', +]; + +function isGitLockError(errorMessage: string): boolean { + return GIT_LOCK_ERROR_PATTERNS.some((pattern) => errorMessage.includes(pattern)); +} + +// Retries git commands on lock conflicts with exponential backoff +export async function executeGitCommandWithRetry( commandArgs: string[], sourceDir: string, description: string, maxRetries: number = 5 -): Promise<{ stdout: string; stderr: string }> => { +): Promise<{ stdout: string; stderr: string }> { await gitSemaphore.acquire(); try { for (let attempt = 1; attempt <= maxRetries; attempt++) { try { - // For arrays like ['git', 'status', '--porcelain'], execute parts separately const [cmd, ...args] = commandArgs; const result = await $`cd ${sourceDir} && ${cmd} ${args}`; return result; } catch (error) { const errMsg = error instanceof Error ? error.message : String(error); - const isLockError = - errMsg.includes('index.lock') || - errMsg.includes('unable to lock') || - errMsg.includes('Another git process') || - errMsg.includes('fatal: Unable to create') || - errMsg.includes('fatal: index file'); - if (isLockError && attempt < maxRetries) { - const delay = Math.pow(2, attempt - 1) * 1000; // Exponential backoff: 1s, 2s, 4s, 8s, 16s + if (isGitLockError(errMsg) && attempt < maxRetries) { + const delay = Math.pow(2, attempt - 1) * 1000; console.log( chalk.yellow( ` โš ๏ธ Git lock conflict during ${description} (attempt ${attempt}/${maxRetries}). Retrying in ${delay}ms...` @@ -80,84 +135,69 @@ export const executeGitCommandWithRetry = async ( throw error; } } - // Should never reach here but TypeScript needs a return throw new Error(`Git command failed after ${maxRetries} retries`); } finally { gitSemaphore.release(); } -}; +} -// Pure functions for Git workspace management -const cleanWorkspace = async ( +// Two-phase reset: hard reset (tracked files) + clean (untracked files) +export async function rollbackGitWorkspace( sourceDir: string, - reason: string = 'clean start' -): Promise => { - console.log(chalk.blue(` ๐Ÿงน Cleaning workspace for ${reason}`)); + reason: string = 'retry preparation' +): Promise { + console.log(chalk.yellow(` ๐Ÿ”„ Rolling back workspace for ${reason}`)); try { - // Check for uncommitted changes - const status = await $`cd ${sourceDir} && git status --porcelain`; - const hasChanges = status.stdout.trim().length > 0; + const changes = await getChangedFiles(sourceDir, 'status check for rollback'); - if (hasChanges) { - // Show what we're about to remove - const changes = status.stdout - .trim() - .split('\n') - .filter((line) => line.length > 0); - console.log(chalk.yellow(` ๐Ÿ”„ Rolling back workspace for ${reason}`)); + await executeGitCommandWithRetry( + ['git', 'reset', '--hard', 'HEAD'], + sourceDir, + 'hard reset for rollback' + ); + await executeGitCommandWithRetry( + ['git', 'clean', '-fd'], + sourceDir, + 'cleaning untracked files for rollback' + ); - await $`cd ${sourceDir} && git reset --hard HEAD`; - await $`cd ${sourceDir} && git clean -fd`; - - console.log( - chalk.yellow(` โœ… Rollback completed - removed ${changes.length} contaminated changes:`) - ); - changes.slice(0, 3).forEach((change) => console.log(chalk.gray(` ${change}`))); - if (changes.length > 3) { - console.log(chalk.gray(` ... and ${changes.length - 3} more files`)); - } - } else { - console.log(chalk.blue(` โœ… Workspace already clean (no changes to remove)`)); - } - return { success: true, hadChanges: hasChanges }; + logChangeSummary( + changes, + ' โœ… Rollback completed - removed {count} contaminated changes:', + ' โœ… Rollback completed - no changes to remove', + chalk.yellow, + 3 + ); + return { success: true }; } catch (error) { - const errMsg = error instanceof Error ? error.message : String(error); - console.log(chalk.yellow(` โš ๏ธ Workspace cleanup failed: ${errMsg}`)); - return { success: false, error: error instanceof Error ? error : new Error(errMsg) }; + const result = toErrorResult(error); + console.log(chalk.red(` โŒ Rollback failed after retries: ${result.error?.message}`)); + return result; } -}; +} -export const createGitCheckpoint = async ( +// Creates checkpoint before each attempt. First attempt preserves workspace; retries clean it. +export async function createGitCheckpoint( sourceDir: string, description: string, attempt: number -): Promise => { +): Promise { console.log(chalk.blue(` ๐Ÿ“ Creating checkpoint for ${description} (attempt ${attempt})`)); try { - // Only clean workspace on retry attempts (attempt > 1), not on first attempts - // This preserves deliverables between agents while still cleaning on actual retries + // First attempt: preserve existing deliverables. Retries: clean workspace to prevent pollution if (attempt > 1) { - const cleanResult = await cleanWorkspace(sourceDir, `${description} (retry cleanup)`); + const cleanResult = await rollbackGitWorkspace(sourceDir, `${description} (retry cleanup)`); if (!cleanResult.success) { - const errMsg = cleanResult.error?.message || 'Unknown error'; console.log( - chalk.yellow(` โš ๏ธ Workspace cleanup failed, continuing anyway: ${errMsg}`) + chalk.yellow(` โš ๏ธ Workspace cleanup failed, continuing anyway: ${cleanResult.error?.message}`) ); } } - // Check for uncommitted changes with retry logic - const status = await executeGitCommandWithRetry( - ['git', 'status', '--porcelain'], - sourceDir, - 'status check' - ); - const hasChanges = status.stdout.trim().length > 0; + const changes = await getChangedFiles(sourceDir, 'status check'); + const hasChanges = changes.length > 0; - // Stage changes with retry logic await executeGitCommandWithRetry(['git', 'add', '-A'], sourceDir, 'staging changes'); - - // Create commit with retry logic await executeGitCommandWithRetry( ['git', 'commit', '-m', `๐Ÿ“ Checkpoint: ${description} (attempt ${attempt})`, '--allow-empty'], sourceDir, @@ -171,106 +211,54 @@ export const createGitCheckpoint = async ( } return { success: true }; } catch (error) { - const errMsg = error instanceof Error ? error.message : String(error); - console.log(chalk.yellow(` โš ๏ธ Checkpoint creation failed after retries: ${errMsg}`)); - return { success: false, error: error instanceof Error ? error : new Error(errMsg) }; + const result = toErrorResult(error); + console.log(chalk.yellow(` โš ๏ธ Checkpoint creation failed after retries: ${result.error?.message}`)); + return result; } -}; +} -export const commitGitSuccess = async ( +export async function commitGitSuccess( sourceDir: string, description: string -): Promise => { +): Promise { console.log(chalk.green(` ๐Ÿ’พ Committing successful results for ${description}`)); try { - // Check what we're about to commit with retry logic - const status = await executeGitCommandWithRetry( - ['git', 'status', '--porcelain'], - sourceDir, - 'status check for success commit' - ); - const changes = status.stdout - .trim() - .split('\n') - .filter((line) => line.length > 0); + const changes = await getChangedFiles(sourceDir, 'status check for success commit'); - // Stage changes with retry logic await executeGitCommandWithRetry( ['git', 'add', '-A'], sourceDir, 'staging changes for success commit' ); - - // Create success commit with retry logic await executeGitCommandWithRetry( ['git', 'commit', '-m', `โœ… ${description}: completed successfully`, '--allow-empty'], sourceDir, 'creating success commit' ); - if (changes.length > 0) { - console.log(chalk.green(` โœ… Success commit created with ${changes.length} file changes:`)); - changes.slice(0, 5).forEach((change) => console.log(chalk.gray(` ${change}`))); - if (changes.length > 5) { - console.log(chalk.gray(` ... and ${changes.length - 5} more files`)); - } - } else { - console.log(chalk.green(` โœ… Empty success commit created (agent made no file changes)`)); - } + logChangeSummary( + changes, + ' โœ… Success commit created with {count} file changes:', + ' โœ… Empty success commit created (agent made no file changes)', + chalk.green, + 5 + ); return { success: true }; } catch (error) { - const errMsg = error instanceof Error ? error.message : String(error); - console.log(chalk.yellow(` โš ๏ธ Success commit failed after retries: ${errMsg}`)); - return { success: false, error: error instanceof Error ? error : new Error(errMsg) }; + const result = toErrorResult(error); + console.log(chalk.yellow(` โš ๏ธ Success commit failed after retries: ${result.error?.message}`)); + return result; } -}; +} -export const rollbackGitWorkspace = async ( - sourceDir: string, - reason: string = 'retry preparation' -): Promise => { - console.log(chalk.yellow(` ๐Ÿ”„ Rolling back workspace for ${reason}`)); +/** + * Get current git commit hash + */ +export async function getGitCommitHash(sourceDir: string): Promise { try { - // Show what we're about to remove with retry logic - const status = await executeGitCommandWithRetry( - ['git', 'status', '--porcelain'], - sourceDir, - 'status check for rollback' - ); - const changes = status.stdout - .trim() - .split('\n') - .filter((line) => line.length > 0); - - // Reset to HEAD with retry logic - await executeGitCommandWithRetry( - ['git', 'reset', '--hard', 'HEAD'], - sourceDir, - 'hard reset for rollback' - ); - - // Clean untracked files with retry logic - await executeGitCommandWithRetry( - ['git', 'clean', '-fd'], - sourceDir, - 'cleaning untracked files for rollback' - ); - - if (changes.length > 0) { - console.log( - chalk.yellow(` โœ… Rollback completed - removed ${changes.length} contaminated changes:`) - ); - changes.slice(0, 3).forEach((change) => console.log(chalk.gray(` ${change}`))); - if (changes.length > 3) { - console.log(chalk.gray(` ... and ${changes.length - 3} more files`)); - } - } else { - console.log(chalk.yellow(` โœ… Rollback completed - no changes to remove`)); - } - return { success: true }; - } catch (error) { - const errMsg = error instanceof Error ? error.message : String(error); - console.log(chalk.red(` โŒ Rollback failed after retries: ${errMsg}`)); - return { success: false, error: error instanceof Error ? error : new Error(errMsg) }; + const result = await $`cd ${sourceDir} && git rev-parse HEAD`; + return result.stdout.trim(); + } catch { + return null; } -}; +} diff --git a/src/utils/metrics.ts b/src/utils/metrics.ts index 93ec456..01cf79c 100644 --- a/src/utils/metrics.ts +++ b/src/utils/metrics.ts @@ -5,7 +5,7 @@ // as published by the Free Software Foundation. import chalk from 'chalk'; -import { formatDuration } from '../audit/utils.js'; +import { formatDuration } from './formatting.js'; // Timing utilities