refactor: modularize claude-executor and extract shared utilities

- Extract message handling into src/ai/message-handlers.ts with pure functions
- Extract output formatting into src/ai/output-formatters.ts
- Extract progress management into src/ai/progress-manager.ts
- Add audit-logger.ts with Null Object pattern for optional logging
- Add shared utilities: formatting.ts, file-io.ts, functional.ts
- Consolidate getPromptNameForAgent into src/types/agents.ts
This commit is contained in:
ajmallesh
2026-01-12 12:14:49 -08:00
parent bc52d67dd5
commit f84414d5ca
21 changed files with 1636 additions and 1107 deletions

79
src/ai/audit-logger.ts Normal file
View File

@@ -0,0 +1,79 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Null Object pattern for audit logging - callers never check for null
import type { AuditSession } from '../audit/index.js';
import { formatTimestamp } from '../utils/formatting.js';
export interface AuditLogger {
logLlmResponse(turn: number, content: string): Promise<void>;
logToolStart(toolName: string, parameters: unknown): Promise<void>;
logToolEnd(result: unknown): Promise<void>;
logError(error: Error, duration: number, turns: number): Promise<void>;
}
class RealAuditLogger implements AuditLogger {
private auditSession: AuditSession;
constructor(auditSession: AuditSession) {
this.auditSession = auditSession;
}
async logLlmResponse(turn: number, content: string): Promise<void> {
await this.auditSession.logEvent('llm_response', {
turn,
content,
timestamp: formatTimestamp(),
});
}
async logToolStart(toolName: string, parameters: unknown): Promise<void> {
await this.auditSession.logEvent('tool_start', {
toolName,
parameters,
timestamp: formatTimestamp(),
});
}
async logToolEnd(result: unknown): Promise<void> {
await this.auditSession.logEvent('tool_end', {
result,
timestamp: formatTimestamp(),
});
}
async logError(error: Error, duration: number, turns: number): Promise<void> {
await this.auditSession.logEvent('error', {
message: error.message,
errorType: error.constructor.name,
stack: error.stack,
duration,
turns,
timestamp: formatTimestamp(),
});
}
}
/** Null Object implementation - all methods are safe no-ops */
class NullAuditLogger implements AuditLogger {
async logLlmResponse(_turn: number, _content: string): Promise<void> {}
async logToolStart(_toolName: string, _parameters: unknown): Promise<void> {}
async logToolEnd(_result: unknown): Promise<void> {}
async logError(_error: Error, _duration: number, _turns: number): Promise<void> {}
}
// Returns no-op when auditSession is null
export function createAuditLogger(auditSession: AuditSession | null): AuditLogger {
if (auditSession) {
return new RealAuditLogger(auditSession);
}
return new NullAuditLogger();
}

View File

@@ -4,34 +4,33 @@
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
import { $, fs, path } from 'zx';
// Production Claude agent execution with retry, git checkpoints, and audit logging
import { fs, path } from 'zx';
import chalk, { type ChalkInstance } from 'chalk';
import { query } from '@anthropic-ai/claude-agent-sdk';
import { fileURLToPath } from 'url';
import { dirname } from 'path';
import { isRetryableError, getRetryDelay, PentestError } from '../error-handling.js';
import { ProgressIndicator } from '../progress-indicator.js';
import { timingResults, costResults, Timer } from '../utils/metrics.js';
import { formatDuration } from '../audit/utils.js';
import { createGitCheckpoint, commitGitSuccess, rollbackGitWorkspace } from '../utils/git-manager.js';
import { timingResults, Timer } from '../utils/metrics.js';
import { formatTimestamp } from '../utils/formatting.js';
import { createGitCheckpoint, commitGitSuccess, rollbackGitWorkspace, getGitCommitHash } from '../utils/git-manager.js';
import { AGENT_VALIDATORS, MCP_AGENT_MAPPING } from '../constants.js';
import { filterJsonToolCalls, getAgentPrefix } from '../utils/output-formatter.js';
import { generateSessionLogPath } from '../session-manager.js';
import { AuditSession } from '../audit/index.js';
import { createShannonHelperServer } from '../../mcp-server/dist/index.js';
import type { SessionMetadata } from '../audit/utils.js';
import type { PromptName } from '../types/index.js';
import { getPromptNameForAgent } from '../types/agents.js';
import type { AgentName } from '../types/index.js';
import { dispatchMessage } from './message-handlers.js';
import { detectExecutionContext, formatErrorOutput, formatCompletionMessage } from './output-formatters.js';
import { createProgressManager } from './progress-manager.js';
import { createAuditLogger } from './audit-logger.js';
// Extend global for loader flag
declare global {
var SHANNON_DISABLE_LOADER: boolean | undefined;
}
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Result types
interface ClaudePromptResult {
result?: string | null;
success: boolean;
@@ -47,7 +46,6 @@ interface ClaudePromptResult {
retryable?: boolean;
}
// MCP Server types
interface StdioMcpServer {
type: 'stdio';
command: string;
@@ -57,157 +55,29 @@ interface StdioMcpServer {
type McpServer = ReturnType<typeof createShannonHelperServer> | StdioMcpServer;
/**
* Convert agent name to prompt name for MCP_AGENT_MAPPING lookup
*/
function agentNameToPromptName(agentName: string): PromptName {
// Special cases
if (agentName === 'pre-recon') return 'pre-recon-code';
if (agentName === 'report') return 'report-executive';
if (agentName === 'recon') return 'recon';
// Pattern: {type}-vuln → vuln-{type}
const vulnMatch = agentName.match(/^(.+)-vuln$/);
if (vulnMatch) {
return `vuln-${vulnMatch[1]}` as PromptName;
}
// Pattern: {type}-exploit → exploit-{type}
const exploitMatch = agentName.match(/^(.+)-exploit$/);
if (exploitMatch) {
return `exploit-${exploitMatch[1]}` as PromptName;
}
// Default: return as-is
return agentName as PromptName;
}
// Simplified validation using direct agent name mapping
async function validateAgentOutput(
result: ClaudePromptResult,
agentName: string | null,
sourceDir: string
): Promise<boolean> {
console.log(chalk.blue(` 🔍 Validating ${agentName} agent output`));
try {
// Check if agent completed successfully
if (!result.success || !result.result) {
console.log(chalk.red(` ❌ Validation failed: Agent execution was unsuccessful`));
return false;
}
// Get validator function for this agent
const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined;
if (!validator) {
console.log(chalk.yellow(` ⚠️ No validator found for agent "${agentName}" - assuming success`));
console.log(chalk.green(` ✅ Validation passed: Unknown agent with successful result`));
return true;
}
console.log(chalk.blue(` 📋 Using validator for agent: ${agentName}`));
console.log(chalk.blue(` 📂 Source directory: ${sourceDir}`));
// Apply validation function
const validationResult = await validator(sourceDir);
if (validationResult) {
console.log(chalk.green(` ✅ Validation passed: Required files/structure present`));
} else {
console.log(chalk.red(` ❌ Validation failed: Missing required deliverable files`));
}
return validationResult;
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.red(` ❌ Validation failed with error: ${errMsg}`));
return false; // Assume invalid on validation error
}
}
// Pure function: Run Claude Code with SDK - Maximum Autonomy
// WARNING: This is a low-level function. Use runClaudePromptWithRetry() for agent execution
async function runClaudePrompt(
prompt: string,
// Configures MCP servers for agent execution, with Docker-specific Chromium handling
function buildMcpServers(
sourceDir: string,
_allowedTools: string = 'Read',
context: string = '',
description: string = 'Claude analysis',
agentName: string | null = null,
colorFn: ChalkInstance = chalk.cyan,
sessionMetadata: SessionMetadata | null = null,
auditSession: AuditSession | null = null,
attemptNumber: number = 1
): Promise<ClaudePromptResult> {
const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`);
const fullPrompt = context ? `${context}\n\n${prompt}` : prompt;
let totalCost = 0;
let partialCost = 0; // Track partial cost for crash safety
agentName: string | null
): Record<string, McpServer> {
const shannonHelperServer = createShannonHelperServer(sourceDir);
// Auto-detect execution mode to adjust logging behavior
const isParallelExecution = description.includes('vuln agent') || description.includes('exploit agent');
const useCleanOutput = description.includes('Pre-recon agent') ||
description.includes('Recon agent') ||
description.includes('Executive Summary and Report Cleanup') ||
description.includes('vuln agent') ||
description.includes('exploit agent');
const mcpServers: Record<string, McpServer> = {
'shannon-helper': shannonHelperServer,
};
// Disable status manager - using simple JSON filtering for all agents now
const statusManager = null;
if (agentName) {
const promptName = getPromptNameForAgent(agentName as AgentName);
const playwrightMcpName = MCP_AGENT_MAPPING[promptName as keyof typeof MCP_AGENT_MAPPING] || null;
// Setup progress indicator for clean output agents (unless disabled via flag)
let progressIndicator: ProgressIndicator | null = null;
if (useCleanOutput && !global.SHANNON_DISABLE_LOADER) {
const agentType = description.includes('Pre-recon') ? 'pre-reconnaissance' :
description.includes('Recon') ? 'reconnaissance' :
description.includes('Report') ? 'report generation' : 'analysis';
progressIndicator = new ProgressIndicator(`Running ${agentType}...`);
}
// NOTE: Logging now handled by AuditSession (append-only, crash-safe)
let logFilePath: string | null = null;
if (sessionMetadata && sessionMetadata.webUrl && sessionMetadata.id) {
const timestamp = new Date().toISOString().replace(/T/, '_').replace(/[:.]/g, '-').slice(0, 19);
const agentKey = description.toLowerCase().replace(/\s+/g, '-');
const logDir = generateSessionLogPath(sessionMetadata.webUrl, sessionMetadata.id);
logFilePath = path.join(logDir, `${timestamp}_${agentKey}_attempt-${attemptNumber}.log`);
} else {
console.log(chalk.blue(` 🤖 Running Claude Code: ${description}...`));
}
// Declare variables that need to be accessible in both try and catch blocks
let turnCount = 0;
try {
// Create MCP server with target directory context
const shannonHelperServer = createShannonHelperServer(sourceDir);
// Look up agent's assigned Playwright MCP server
let playwrightMcpName: string | null = null;
if (agentName) {
const promptName = agentNameToPromptName(agentName);
playwrightMcpName = MCP_AGENT_MAPPING[promptName as keyof typeof MCP_AGENT_MAPPING] || null;
if (playwrightMcpName) {
console.log(chalk.gray(` 🎭 Assigned ${agentName}${playwrightMcpName}`));
}
}
// Configure MCP servers: shannon-helper (SDK) + playwright-agentN (stdio)
const mcpServers: Record<string, McpServer> = {
'shannon-helper': shannonHelperServer,
};
// Add Playwright MCP server if this agent needs browser automation
if (playwrightMcpName) {
console.log(chalk.gray(` Assigned ${agentName} -> ${playwrightMcpName}`));
const userDataDir = `/tmp/${playwrightMcpName}`;
// Detect if running in Docker via explicit environment variable
// Docker uses system Chromium; local dev uses Playwright's bundled browsers
const isDocker = process.env.SHANNON_DOCKER === 'true';
// Build args array - conditionally add --executable-path for Docker
const mcpArgs: string[] = [
'@playwright/mcp@latest',
'--isolated',
@@ -220,7 +90,6 @@ async function runClaudePrompt(
mcpArgs.push('--browser', 'chromium');
}
// Filter out undefined env values for type safety
const envVars: Record<string, string> = Object.fromEntries(
Object.entries({
...process.env,
@@ -236,256 +105,169 @@ async function runClaudePrompt(
env: envVars,
};
}
}
const options = {
model: 'claude-sonnet-4-5-20250929', // Use latest Claude 4.5 Sonnet
maxTurns: 10_000, // Maximum turns for autonomous work
cwd: sourceDir, // Set working directory using SDK option
permissionMode: 'bypassPermissions' as const, // Bypass all permission checks for pentesting
mcpServers,
return mcpServers;
}
function outputLines(lines: string[]): void {
for (const line of lines) {
console.log(line);
}
}
async function writeErrorLog(
err: Error & { code?: string; status?: number },
sourceDir: string,
fullPrompt: string,
duration: number
): Promise<void> {
try {
const errorLog = {
timestamp: formatTimestamp(),
agent: 'claude-executor',
error: {
name: err.constructor.name,
message: err.message,
code: err.code,
status: err.status,
stack: err.stack
},
context: {
sourceDir,
prompt: fullPrompt.slice(0, 200) + '...',
retryable: isRetryableError(err)
},
duration
};
const logPath = path.join(sourceDir, 'error.log');
await fs.appendFile(logPath, JSON.stringify(errorLog) + '\n');
} catch (logError) {
const logErrMsg = logError instanceof Error ? logError.message : String(logError);
console.log(chalk.gray(` (Failed to write error log: ${logErrMsg})`));
}
}
// SDK Options only shown for verbose agents (not clean output)
if (!useCleanOutput) {
console.log(chalk.gray(` SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`));
async function validateAgentOutput(
result: ClaudePromptResult,
agentName: string | null,
sourceDir: string
): Promise<boolean> {
console.log(chalk.blue(` Validating ${agentName} agent output`));
try {
// Check if agent completed successfully
if (!result.success || !result.result) {
console.log(chalk.red(` Validation failed: Agent execution was unsuccessful`));
return false;
}
let result: string | null = null;
const messages: string[] = [];
let apiErrorDetected = false;
// Get validator function for this agent
const validator = agentName ? AGENT_VALIDATORS[agentName as keyof typeof AGENT_VALIDATORS] : undefined;
// Start progress indicator for clean output agents
if (progressIndicator) {
progressIndicator.start();
if (!validator) {
console.log(chalk.yellow(` No validator found for agent "${agentName}" - assuming success`));
console.log(chalk.green(` Validation passed: Unknown agent with successful result`));
return true;
}
let lastHeartbeat = Date.now();
const HEARTBEAT_INTERVAL = 30000; // 30 seconds
console.log(chalk.blue(` Using validator for agent: ${agentName}`));
console.log(chalk.blue(` Source directory: ${sourceDir}`));
try {
for await (const message of query({ prompt: fullPrompt, options })) {
// Periodic heartbeat for long-running agents (only when loader is disabled)
const now = Date.now();
if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) {
console.log(chalk.blue(` ⏱️ [${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`));
lastHeartbeat = now;
}
// Apply validation function
const validationResult = await validator(sourceDir);
if (message.type === "assistant") {
turnCount++;
const messageContent = message.message as { content: unknown };
const content = Array.isArray(messageContent.content)
? messageContent.content.map((c: { text?: string }) => c.text || JSON.stringify(c)).join('\n')
: String(messageContent.content);
if (statusManager) {
// Smart status updates for parallel execution - disabled
} else if (useCleanOutput) {
// Clean output for all agents: filter JSON tool calls but show meaningful text
const cleanedContent = filterJsonToolCalls(content);
if (cleanedContent.trim()) {
// Temporarily stop progress indicator to show output
if (progressIndicator) {
progressIndicator.stop();
}
if (isParallelExecution) {
// Compact output for parallel agents with prefixes
const prefix = getAgentPrefix(description);
console.log(colorFn(`${prefix} ${cleanedContent}`));
} else {
// Full turn output for single agents
console.log(colorFn(`\n 🤖 Turn ${turnCount} (${description}):`));
console.log(colorFn(` ${cleanedContent}`));
}
// Restart progress indicator after output
if (progressIndicator) {
progressIndicator.start();
}
}
} else {
// Full streaming output - show complete messages with specialist color
console.log(colorFn(`\n 🤖 Turn ${turnCount} (${description}):`));
console.log(colorFn(` ${content}`));
}
// Log to audit system (crash-safe, append-only)
if (auditSession) {
await auditSession.logEvent('llm_response', {
turn: turnCount,
content,
timestamp: new Date().toISOString()
});
}
messages.push(content);
// Check for API error patterns in assistant message content
if (content && typeof content === 'string') {
const lowerContent = content.toLowerCase();
if (lowerContent.includes('session limit reached')) {
throw new PentestError('Session limit reached', 'billing', false);
}
if (lowerContent.includes('api error') || lowerContent.includes('terminated')) {
apiErrorDetected = true;
console.log(chalk.red(` ⚠️ API Error detected in assistant response: ${content.trim()}`));
}
}
} else if (message.type === "system" && (message as { subtype?: string }).subtype === "init") {
// Show useful system info only for verbose agents
if (!useCleanOutput) {
const initMsg = message as { model?: string; permissionMode?: string; mcp_servers?: Array<{ name: string; status: string }> };
console.log(chalk.blue(` Model: ${initMsg.model}, Permission: ${initMsg.permissionMode}`));
if (initMsg.mcp_servers && initMsg.mcp_servers.length > 0) {
const mcpStatus = initMsg.mcp_servers.map(s => `${s.name}(${s.status})`).join(', ');
console.log(chalk.blue(` 📦 MCP: ${mcpStatus}`));
}
}
} else if (message.type === "user") {
// Skip user messages (these are our own inputs echoed back)
continue;
} else if ((message.type as string) === "tool_use") {
const toolMsg = message as unknown as { name: string; input?: Record<string, unknown> };
console.log(chalk.yellow(`\n 🔧 Using Tool: ${toolMsg.name}`));
if (toolMsg.input && Object.keys(toolMsg.input).length > 0) {
console.log(chalk.gray(` Input: ${JSON.stringify(toolMsg.input, null, 2)}`));
}
// Log tool start event
if (auditSession) {
await auditSession.logEvent('tool_start', {
toolName: toolMsg.name,
parameters: toolMsg.input,
timestamp: new Date().toISOString()
});
}
} else if ((message.type as string) === "tool_result") {
const resultMsg = message as unknown as { content?: unknown };
console.log(chalk.green(` ✅ Tool Result:`));
if (resultMsg.content) {
// Show tool results but truncate if too long
const resultStr = typeof resultMsg.content === 'string' ? resultMsg.content : JSON.stringify(resultMsg.content, null, 2);
if (resultStr.length > 500) {
console.log(chalk.gray(` ${resultStr.slice(0, 500)}...\n [Result truncated - ${resultStr.length} total chars]`));
} else {
console.log(chalk.gray(` ${resultStr}`));
}
}
// Log tool end event
if (auditSession) {
await auditSession.logEvent('tool_end', {
result: resultMsg.content,
timestamp: new Date().toISOString()
});
}
} else if (message.type === "result") {
const resultMessage = message as {
result?: string;
total_cost_usd?: number;
duration_ms?: number;
subtype?: string;
permission_denials?: unknown[];
};
result = resultMessage.result || null;
if (!statusManager) {
if (useCleanOutput) {
// Clean completion output - just duration and cost
console.log(chalk.magenta(`\n 🏁 COMPLETED:`));
const cost = resultMessage.total_cost_usd || 0;
console.log(chalk.gray(` ⏱️ Duration: ${((resultMessage.duration_ms || 0)/1000).toFixed(1)}s, Cost: $${cost.toFixed(4)}`));
if (resultMessage.subtype === "error_max_turns") {
console.log(chalk.red(` ⚠️ Stopped: Hit maximum turns limit`));
} else if (resultMessage.subtype === "error_during_execution") {
console.log(chalk.red(` ❌ Stopped: Execution error`));
}
if (resultMessage.permission_denials && resultMessage.permission_denials.length > 0) {
console.log(chalk.yellow(` 🚫 ${resultMessage.permission_denials.length} permission denials`));
}
} else {
// Full completion output for agents without clean output
console.log(chalk.magenta(`\n 🏁 COMPLETED:`));
const cost = resultMessage.total_cost_usd || 0;
console.log(chalk.gray(` ⏱️ Duration: ${((resultMessage.duration_ms || 0)/1000).toFixed(1)}s, Cost: $${cost.toFixed(4)}`));
if (resultMessage.subtype === "error_max_turns") {
console.log(chalk.red(` ⚠️ Stopped: Hit maximum turns limit`));
} else if (resultMessage.subtype === "error_during_execution") {
console.log(chalk.red(` ❌ Stopped: Execution error`));
}
if (resultMessage.permission_denials && resultMessage.permission_denials.length > 0) {
console.log(chalk.yellow(` 🚫 ${resultMessage.permission_denials.length} permission denials`));
}
// Show result content (if it's reasonable length)
if (result && typeof result === 'string') {
if (result.length > 1000) {
console.log(chalk.magenta(` 📄 ${result.slice(0, 1000)}... [${result.length} total chars]`));
} else {
console.log(chalk.magenta(` 📄 ${result}`));
}
}
}
}
// Track cost for all agents
const cost = resultMessage.total_cost_usd || 0;
const agentKey = description.toLowerCase().replace(/\s+/g, '-');
costResults.agents[agentKey] = cost;
costResults.total += cost;
// Store cost for return value and partial tracking
totalCost = cost;
partialCost = cost;
break;
} else {
// Log any other message types we might not be handling
console.log(chalk.gray(` 💬 ${message.type}: ${JSON.stringify(message, null, 2)}`));
}
}
} catch (queryError) {
throw queryError; // Re-throw to outer catch
if (validationResult) {
console.log(chalk.green(` Validation passed: Required files/structure present`));
} else {
console.log(chalk.red(` Validation failed: Missing required deliverable files`));
}
return validationResult;
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.red(` Validation failed with error: ${errMsg}`));
return false;
}
}
// Low-level SDK execution. Handles message streaming, progress, and audit logging.
async function runClaudePrompt(
prompt: string,
sourceDir: string,
context: string = '',
description: string = 'Claude analysis',
agentName: string | null = null,
colorFn: ChalkInstance = chalk.cyan,
sessionMetadata: SessionMetadata | null = null,
auditSession: AuditSession | null = null,
attemptNumber: number = 1
): Promise<ClaudePromptResult> {
const timer = new Timer(`agent-${description.toLowerCase().replace(/\s+/g, '-')}`);
const fullPrompt = context ? `${context}\n\n${prompt}` : prompt;
const execContext = detectExecutionContext(description);
const progress = createProgressManager(
{ description, useCleanOutput: execContext.useCleanOutput },
global.SHANNON_DISABLE_LOADER ?? false
);
const auditLogger = createAuditLogger(auditSession);
const logFilePath = buildLogFilePath(sessionMetadata, execContext.agentKey, attemptNumber);
if (!logFilePath) {
console.log(chalk.blue(` Running Claude Code: ${description}...`));
}
const mcpServers = buildMcpServers(sourceDir, agentName);
const options = {
model: 'claude-sonnet-4-5-20250929',
maxTurns: 10_000,
cwd: sourceDir,
permissionMode: 'bypassPermissions' as const,
mcpServers,
};
if (!execContext.useCleanOutput) {
console.log(chalk.gray(` SDK Options: maxTurns=${options.maxTurns}, cwd=${sourceDir}, permissions=BYPASS`));
}
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let totalCost = 0;
progress.start();
try {
const messageLoopResult = await processMessageStream(
fullPrompt,
options,
{ execContext, description, colorFn, progress, auditLogger },
timer
);
turnCount = messageLoopResult.turnCount;
result = messageLoopResult.result;
apiErrorDetected = messageLoopResult.apiErrorDetected;
totalCost = messageLoopResult.cost;
const duration = timer.stop();
const agentKey = description.toLowerCase().replace(/\s+/g, '-');
timingResults.agents[agentKey] = duration;
timingResults.agents[execContext.agentKey] = duration;
// API error detection is logged but not immediately failed
if (apiErrorDetected) {
console.log(chalk.yellow(` ⚠️ API Error detected in ${description} - will validate deliverables before failing`));
console.log(chalk.yellow(` API Error detected in ${description} - will validate deliverables before failing`));
}
// Show completion messages based on agent type
if (progressIndicator) {
const agentType = description.includes('Pre-recon') ? 'Pre-recon analysis' :
description.includes('Recon') ? 'Reconnaissance' :
description.includes('Report') ? 'Report generation' : 'Analysis';
progressIndicator.finish(`${agentType} complete! (${turnCount} turns, ${formatDuration(duration)})`);
} else if (isParallelExecution) {
const prefix = getAgentPrefix(description);
console.log(chalk.green(`${prefix} ✅ Complete (${turnCount} turns, ${formatDuration(duration)})`));
} else if (!useCleanOutput) {
console.log(chalk.green(` ✅ Claude Code completed: ${description} (${turnCount} turns) in ${formatDuration(duration)}`));
}
progress.finish(formatCompletionMessage(execContext, description, turnCount, duration));
// Return result with log file path for all agents
const returnData: ClaudePromptResult = {
result,
success: true,
duration,
turns: turnCount,
cost: totalCost,
partialCost,
partialCost: totalCost,
apiErrorDetected
};
if (logFilePath) {
@@ -495,76 +277,14 @@ async function runClaudePrompt(
} catch (error) {
const duration = timer.stop();
const agentKey = description.toLowerCase().replace(/\s+/g, '-');
timingResults.agents[agentKey] = duration;
timingResults.agents[execContext.agentKey] = duration;
const err = error as Error & { code?: string; status?: number; duration?: number; cost?: number };
const err = error as Error & { code?: string; status?: number };
// Log error to audit system
if (auditSession) {
await auditSession.logEvent('error', {
message: err.message,
errorType: err.constructor.name,
stack: err.stack,
duration,
turns: turnCount,
timestamp: new Date().toISOString()
});
}
// Show error messages based on agent type
if (progressIndicator) {
progressIndicator.stop();
const agentType = description.includes('Pre-recon') ? 'Pre-recon analysis' :
description.includes('Recon') ? 'Reconnaissance' :
description.includes('Report') ? 'Report generation' : 'Analysis';
console.log(chalk.red(`${agentType} failed (${formatDuration(duration)})`));
} else if (isParallelExecution) {
const prefix = getAgentPrefix(description);
console.log(chalk.red(`${prefix} ❌ Failed (${formatDuration(duration)})`));
} else if (!useCleanOutput) {
console.log(chalk.red(` ❌ Claude Code failed: ${description} (${formatDuration(duration)})`));
}
console.log(chalk.red(` Error Type: ${err.constructor.name}`));
console.log(chalk.red(` Message: ${err.message}`));
console.log(chalk.gray(` Agent: ${description}`));
console.log(chalk.gray(` Working Directory: ${sourceDir}`));
console.log(chalk.gray(` Retryable: ${isRetryableError(err) ? 'Yes' : 'No'}`));
// Log additional context if available
if (err.code) {
console.log(chalk.gray(` Error Code: ${err.code}`));
}
if (err.status) {
console.log(chalk.gray(` HTTP Status: ${err.status}`));
}
// Save detailed error to log file for debugging
try {
const errorLog = {
timestamp: new Date().toISOString(),
agent: description,
error: {
name: err.constructor.name,
message: err.message,
code: err.code,
status: err.status,
stack: err.stack
},
context: {
sourceDir,
prompt: fullPrompt.slice(0, 200) + '...',
retryable: isRetryableError(err)
},
duration
};
const logPath = path.join(sourceDir, 'error.log');
await fs.appendFile(logPath, JSON.stringify(errorLog) + '\n');
} catch (logError) {
const logErrMsg = logError instanceof Error ? logError.message : String(logError);
console.log(chalk.gray(` (Failed to write error log: ${logErrMsg})`));
}
await auditLogger.logError(err, duration, turnCount);
progress.stop();
outputLines(formatErrorOutput(err, execContext, description, duration, sourceDir, isRetryableError(err)));
await writeErrorLog(err, sourceDir, fullPrompt, duration);
return {
error: err.message,
@@ -572,17 +292,97 @@ async function runClaudePrompt(
prompt: fullPrompt.slice(0, 100) + '...',
success: false,
duration,
cost: partialCost,
cost: totalCost,
retryable: isRetryableError(err)
};
}
}
// PREFERRED: Production-ready Claude agent execution with full orchestration
function buildLogFilePath(
sessionMetadata: SessionMetadata | null,
agentKey: string,
attemptNumber: number
): string | null {
if (!sessionMetadata || !sessionMetadata.webUrl || !sessionMetadata.id) {
return null;
}
const timestamp = formatTimestamp().replace(/T/, '_').replace(/[:.]/g, '-').slice(0, 19);
const logDir = generateSessionLogPath(sessionMetadata.webUrl, sessionMetadata.id);
return path.join(logDir, `${timestamp}_${agentKey}_attempt-${attemptNumber}.log`);
}
interface MessageLoopResult {
turnCount: number;
result: string | null;
apiErrorDetected: boolean;
cost: number;
}
interface MessageLoopDeps {
execContext: ReturnType<typeof detectExecutionContext>;
description: string;
colorFn: ChalkInstance;
progress: ReturnType<typeof createProgressManager>;
auditLogger: ReturnType<typeof createAuditLogger>;
}
async function processMessageStream(
fullPrompt: string,
options: NonNullable<Parameters<typeof query>[0]['options']>,
deps: MessageLoopDeps,
timer: Timer
): Promise<MessageLoopResult> {
const { execContext, description, colorFn, progress, auditLogger } = deps;
const HEARTBEAT_INTERVAL = 30000;
let turnCount = 0;
let result: string | null = null;
let apiErrorDetected = false;
let cost = 0;
let lastHeartbeat = Date.now();
for await (const message of query({ prompt: fullPrompt, options })) {
// Heartbeat logging when loader is disabled
const now = Date.now();
if (global.SHANNON_DISABLE_LOADER && now - lastHeartbeat > HEARTBEAT_INTERVAL) {
console.log(chalk.blue(` [${Math.floor((now - timer.startTime) / 1000)}s] ${description} running... (Turn ${turnCount})`));
lastHeartbeat = now;
}
// Increment turn count for assistant messages
if (message.type === 'assistant') {
turnCount++;
}
const dispatchResult = await dispatchMessage(
message as { type: string; subtype?: string },
turnCount,
{ execContext, description, colorFn, progress, auditLogger }
);
if (dispatchResult.type === 'throw') {
throw dispatchResult.error;
}
if (dispatchResult.type === 'complete') {
result = dispatchResult.result;
cost = dispatchResult.cost;
break;
}
if (dispatchResult.type === 'continue' && dispatchResult.apiErrorDetected) {
apiErrorDetected = true;
}
}
return { turnCount, result, apiErrorDetected, cost };
}
// Main entry point for agent execution. Handles retries, git checkpoints, and validation.
export async function runClaudePromptWithRetry(
prompt: string,
sourceDir: string,
allowedTools: string = 'Read',
_allowedTools: string = 'Read',
context: string = '',
description: string = 'Claude analysis',
agentName: string | null = null,
@@ -593,9 +393,8 @@ export async function runClaudePromptWithRetry(
let lastError: Error | undefined;
let retryContext = context;
console.log(chalk.cyan(`🚀 Starting ${description} with ${maxRetries} max attempts`));
console.log(chalk.cyan(`Starting ${description} with ${maxRetries} max attempts`));
// Initialize audit session (crash-safe logging)
let auditSession: AuditSession | null = null;
if (sessionMetadata && agentName) {
auditSession = new AuditSession(sessionMetadata);
@@ -603,29 +402,27 @@ export async function runClaudePromptWithRetry(
}
for (let attempt = 1; attempt <= maxRetries; attempt++) {
// Create checkpoint before each attempt
await createGitCheckpoint(sourceDir, description, attempt);
// Start agent tracking in audit system (saves prompt snapshot automatically)
if (auditSession && agentName) {
const fullPrompt = retryContext ? `${retryContext}\n\n${prompt}` : prompt;
await auditSession.startAgent(agentName, fullPrompt, attempt);
}
try {
const result = await runClaudePrompt(prompt, sourceDir, allowedTools, retryContext, description, agentName, colorFn, sessionMetadata, auditSession, attempt);
const result = await runClaudePrompt(
prompt, sourceDir, retryContext,
description, agentName, colorFn, sessionMetadata, auditSession, attempt
);
// Validate output after successful run
if (result.success) {
const validationPassed = await validateAgentOutput(result, agentName, sourceDir);
if (validationPassed) {
// Check if API error was detected but validation passed
if (result.apiErrorDetected) {
console.log(chalk.yellow(`📋 Validation: Ready for exploitation despite API error warnings`));
console.log(chalk.yellow(`Validation: Ready for exploitation despite API error warnings`));
}
// Record successful attempt in audit system
if (auditSession && agentName) {
const commitHash = await getGitCommitHash(sourceDir);
const endResult: {
@@ -646,15 +443,13 @@ export async function runClaudePromptWithRetry(
await auditSession.endAgent(agentName, endResult);
}
// Commit successful changes (will include the snapshot)
await commitGitSuccess(sourceDir, description);
console.log(chalk.green.bold(`🎉 ${description} completed successfully on attempt ${attempt}/${maxRetries}`));
console.log(chalk.green.bold(`${description} completed successfully on attempt ${attempt}/${maxRetries}`));
return result;
// Validation failure is retryable - agent might succeed on retry with cleaner workspace
} else {
// Agent completed but output validation failed
console.log(chalk.yellow(`⚠️ ${description} completed but output validation failed`));
console.log(chalk.yellow(`${description} completed but output validation failed`));
// Record failed validation attempt in audit system
if (auditSession && agentName) {
await auditSession.endAgent(agentName, {
attemptNumber: attempt,
@@ -666,20 +461,17 @@ export async function runClaudePromptWithRetry(
});
}
// If API error detected AND validation failed, this is a retryable error
if (result.apiErrorDetected) {
console.log(chalk.yellow(`⚠️ API Error detected with validation failure - treating as retryable`));
console.log(chalk.yellow(`API Error detected with validation failure - treating as retryable`));
lastError = new Error('API Error: terminated with validation failure');
} else {
lastError = new Error('Output validation failed');
}
if (attempt < maxRetries) {
// Rollback contaminated workspace
await rollbackGitWorkspace(sourceDir, 'validation failure');
continue;
} else {
// FAIL FAST - Don't continue with broken pipeline
throw new PentestError(
`Agent ${description} failed output validation after ${maxRetries} attempts. Required deliverable files were not created.`,
'validation',
@@ -694,7 +486,6 @@ export async function runClaudePromptWithRetry(
const err = error as Error & { duration?: number; cost?: number; partialResults?: unknown };
lastError = err;
// Record failed attempt in audit system
if (auditSession && agentName) {
await auditSession.endAgent(agentName, {
attemptNumber: attempt,
@@ -706,24 +497,21 @@ export async function runClaudePromptWithRetry(
});
}
// Check if error is retryable
if (!isRetryableError(err)) {
console.log(chalk.red(`${description} failed with non-retryable error: ${err.message}`));
console.log(chalk.red(`${description} failed with non-retryable error: ${err.message}`));
await rollbackGitWorkspace(sourceDir, 'non-retryable error cleanup');
throw err;
}
if (attempt < maxRetries) {
// Rollback for clean retry
await rollbackGitWorkspace(sourceDir, 'retryable error cleanup');
const delay = getRetryDelay(err, attempt);
const delaySeconds = (delay / 1000).toFixed(1);
console.log(chalk.yellow(`⚠️ ${description} failed (attempt ${attempt}/${maxRetries})`));
console.log(chalk.yellow(`${description} failed (attempt ${attempt}/${maxRetries})`));
console.log(chalk.gray(` Error: ${err.message}`));
console.log(chalk.gray(` Workspace rolled back, retrying in ${delaySeconds}s...`));
// Preserve any partial results for next retry
if (err.partialResults) {
retryContext = `${context}\n\nPrevious partial results: ${JSON.stringify(err.partialResults)}`;
}
@@ -731,7 +519,7 @@ export async function runClaudePromptWithRetry(
await new Promise(resolve => setTimeout(resolve, delay));
} else {
await rollbackGitWorkspace(sourceDir, 'final failure cleanup');
console.log(chalk.red(`${description} failed after ${maxRetries} attempts`));
console.log(chalk.red(`${description} failed after ${maxRetries} attempts`));
console.log(chalk.red(` Final error: ${err.message}`));
}
}
@@ -739,13 +527,3 @@ export async function runClaudePromptWithRetry(
throw lastError;
}
// Helper function to get git commit hash
async function getGitCommitHash(sourceDir: string): Promise<string | null> {
try {
const result = await $`cd ${sourceDir} && git rev-parse HEAD`;
return result.stdout.trim();
} catch {
return null;
}
}

244
src/ai/message-handlers.ts Normal file
View File

@@ -0,0 +1,244 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Pure functions for processing SDK message types
import { PentestError } from '../error-handling.js';
import { filterJsonToolCalls } from '../utils/output-formatter.js';
import { formatTimestamp } from '../utils/formatting.js';
import chalk from 'chalk';
import {
formatAssistantOutput,
formatResultOutput,
formatToolUseOutput,
formatToolResultOutput,
} from './output-formatters.js';
import { costResults } from '../utils/metrics.js';
import type { AuditLogger } from './audit-logger.js';
import type { ProgressManager } from './progress-manager.js';
import type {
AssistantMessage,
ResultMessage,
ToolUseMessage,
ToolResultMessage,
AssistantResult,
ResultData,
ToolUseData,
ToolResultData,
ApiErrorDetection,
ContentBlock,
SystemInitMessage,
ExecutionContext,
} from './types.js';
import type { ChalkInstance } from 'chalk';
// Handles both array and string content formats from SDK
export function extractMessageContent(message: AssistantMessage): string {
const messageContent = message.message;
if (Array.isArray(messageContent.content)) {
return messageContent.content
.map((c: ContentBlock) => c.text || JSON.stringify(c))
.join('\n');
}
return String(messageContent.content);
}
export function detectApiError(content: string): ApiErrorDetection {
if (!content || typeof content !== 'string') {
return { detected: false };
}
const lowerContent = content.toLowerCase();
// Fatal error - should throw immediately
if (lowerContent.includes('session limit reached')) {
return {
detected: true,
shouldThrow: new PentestError('Session limit reached', 'billing', false),
};
}
// Non-fatal API errors - detected but continue
if (lowerContent.includes('api error') || lowerContent.includes('terminated')) {
return { detected: true };
}
return { detected: false };
}
export function handleAssistantMessage(
message: AssistantMessage,
turnCount: number
): AssistantResult {
const content = extractMessageContent(message);
const cleanedContent = filterJsonToolCalls(content);
const errorDetection = detectApiError(content);
const result: AssistantResult = {
content,
cleanedContent,
apiErrorDetected: errorDetection.detected,
logData: {
turn: turnCount,
content,
timestamp: formatTimestamp(),
},
};
// Only add shouldThrow if it exists (exactOptionalPropertyTypes compliance)
if (errorDetection.shouldThrow) {
result.shouldThrow = errorDetection.shouldThrow;
}
return result;
}
// Final message of a query with cost/duration info
export function handleResultMessage(message: ResultMessage): ResultData {
const result: ResultData = {
result: message.result || null,
cost: message.total_cost_usd || 0,
duration_ms: message.duration_ms || 0,
permissionDenials: message.permission_denials?.length || 0,
};
// Only add subtype if it exists (exactOptionalPropertyTypes compliance)
if (message.subtype) {
result.subtype = message.subtype;
}
return result;
}
export function handleToolUseMessage(message: ToolUseMessage): ToolUseData {
return {
toolName: message.name,
parameters: message.input || {},
timestamp: formatTimestamp(),
};
}
// Truncates long results for display (500 char limit), preserves full content for logging
export function handleToolResultMessage(message: ToolResultMessage): ToolResultData {
const content = message.content;
const contentStr =
typeof content === 'string' ? content : JSON.stringify(content, null, 2);
const displayContent =
contentStr.length > 500
? `${contentStr.slice(0, 500)}...\n[Result truncated - ${contentStr.length} total chars]`
: contentStr;
return {
content,
displayContent,
timestamp: formatTimestamp(),
};
}
// Output helper for console logging
function outputLines(lines: string[]): void {
for (const line of lines) {
console.log(line);
}
}
// Message dispatch result types
export type MessageDispatchAction =
| { type: 'continue'; apiErrorDetected?: boolean }
| { type: 'complete'; result: string | null; cost: number }
| { type: 'throw'; error: Error };
export interface MessageDispatchDeps {
execContext: ExecutionContext;
description: string;
colorFn: ChalkInstance;
progress: ProgressManager;
auditLogger: AuditLogger;
}
// Dispatches SDK messages to appropriate handlers and formatters
export async function dispatchMessage(
message: { type: string; subtype?: string },
turnCount: number,
deps: MessageDispatchDeps
): Promise<MessageDispatchAction> {
const { execContext, description, colorFn, progress, auditLogger } = deps;
switch (message.type) {
case 'assistant': {
const assistantResult = handleAssistantMessage(message as AssistantMessage, turnCount);
if (assistantResult.shouldThrow) {
return { type: 'throw', error: assistantResult.shouldThrow };
}
if (assistantResult.cleanedContent.trim()) {
progress.stop();
outputLines(formatAssistantOutput(
assistantResult.cleanedContent,
execContext,
turnCount,
description,
colorFn
));
progress.start();
}
await auditLogger.logLlmResponse(turnCount, assistantResult.content);
if (assistantResult.apiErrorDetected) {
console.log(chalk.red(` API Error detected in assistant response`));
return { type: 'continue', apiErrorDetected: true };
}
return { type: 'continue' };
}
case 'system': {
if (message.subtype === 'init' && !execContext.useCleanOutput) {
const initMsg = message as SystemInitMessage;
console.log(chalk.blue(` Model: ${initMsg.model}, Permission: ${initMsg.permissionMode}`));
if (initMsg.mcp_servers && initMsg.mcp_servers.length > 0) {
const mcpStatus = initMsg.mcp_servers.map(s => `${s.name}(${s.status})`).join(', ');
console.log(chalk.blue(` MCP: ${mcpStatus}`));
}
}
return { type: 'continue' };
}
case 'user':
return { type: 'continue' };
case 'tool_use': {
const toolData = handleToolUseMessage(message as unknown as ToolUseMessage);
outputLines(formatToolUseOutput(toolData.toolName, toolData.parameters));
await auditLogger.logToolStart(toolData.toolName, toolData.parameters);
return { type: 'continue' };
}
case 'tool_result': {
const toolResultData = handleToolResultMessage(message as unknown as ToolResultMessage);
outputLines(formatToolResultOutput(toolResultData.displayContent));
await auditLogger.logToolEnd(toolResultData.content);
return { type: 'continue' };
}
case 'result': {
const resultData = handleResultMessage(message as ResultMessage);
outputLines(formatResultOutput(resultData, !execContext.useCleanOutput));
costResults.agents[execContext.agentKey] = resultData.cost;
costResults.total += resultData.cost;
return { type: 'complete', result: resultData.result, cost: resultData.cost };
}
default:
console.log(chalk.gray(` ${message.type}: ${JSON.stringify(message, null, 2)}`));
return { type: 'continue' };
}
}

169
src/ai/output-formatters.ts Normal file
View File

@@ -0,0 +1,169 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Pure functions for formatting console output
import chalk from 'chalk';
import { extractAgentType, formatDuration } from '../utils/formatting.js';
import { getAgentPrefix } from '../utils/output-formatter.js';
import type { ExecutionContext, ResultData } from './types.js';
export function detectExecutionContext(description: string): ExecutionContext {
const isParallelExecution =
description.includes('vuln agent') || description.includes('exploit agent');
const useCleanOutput =
description.includes('Pre-recon agent') ||
description.includes('Recon agent') ||
description.includes('Executive Summary and Report Cleanup') ||
description.includes('vuln agent') ||
description.includes('exploit agent');
const agentType = extractAgentType(description);
const agentKey = description.toLowerCase().replace(/\s+/g, '-');
return { isParallelExecution, useCleanOutput, agentType, agentKey };
}
export function formatAssistantOutput(
cleanedContent: string,
context: ExecutionContext,
turnCount: number,
description: string,
colorFn: typeof chalk.cyan = chalk.cyan
): string[] {
if (!cleanedContent.trim()) {
return [];
}
const lines: string[] = [];
if (context.isParallelExecution) {
// Compact output for parallel agents with prefixes
const prefix = getAgentPrefix(description);
lines.push(colorFn(`${prefix} ${cleanedContent}`));
} else {
// Full turn output for sequential agents
lines.push(colorFn(`\n Turn ${turnCount} (${description}):`));
lines.push(colorFn(` ${cleanedContent}`));
}
return lines;
}
export function formatResultOutput(data: ResultData, showFullResult: boolean): string[] {
const lines: string[] = [];
lines.push(chalk.magenta(`\n COMPLETED:`));
lines.push(
chalk.gray(
` Duration: ${(data.duration_ms / 1000).toFixed(1)}s, Cost: $${data.cost.toFixed(4)}`
)
);
if (data.subtype === 'error_max_turns') {
lines.push(chalk.red(` Stopped: Hit maximum turns limit`));
} else if (data.subtype === 'error_during_execution') {
lines.push(chalk.red(` Stopped: Execution error`));
}
if (data.permissionDenials > 0) {
lines.push(chalk.yellow(` ${data.permissionDenials} permission denials`));
}
if (showFullResult && data.result && typeof data.result === 'string') {
if (data.result.length > 1000) {
lines.push(chalk.magenta(` ${data.result.slice(0, 1000)}... [${data.result.length} total chars]`));
} else {
lines.push(chalk.magenta(` ${data.result}`));
}
}
return lines;
}
export function formatErrorOutput(
error: Error & { code?: string; status?: number },
context: ExecutionContext,
description: string,
duration: number,
sourceDir: string,
isRetryable: boolean
): string[] {
const lines: string[] = [];
if (context.isParallelExecution) {
const prefix = getAgentPrefix(description);
lines.push(chalk.red(`${prefix} Failed (${formatDuration(duration)})`));
} else if (context.useCleanOutput) {
lines.push(chalk.red(`${context.agentType} failed (${formatDuration(duration)})`));
} else {
lines.push(chalk.red(` Claude Code failed: ${description} (${formatDuration(duration)})`));
}
lines.push(chalk.red(` Error Type: ${error.constructor.name}`));
lines.push(chalk.red(` Message: ${error.message}`));
lines.push(chalk.gray(` Agent: ${description}`));
lines.push(chalk.gray(` Working Directory: ${sourceDir}`));
lines.push(chalk.gray(` Retryable: ${isRetryable ? 'Yes' : 'No'}`));
if (error.code) {
lines.push(chalk.gray(` Error Code: ${error.code}`));
}
if (error.status) {
lines.push(chalk.gray(` HTTP Status: ${error.status}`));
}
return lines;
}
export function formatCompletionMessage(
context: ExecutionContext,
description: string,
turnCount: number,
duration: number
): string {
if (context.isParallelExecution) {
const prefix = getAgentPrefix(description);
return chalk.green(`${prefix} Complete (${turnCount} turns, ${formatDuration(duration)})`);
}
if (context.useCleanOutput) {
return chalk.green(
`${context.agentType.charAt(0).toUpperCase() + context.agentType.slice(1)} complete! (${turnCount} turns, ${formatDuration(duration)})`
);
}
return chalk.green(
` Claude Code completed: ${description} (${turnCount} turns) in ${formatDuration(duration)}`
);
}
export function formatToolUseOutput(
toolName: string,
input: Record<string, unknown> | undefined
): string[] {
const lines: string[] = [];
lines.push(chalk.yellow(`\n Using Tool: ${toolName}`));
if (input && Object.keys(input).length > 0) {
lines.push(chalk.gray(` Input: ${JSON.stringify(input, null, 2)}`));
}
return lines;
}
export function formatToolResultOutput(displayContent: string): string[] {
const lines: string[] = [];
lines.push(chalk.green(` Tool Result:`));
if (displayContent) {
lines.push(chalk.gray(` ${displayContent}`));
}
return lines;
}

View File

@@ -0,0 +1,76 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Null Object pattern for progress indicator - callers never check for null
import { ProgressIndicator } from '../progress-indicator.js';
import { extractAgentType } from '../utils/formatting.js';
export interface ProgressContext {
description: string;
useCleanOutput: boolean;
}
export interface ProgressManager {
start(): void;
stop(): void;
finish(message: string): void;
isActive(): boolean;
}
class RealProgressManager implements ProgressManager {
private indicator: ProgressIndicator;
private active: boolean = false;
constructor(message: string) {
this.indicator = new ProgressIndicator(message);
}
start(): void {
this.indicator.start();
this.active = true;
}
stop(): void {
this.indicator.stop();
this.active = false;
}
finish(message: string): void {
this.indicator.finish(message);
this.active = false;
}
isActive(): boolean {
return this.active;
}
}
/** Null Object implementation - all methods are safe no-ops */
class NullProgressManager implements ProgressManager {
start(): void {}
stop(): void {}
finish(_message: string): void {}
isActive(): boolean {
return false;
}
}
// Returns no-op when disabled
export function createProgressManager(
context: ProgressContext,
disableLoader: boolean
): ProgressManager {
if (!context.useCleanOutput || disableLoader) {
return new NullProgressManager();
}
const agentType = extractAgentType(context.description);
return new RealProgressManager(`Running ${agentType}...`);
}

134
src/ai/types.ts Normal file
View File

@@ -0,0 +1,134 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
// Type definitions for Claude executor message processing pipeline
export interface ExecutionContext {
isParallelExecution: boolean;
useCleanOutput: boolean;
agentType: string;
agentKey: string;
}
export interface ProcessingState {
turnCount: number;
result: string | null;
apiErrorDetected: boolean;
totalCost: number;
partialCost: number;
lastHeartbeat: number;
}
export interface ProcessingResult {
result: string | null;
turnCount: number;
apiErrorDetected: boolean;
totalCost: number;
}
export interface AssistantResult {
content: string;
cleanedContent: string;
apiErrorDetected: boolean;
shouldThrow?: Error;
logData: {
turn: number;
content: string;
timestamp: string;
};
}
export interface ResultData {
result: string | null;
cost: number;
duration_ms: number;
subtype?: string;
permissionDenials: number;
}
export interface ToolUseData {
toolName: string;
parameters: Record<string, unknown>;
timestamp: string;
}
export interface ToolResultData {
content: unknown;
displayContent: string;
timestamp: string;
}
export interface ContentBlock {
type?: string;
text?: string;
}
export interface AssistantMessage {
type: 'assistant';
message: {
content: ContentBlock[] | string;
};
}
export interface ResultMessage {
type: 'result';
result?: string;
total_cost_usd?: number;
duration_ms?: number;
subtype?: string;
permission_denials?: unknown[];
}
export interface ToolUseMessage {
type: 'tool_use';
name: string;
input?: Record<string, unknown>;
}
export interface ToolResultMessage {
type: 'tool_result';
content?: unknown;
}
export interface ApiErrorDetection {
detected: boolean;
shouldThrow?: Error;
}
// Message types from SDK stream
export type SdkMessage =
| AssistantMessage
| ResultMessage
| ToolUseMessage
| ToolResultMessage
| SystemInitMessage
| UserMessage;
export interface SystemInitMessage {
type: 'system';
subtype: 'init';
model?: string;
permissionMode?: string;
mcp_servers?: Array<{ name: string; status: string }>;
}
export interface UserMessage {
type: 'user';
}
// Dispatch result types for message processing
export type MessageDispatchResult =
| { action: 'continue' }
| { action: 'break'; result: string | null; cost: number }
| { action: 'throw'; error: Error };
export interface MessageDispatchContext {
turnCount: number;
execContext: ExecutionContext;
description: string;
colorFn: (text: string) => string;
useCleanOutput: boolean;
}

View File

@@ -13,7 +13,8 @@
import { AgentLogger } from './logger.js';
import { MetricsTracker } from './metrics-tracker.js';
import { initializeAuditStructure, formatTimestamp, type SessionMetadata } from './utils.js';
import { initializeAuditStructure, type SessionMetadata } from './utils.js';
import { formatTimestamp } from '../utils/formatting.js';
import { SessionMutex } from '../utils/concurrency.js';
// Global mutex instance
@@ -145,7 +146,7 @@ export class AuditSession {
// Mutex-protected update to session.json
const unlock = await sessionMutex.lock(this.sessionId);
try {
// Reload metrics (in case of parallel updates)
// Reload inside mutex to prevent lost updates during parallel exploitation phase
await this.metricsTracker.reload();
// Update metrics

View File

@@ -15,10 +15,10 @@ import fs from 'fs';
import {
generateLogPath,
generatePromptPath,
atomicWrite,
formatTimestamp,
type SessionMetadata,
} from './utils.js';
import { atomicWrite } from '../utils/file-io.js';
import { formatTimestamp } from '../utils/formatting.js';
interface LogEvent {
type: string;
@@ -96,22 +96,13 @@ export class AgentLogger {
return;
}
// Write and flush immediately (crash-safe)
const needsDrain = !this.stream.write(text, 'utf8', (error) => {
if (error) {
reject(error);
}
if (error) reject(error);
});
if (needsDrain) {
// Buffer is full, wait for drain
const drainHandler = (): void => {
this.stream!.removeListener('drain', drainHandler);
resolve();
};
this.stream.once('drain', drainHandler);
this.stream.once('drain', resolve);
} else {
// Buffer has space, resolve immediately
resolve();
}
});

View File

@@ -13,13 +13,12 @@
import {
generateSessionJsonPath,
atomicWrite,
readJson,
fileExists,
formatTimestamp,
calculatePercentage,
type SessionMetadata,
} from './utils.js';
import { atomicWrite, readJson, fileExists } from '../utils/file-io.js';
import { formatTimestamp, calculatePercentage } from '../utils/formatting.js';
import { AGENT_PHASE_MAP, type PhaseName } from '../session-manager.js';
import type { AgentName } from '../types/index.js';
interface AttemptData {
attempt_number: number;
@@ -152,16 +151,14 @@ export class MetricsTracker {
}
// Initialize agent metrics if not exists
if (!this.data.metrics.agents[agentName]) {
this.data.metrics.agents[agentName] = {
status: 'in-progress',
attempts: [],
final_duration_ms: 0,
total_cost_usd: 0,
};
}
const agent = this.data.metrics.agents[agentName]!;
const existingAgent = this.data.metrics.agents[agentName];
const agent = existingAgent ?? {
status: 'in-progress' as const,
attempts: [],
final_duration_ms: 0,
total_cost_usd: 0,
};
this.data.metrics.agents[agentName] = agent;
// Add attempt to array
const attempt: AttemptData = {
@@ -255,36 +252,19 @@ export class MetricsTracker {
private calculatePhaseMetrics(
successfulAgents: Array<[string, AgentMetrics]>
): Record<string, PhaseMetrics> {
const phases: Record<string, AgentMetrics[]> = {
const phases: Record<PhaseName, AgentMetrics[]> = {
'pre-recon': [],
recon: [],
'recon': [],
'vulnerability-analysis': [],
exploitation: [],
reporting: [],
'exploitation': [],
'reporting': [],
};
// Map agents to phases
const agentPhaseMap: Record<string, string> = {
'pre-recon': 'pre-recon',
recon: 'recon',
'injection-vuln': 'vulnerability-analysis',
'xss-vuln': 'vulnerability-analysis',
'auth-vuln': 'vulnerability-analysis',
'authz-vuln': 'vulnerability-analysis',
'ssrf-vuln': 'vulnerability-analysis',
'injection-exploit': 'exploitation',
'xss-exploit': 'exploitation',
'auth-exploit': 'exploitation',
'authz-exploit': 'exploitation',
'ssrf-exploit': 'exploitation',
report: 'reporting',
};
// Group agents by phase
// Group agents by phase using imported AGENT_PHASE_MAP
for (const [agentName, agentData] of successfulAgents) {
const phase = agentPhaseMap[agentName];
if (phase && phases[phase]) {
phases[phase]!.push(agentData);
const phase = AGENT_PHASE_MAP[agentName as AgentName];
if (phase) {
phases[phase].push(agentData);
}
}
@@ -296,7 +276,6 @@ export class MetricsTracker {
if (agentList.length === 0) continue;
const phaseDuration = agentList.reduce((sum, agent) => sum + agent.final_duration_ms, 0);
const phaseCost = agentList.reduce((sum, agent) => sum + agent.total_cost_usd, 0);
phaseMetrics[phaseName] = {

View File

@@ -37,11 +37,11 @@ export class PentestError extends Error {
}
// Centralized error logging function
export const logError = async (
export async function logError(
error: Error & { type?: PentestErrorType; retryable?: boolean; context?: PentestErrorContext },
contextMsg: string,
sourceDir: string | null = null
): Promise<LogEntry> => {
): Promise<LogEntry> {
const timestamp = new Date().toISOString();
const logEntry: LogEntry = {
timestamp,
@@ -80,13 +80,13 @@ export const logError = async (
}
return logEntry;
};
}
// Handle tool execution errors
export const handleToolError = (
export function handleToolError(
toolName: string,
error: Error & { code?: string }
): ToolErrorResult => {
): ToolErrorResult {
const isRetryable =
error.code === 'ECONNRESET' ||
error.code === 'ETIMEDOUT' ||
@@ -105,13 +105,13 @@ export const handleToolError = (
{ toolName, originalError: error.message, errorCode: error.code }
),
};
};
}
// Handle prompt loading errors
export const handlePromptError = (
export function handlePromptError(
promptName: string,
error: Error
): PromptErrorResult => {
): PromptErrorResult {
return {
success: false,
error: new PentestError(
@@ -121,78 +121,63 @@ export const handlePromptError = (
{ promptName, originalError: error.message }
),
};
};
}
// Check if an error should trigger a retry for Claude agents
export const isRetryableError = (error: Error): boolean => {
// Patterns that indicate retryable errors
const RETRYABLE_PATTERNS = [
// Network and connection errors
'network',
'connection',
'timeout',
'econnreset',
'enotfound',
'econnrefused',
// Rate limiting
'rate limit',
'429',
'too many requests',
// Server errors
'server error',
'5xx',
'internal server error',
'service unavailable',
'bad gateway',
// Claude API errors
'mcp server',
'model unavailable',
'service temporarily unavailable',
'api error',
'terminated',
// Max turns
'max turns',
'maximum turns',
];
// Patterns that indicate non-retryable errors (checked before default)
const NON_RETRYABLE_PATTERNS = [
'authentication',
'invalid prompt',
'out of memory',
'permission denied',
'session limit reached',
'invalid api key',
];
// Conservative retry classification - unknown errors don't retry (fail-safe default)
export function isRetryableError(error: Error): boolean {
const message = error.message.toLowerCase();
// Network and connection errors - always retryable
if (
message.includes('network') ||
message.includes('connection') ||
message.includes('timeout') ||
message.includes('econnreset') ||
message.includes('enotfound') ||
message.includes('econnrefused')
) {
return true;
}
// Rate limiting - retryable with longer backoff
if (
message.includes('rate limit') ||
message.includes('429') ||
message.includes('too many requests')
) {
return true;
}
// Server errors - retryable
if (
message.includes('server error') ||
message.includes('5xx') ||
message.includes('internal server error') ||
message.includes('service unavailable') ||
message.includes('bad gateway')
) {
return true;
}
// Claude API specific errors - retryable
if (
message.includes('mcp server') ||
message.includes('model unavailable') ||
message.includes('service temporarily unavailable') ||
message.includes('api error') ||
message.includes('terminated')
) {
return true;
}
// Max turns without completion - retryable once
if (message.includes('max turns') || message.includes('maximum turns')) {
return true;
}
// Non-retryable errors
if (
message.includes('authentication') ||
message.includes('invalid prompt') ||
message.includes('out of memory') ||
message.includes('permission denied') ||
message.includes('session limit reached') ||
message.includes('invalid api key')
) {
// Check for explicit non-retryable patterns first
if (NON_RETRYABLE_PATTERNS.some((pattern) => message.includes(pattern))) {
return false;
}
// Default to non-retryable for unknown errors
return false;
};
// Check for retryable patterns
return RETRYABLE_PATTERNS.some((pattern) => message.includes(pattern));
}
// Get retry delay based on error type and attempt number
export const getRetryDelay = (error: Error, attempt: number): number => {
// Rate limit errors get longer base delay (30s) vs standard exponential backoff (2s)
export function getRetryDelay(error: Error, attempt: number): number {
const message = error.message.toLowerCase();
// Rate limiting gets longer delays
@@ -204,4 +189,4 @@ export const getRetryDelay = (error: Error, attempt: number): number => {
const baseDelay = Math.pow(2, attempt) * 1000; // 2s, 4s, 8s
const jitter = Math.random() * 1000; // 0-1s random
return Math.min(baseDelay + jitter, 30000); // Max 30s
};
}

View File

@@ -7,7 +7,7 @@
import { $, fs, path } from 'zx';
import chalk from 'chalk';
import { Timer } from '../utils/metrics.js';
import { formatDuration } from '../audit/utils.js';
import { formatDuration } from '../utils/formatting.js';
import { handleToolError, PentestError } from '../error-handling.js';
import { AGENTS } from '../session-manager.js';
import { runClaudePromptWithRetry } from '../ai/claude-executor.js';
@@ -40,11 +40,17 @@ interface PromptVariables {
repoPath: string;
}
// Discriminated union for Wave1 tool results - clearer than loose union types
type Wave1ToolResult =
| { kind: 'scan'; result: TerminalScanResult }
| { kind: 'skipped'; message: string }
| { kind: 'agent'; result: AgentResult };
interface Wave1Results {
nmap: TerminalScanResult | string | AgentResult;
subfinder: TerminalScanResult | string | AgentResult;
whatweb: TerminalScanResult | string | AgentResult;
naabu?: TerminalScanResult | string | AgentResult;
nmap: Wave1ToolResult;
subfinder: Wave1ToolResult;
whatweb: Wave1ToolResult;
naabu?: Wave1ToolResult;
codeAnalysis: AgentResult;
}
@@ -57,7 +63,7 @@ interface PreReconResult {
report: string;
}
// Pure function: Run terminal scanning tools
// Runs external security tools (nmap, whatweb, etc). Schemathesis requires schemas from code analysis.
async function runTerminalScan(tool: ToolName, target: string, sourceDir: string | null = null): Promise<TerminalScanResult> {
const timer = new Timer(`command-${tool}`);
try {
@@ -89,7 +95,7 @@ async function runTerminalScan(tool: ToolName, target: string, sourceDir: string
return { tool: 'whatweb', output: result.stdout, status: 'success', duration: whatwebDuration };
}
case 'schemathesis': {
// Only run if API schemas found
// Schemathesis depends on code analysis output - skip if no schemas found
const schemasDir = path.join(sourceDir || '.', 'outputs', 'schemas');
if (await fs.pathExists(schemasDir)) {
const schemaFiles = await fs.readdir(schemasDir) as string[];
@@ -146,6 +152,8 @@ async function runPreReconWave1(
const operations: Promise<TerminalScanResult | AgentResult>[] = [];
const skippedResult = (message: string): Wave1ToolResult => ({ kind: 'skipped', message });
// Skip external commands in pipeline testing mode
if (pipelineTestingMode) {
console.log(chalk.gray(' ⏭️ Skipping external tools (pipeline testing mode)'));
@@ -163,9 +171,9 @@ async function runPreReconWave1(
);
const [codeAnalysis] = await Promise.all(operations);
return {
nmap: 'Skipped (pipeline testing mode)',
subfinder: 'Skipped (pipeline testing mode)',
whatweb: 'Skipped (pipeline testing mode)',
nmap: skippedResult('Skipped (pipeline testing mode)'),
subfinder: skippedResult('Skipped (pipeline testing mode)'),
whatweb: skippedResult('Skipped (pipeline testing mode)'),
codeAnalysis: codeAnalysis as AgentResult
};
} else {
@@ -192,9 +200,9 @@ async function runPreReconWave1(
const [nmap, subfinder, whatweb, codeAnalysis] = await Promise.all(operations);
return {
nmap: nmap as TerminalScanResult,
subfinder: subfinder as TerminalScanResult,
whatweb: whatweb as TerminalScanResult,
nmap: { kind: 'scan', result: nmap as TerminalScanResult },
subfinder: { kind: 'scan', result: subfinder as TerminalScanResult },
whatweb: { kind: 'scan', result: whatweb as TerminalScanResult },
codeAnalysis: codeAnalysis as AgentResult
};
}
@@ -250,17 +258,21 @@ async function runPreReconWave2(
return response;
}
// Helper type for stitching results
interface StitchableResult {
status?: string;
output?: string;
tool?: string;
// Extracts status and output from a Wave1 tool result
function extractResult(r: Wave1ToolResult | undefined): { status: string; output: string } {
if (!r) return { status: 'Skipped', output: 'No output' };
switch (r.kind) {
case 'scan':
return { status: r.result.status || 'Skipped', output: r.result.output || 'No output' };
case 'skipped':
return { status: 'Skipped', output: r.message };
case 'agent':
return { status: r.result.success ? 'success' : 'error', output: 'See agent output' };
}
}
// Pure function: Stitch together pre-recon outputs and save to file
async function stitchPreReconOutputs(outputs: (StitchableResult | string | undefined)[], sourceDir: string): Promise<string> {
const [nmap, subfinder, whatweb, naabu, codeAnalysis, ...additionalScans] = outputs;
// Combines tool outputs into single deliverable. Falls back to reference if file missing.
async function stitchPreReconOutputs(wave1: Wave1Results, additionalScans: TerminalScanResult[], sourceDir: string): Promise<string> {
// Try to read the code analysis deliverable file
let codeAnalysisContent = 'No analysis available';
try {
@@ -269,62 +281,45 @@ async function stitchPreReconOutputs(outputs: (StitchableResult | string | undef
} catch (error) {
const err = error as Error;
console.log(chalk.yellow(`⚠️ Could not read code analysis deliverable: ${err.message}`));
// Fallback message if file doesn't exist
codeAnalysisContent = 'Analysis located in deliverables/code_analysis_deliverable.md';
}
// Build additional scans section
let additionalSection = '';
if (additionalScans && additionalScans.length > 0) {
if (additionalScans.length > 0) {
additionalSection = '\n## Authenticated Scans\n';
additionalScans.forEach(scan => {
const s = scan as StitchableResult;
if (s && s.tool) {
additionalSection += `
### ${s.tool.toUpperCase()}
Status: ${s.status}
${s.output}
for (const scan of additionalScans) {
additionalSection += `
### ${scan.tool.toUpperCase()}
Status: ${scan.status}
${scan.output}
`;
}
});
}
}
const nmapResult = nmap as StitchableResult | string | undefined;
const subfinderResult = subfinder as StitchableResult | string | undefined;
const whatwebResult = whatweb as StitchableResult | string | undefined;
const naabuResult = naabu as StitchableResult | string | undefined;
const getStatus = (r: StitchableResult | string | undefined): string => {
if (!r) return 'Skipped';
if (typeof r === 'string') return 'Skipped';
return r.status || 'Skipped';
};
const getOutput = (r: StitchableResult | string | undefined): string => {
if (!r) return 'No output';
if (typeof r === 'string') return r;
return r.output || 'No output';
};
const nmap = extractResult(wave1.nmap);
const subfinder = extractResult(wave1.subfinder);
const whatweb = extractResult(wave1.whatweb);
const naabu = extractResult(wave1.naabu);
const report = `
# Pre-Reconnaissance Report
## Port Discovery (naabu)
Status: ${getStatus(naabuResult)}
${getOutput(naabuResult)}
Status: ${naabu.status}
${naabu.output}
## Network Scanning (nmap)
Status: ${getStatus(nmapResult)}
${getOutput(nmapResult)}
Status: ${nmap.status}
${nmap.output}
## Subdomain Discovery (subfinder)
Status: ${getStatus(subfinderResult)}
${getOutput(subfinderResult)}
Status: ${subfinder.status}
${subfinder.output}
## Technology Detection (whatweb)
Status: ${getStatus(whatwebResult)}
${getOutput(whatwebResult)}
Status: ${whatweb.status}
${whatweb.output}
## Code Analysis
${codeAnalysisContent}
${additionalSection}
@@ -375,16 +370,8 @@ export async function executePreReconPhase(
console.log(chalk.green(' ✅ Wave 2 operations completed'));
console.log(chalk.blue('📝 Stitching pre-recon outputs...'));
// Combine wave 1 and wave 2 results for stitching
const allResults: (StitchableResult | string | undefined)[] = [
wave1Results.nmap as StitchableResult | string,
wave1Results.subfinder as StitchableResult | string,
wave1Results.whatweb as StitchableResult | string,
wave1Results.naabu as StitchableResult | string | undefined,
wave1Results.codeAnalysis as unknown as StitchableResult,
...(wave2Results.schemathesis ? [wave2Results.schemathesis as StitchableResult] : [])
];
const preReconReport = await stitchPreReconOutputs(allResults, sourceDir);
const additionalScans = wave2Results.schemathesis ? [wave2Results.schemathesis] : [];
const preReconReport = await stitchPreReconOutputs(wave1Results, additionalScans, sourceDir);
const duration = timer.stop();
console.log(chalk.green(`✅ Pre-reconnaissance complete in ${formatDuration(duration)}`));

View File

@@ -6,6 +6,7 @@
import { fs, path } from 'zx';
import { PentestError } from './error-handling.js';
import { asyncPipe } from './utils/functional.js';
export type VulnType = 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz';
@@ -16,9 +17,11 @@ interface VulnTypeConfigItem {
type VulnTypeConfig = Record<VulnType, VulnTypeConfigItem>;
type ErrorMessageResolver = string | ((existence: FileExistence) => string);
interface ValidationRule {
predicate: (existence: FileExistence) => boolean;
errorMessage: string;
errorMessage: ErrorMessageResolver;
retryable: boolean;
}
@@ -94,40 +97,36 @@ const VULN_TYPE_CONFIG: VulnTypeConfig = Object.freeze({
}),
}) as VulnTypeConfig;
// Functional composition utilities - async pipe for promise chain
type PipeFunction = (x: any) => any | Promise<any>;
const pipe =
(...fns: PipeFunction[]) =>
(x: any): Promise<any> =>
fns.reduce(async (v, f) => f(await v), Promise.resolve(x));
// Pure function to create validation rule
const createValidationRule = (
function createValidationRule(
predicate: (existence: FileExistence) => boolean,
errorMessage: string,
errorMessage: ErrorMessageResolver,
retryable: boolean = true
): ValidationRule => Object.freeze({ predicate, errorMessage, retryable });
): ValidationRule {
return Object.freeze({ predicate, errorMessage, retryable });
}
// Validation rules for file existence (following QUEUE_VALIDATION_FLOW.md)
// Symmetric deliverable rules: queue and deliverable must exist together (prevents partial analysis from triggering exploitation)
const fileExistenceRules: readonly ValidationRule[] = Object.freeze([
// Rule 1: Neither deliverable nor queue exists
createValidationRule(
({ deliverableExists, queueExists }) => deliverableExists || queueExists,
'Analysis failed: Neither deliverable nor queue file exists. Analysis agent must create both files.'
),
// Rule 2: Queue doesn't exist but deliverable exists
createValidationRule(
({ deliverableExists, queueExists }) => !(!queueExists && deliverableExists),
'Analysis incomplete: Deliverable exists but queue file missing. Analysis agent must create both files.'
),
// Rule 3: Queue exists but deliverable doesn't exist
createValidationRule(
({ deliverableExists, queueExists }) => !(queueExists && !deliverableExists),
'Analysis incomplete: Queue exists but deliverable file missing. Analysis agent must create both files.'
({ deliverableExists, queueExists }) => deliverableExists && queueExists,
getExistenceErrorMessage
),
]);
// Generate appropriate error message based on which files are missing
function getExistenceErrorMessage(existence: FileExistence): string {
const { deliverableExists, queueExists } = existence;
if (!deliverableExists && !queueExists) {
return 'Analysis failed: Neither deliverable nor queue file exists. Analysis agent must create both files.';
}
if (!queueExists) {
return 'Analysis incomplete: Deliverable exists but queue file missing. Analysis agent must create both files.';
}
return 'Analysis incomplete: Queue exists but deliverable file missing. Analysis agent must create both files.';
}
// Pure function to create file paths
const createPaths = (
vulnType: VulnType,
@@ -170,7 +169,7 @@ const checkFileExistence = async (
});
};
// Pure function to validate existence rules
// Validates deliverable/queue symmetry - both must exist or neither
const validateExistenceRules = (
pathsWithExistence: PathsWithExistence | PathsWithError
): PathsWithExistence | PathsWithError => {
@@ -182,9 +181,14 @@ const validateExistenceRules = (
const failedRule = fileExistenceRules.find((rule) => !rule.predicate(existence));
if (failedRule) {
const message =
typeof failedRule.errorMessage === 'function'
? failedRule.errorMessage(existence)
: failedRule.errorMessage;
return {
error: new PentestError(
`${failedRule.errorMessage} (${vulnType})`,
`${message} (${vulnType})`,
'validation',
failedRule.retryable,
{
@@ -224,7 +228,7 @@ const validateQueueStructure = (content: string): QueueValidationResult => {
}
};
// Pure function to read and validate queue content
// Queue parse failures are retryable - agent can fix malformed JSON on retry
const validateQueueContent = async (
pathsWithExistence: PathsWithExistence | PathsWithError
): Promise<PathsWithQueue | PathsWithError> => {
@@ -273,7 +277,7 @@ const validateQueueContent = async (
}
};
// Pure function to determine exploitation decision
// Final decision: skip if queue says no vulns, proceed if vulns found, error otherwise
const determineExploitationDecision = (
validatedData: PathsWithQueue | PathsWithError
): ExploitationDecision => {
@@ -294,17 +298,18 @@ const determineExploitationDecision = (
};
// Main functional validation pipeline
export const validateQueueAndDeliverable = async (
export async function validateQueueAndDeliverable(
vulnType: VulnType,
sourceDir: string
): Promise<ExploitationDecision> =>
(await pipe(
() => createPaths(vulnType, sourceDir),
): Promise<ExploitationDecision> {
return asyncPipe<ExploitationDecision>(
createPaths(vulnType, sourceDir),
checkFileExistence,
validateExistenceRules,
validateQueueContent,
determineExploitationDecision
)(() => createPaths(vulnType, sourceDir))) as ExploitationDecision;
);
}
// Pure function to safely validate (returns result instead of throwing)
export const safeValidateQueueAndDeliverable = async (

View File

@@ -106,6 +106,26 @@ export const getParallelGroups = (): Readonly<{ vuln: AgentName[]; exploit: Agen
exploit: ['injection-exploit', 'xss-exploit', 'auth-exploit', 'ssrf-exploit', 'authz-exploit']
});
// Phase names for metrics aggregation
export type PhaseName = 'pre-recon' | 'recon' | 'vulnerability-analysis' | 'exploitation' | 'reporting';
// Map agents to their corresponding phases (single source of truth)
export const AGENT_PHASE_MAP: Readonly<Record<AgentName, PhaseName>> = Object.freeze({
'pre-recon': 'pre-recon',
'recon': 'recon',
'injection-vuln': 'vulnerability-analysis',
'xss-vuln': 'vulnerability-analysis',
'auth-vuln': 'vulnerability-analysis',
'authz-vuln': 'vulnerability-analysis',
'ssrf-vuln': 'vulnerability-analysis',
'injection-exploit': 'exploitation',
'xss-exploit': 'exploitation',
'auth-exploit': 'exploitation',
'authz-exploit': 'exploitation',
'ssrf-exploit': 'exploitation',
'report': 'reporting',
});
// Generate a session-based log folder path (used by claude-executor.ts)
export const generateSessionLogPath = (webUrl: string, sessionId: string): string => {
const hostname = new URL(webUrl).hostname.replace(/[^a-zA-Z0-9-]/g, '-');

View File

@@ -17,6 +17,7 @@ import { checkToolAvailability, handleMissingTools } from './tool-checker.js';
// Session
import { AGENTS, getParallelGroups } from './session-manager.js';
import { getPromptNameForAgent } from './types/agents.js';
import type { AgentName, PromptName } from './types/index.js';
// Setup and Deliverables
@@ -32,7 +33,8 @@ import { assembleFinalReport } from './phases/reporting.js';
// Utils
import { timingResults, displayTimingSummary, Timer } from './utils/metrics.js';
import { formatDuration, generateAuditPath } from './audit/utils.js';
import { formatDuration } from './utils/formatting.js';
import { generateAuditPath } from './audit/utils.js';
import type { SessionMetadata } from './audit/utils.js';
import { AuditSession } from './audit/audit-session.js';
@@ -86,6 +88,7 @@ async function saveSessions(store: SessionStore): Promise<void> {
await fs.writeJson(STORE_PATH, store, { spaces: 2 });
}
// Session prevents concurrent runs on same repo - different repos can run in parallel
async function createSession(webUrl: string, repoPath: string): Promise<Session> {
const store = await loadSessions();
@@ -155,32 +158,26 @@ interface ParallelAgentResult {
error?: string | undefined;
}
type VulnType = 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz';
interface ParallelAgentConfig {
phaseType: 'vuln' | 'exploit';
headerText: string;
specialistLabel: string;
}
interface AgentExecutionContext {
sourceDir: string;
variables: PromptVariables;
distributedConfig: DistributedConfig | null;
pipelineTestingMode: boolean;
sessionMetadata: SessionMetadata;
}
// Configure zx to disable timeouts (let tools run as long as needed)
$.timeout = 0;
// Helper function to get prompt name from agent name
const getPromptName = (agentName: AgentName): PromptName => {
const mappings: Record<AgentName, PromptName> = {
'pre-recon': 'pre-recon-code',
'recon': 'recon',
'injection-vuln': 'vuln-injection',
'xss-vuln': 'vuln-xss',
'auth-vuln': 'vuln-auth',
'ssrf-vuln': 'vuln-ssrf',
'authz-vuln': 'vuln-authz',
'injection-exploit': 'exploit-injection',
'xss-exploit': 'exploit-xss',
'auth-exploit': 'exploit-auth',
'ssrf-exploit': 'exploit-ssrf',
'authz-exploit': 'exploit-authz',
'report': 'report-executive'
};
return mappings[agentName] || agentName as PromptName;
};
// Get color function for agent
const getAgentColor = (agentName: AgentName): ChalkInstance => {
function getAgentColor(agentName: AgentName): ChalkInstance {
const colorMap: Partial<Record<AgentName, ChalkInstance>> = {
'injection-vuln': chalk.red,
'injection-exploit': chalk.red,
@@ -194,11 +191,9 @@ const getAgentColor = (agentName: AgentName): ChalkInstance => {
'authz-exploit': chalk.green
};
return colorMap[agentName] || chalk.cyan;
};
}
/**
* Consolidate deliverables from target repo into the session folder
*/
// Non-fatal copy - failure logs warning but doesn't halt pipeline
async function consolidateOutputs(sourceDir: string, sessionPath: string): Promise<void> {
const srcDeliverables = path.join(sourceDir, 'deliverables');
const destDeliverables = path.join(sessionPath, 'deliverables');
@@ -228,7 +223,7 @@ async function runAgent(
sessionMetadata: SessionMetadata
): Promise<AgentResult> {
const agent = AGENTS[agentName];
const promptName = getPromptName(agentName);
const promptName = getPromptNameForAgent(agentName);
const prompt = await loadPrompt(promptName, variables, distributedConfig, pipelineTestingMode);
return await runClaudePromptWithRetry(
@@ -244,85 +239,68 @@ async function runAgent(
}
/**
* Run vulnerability agents in parallel
* Execute a single agent with retry logic
*/
async function runParallelVuln(
sourceDir: string,
variables: PromptVariables,
distributedConfig: DistributedConfig | null,
pipelineTestingMode: boolean,
sessionMetadata: SessionMetadata
): Promise<ParallelAgentResult[]> {
const { vuln: vulnAgents } = getParallelGroups();
async function executeAgentWithRetry(
agentName: AgentName,
context: AgentExecutionContext,
onSuccess?: (agentName: AgentName) => Promise<void>
): Promise<ParallelAgentResult> {
const { sourceDir, variables, distributedConfig, pipelineTestingMode, sessionMetadata } = context;
const maxAttempts = 3;
let lastError: Error | undefined;
let attempts = 0;
console.log(chalk.cyan(`\nStarting ${vulnAgents.length} vulnerability analysis specialists in parallel...`));
console.log(chalk.gray(' Specialists: ' + vulnAgents.join(', ')));
console.log();
while (attempts < maxAttempts) {
attempts++;
try {
const result = await runAgent(
agentName,
sourceDir,
variables,
distributedConfig,
pipelineTestingMode,
sessionMetadata
);
const startTime = Date.now();
const results = await Promise.allSettled(
vulnAgents.map(async (agentName, index) => {
// Add 2-second stagger to prevent API overwhelm
await new Promise(resolve => setTimeout(resolve, index * 2000));
let lastError: Error | undefined;
let attempts = 0;
const maxAttempts = 3;
while (attempts < maxAttempts) {
attempts++;
try {
const result = await runAgent(
agentName,
sourceDir,
variables,
distributedConfig,
pipelineTestingMode,
sessionMetadata
);
// Validate vulnerability analysis results
const vulnType = agentName.replace('-vuln', '');
try {
const validation = await safeValidateQueueAndDeliverable(vulnType as 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz', sourceDir);
if (validation.success && validation.data) {
console.log(chalk.blue(`${agentName}: ${validation.data.shouldExploit ? `Ready for exploitation (${validation.data.vulnerabilityCount} vulnerabilities)` : 'No vulnerabilities found'}`));
}
} catch {
// Validation failure is non-critical
}
return {
agentName,
success: result.success,
timing: result.duration,
cost: result.cost,
attempts
};
} catch (error) {
lastError = error as Error;
if (attempts < maxAttempts) {
console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`));
await new Promise(resolve => setTimeout(resolve, 5000));
}
}
if (onSuccess) {
await onSuccess(agentName);
}
return {
agentName,
success: false,
attempts,
error: lastError?.message || 'Unknown error'
success: result.success,
timing: result.duration,
cost: result.cost,
attempts
};
})
);
} catch (error) {
lastError = error as Error;
if (attempts < maxAttempts) {
console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`));
await new Promise(resolve => setTimeout(resolve, 5000));
}
}
}
const totalDuration = Date.now() - startTime;
return {
agentName,
success: false,
attempts,
error: lastError?.message || 'Unknown error'
};
}
// Process and display results
console.log(chalk.cyan('\nVulnerability Analysis Results'));
/**
* Display results table for parallel agent execution
*/
function displayParallelResults(
results: PromiseSettledResult<ParallelAgentResult>[],
agents: AgentName[],
headerText: string,
totalDuration: number
): ParallelAgentResult[] {
console.log(chalk.cyan(`\n${headerText}`));
console.log(chalk.gray('-'.repeat(80)));
console.log(chalk.bold('Agent Status Attempt Duration Cost'));
console.log(chalk.gray('-'.repeat(80)));
@@ -330,7 +308,7 @@ async function runParallelVuln(
const processedResults: ParallelAgentResult[] = [];
results.forEach((result, index) => {
const agentName = vulnAgents[index]!;
const agentName = agents[index]!;
const agentDisplay = agentName.padEnd(22);
if (result.status === 'fulfilled') {
@@ -371,159 +349,90 @@ async function runParallelVuln(
console.log(chalk.gray('-'.repeat(80)));
const successCount = processedResults.filter(r => r.success).length;
console.log(chalk.cyan(`Summary: ${successCount}/${vulnAgents.length} succeeded in ${formatDuration(totalDuration)}`));
console.log(chalk.cyan(`Summary: ${successCount}/${agents.length} succeeded in ${formatDuration(totalDuration)}`));
return processedResults;
}
/**
* Run exploitation agents in parallel
* Run agents in parallel with retry logic and result display
*/
async function runParallelExploit(
sourceDir: string,
variables: PromptVariables,
distributedConfig: DistributedConfig | null,
pipelineTestingMode: boolean,
sessionMetadata: SessionMetadata
async function runParallelAgents(
context: AgentExecutionContext,
config: ParallelAgentConfig
): Promise<ParallelAgentResult[]> {
const { exploit: exploitAgents, vuln: vulnAgents } = getParallelGroups();
const { sourceDir } = context;
const { phaseType, headerText, specialistLabel } = config;
const parallelGroups = getParallelGroups();
const allAgents = parallelGroups[phaseType];
// Load validation module
const { safeValidateQueueAndDeliverable } = await import('./queue-validation.js');
// For exploit phase, filter to only eligible agents
let agents: AgentName[];
if (phaseType === 'exploit') {
const eligibilityChecks = await Promise.all(
allAgents.map(async (agentName) => {
const vulnAgentName = agentName.replace('-exploit', '-vuln') as AgentName;
const vulnType = vulnAgentName.replace('-vuln', '') as VulnType;
// Check eligibility
const eligibilityChecks = await Promise.all(
exploitAgents.map(async (agentName) => {
const vulnAgentName = agentName.replace('-exploit', '-vuln') as AgentName;
const vulnType = vulnAgentName.replace('-vuln', '') as 'injection' | 'xss' | 'auth' | 'ssrf' | 'authz';
const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir);
const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir);
if (!validation.success || !validation.data?.shouldExploit) {
console.log(chalk.gray(`Skipping ${agentName} (no vulnerabilities found in ${vulnAgentName})`));
return { agentName, eligible: false };
}
if (!validation.success || !validation.data?.shouldExploit) {
console.log(chalk.gray(`Skipping ${agentName} (no vulnerabilities found in ${vulnAgentName})`));
return { agentName, eligible: false };
}
console.log(chalk.blue(`${agentName} eligible (${validation.data.vulnerabilityCount} vulnerabilities from ${vulnAgentName})`));
return { agentName, eligible: true };
})
);
console.log(chalk.blue(`${agentName} eligible (${validation.data.vulnerabilityCount} vulnerabilities from ${vulnAgentName})`));
return { agentName, eligible: true };
})
);
agents = eligibilityChecks
.filter(check => check.eligible)
.map(check => check.agentName);
const eligibleAgents = eligibilityChecks
.filter(check => check.eligible)
.map(check => check.agentName);
if (eligibleAgents.length === 0) {
console.log(chalk.gray('No exploitation agents eligible (no vulnerabilities found)'));
return [];
if (agents.length === 0) {
console.log(chalk.gray('No exploitation agents eligible (no vulnerabilities found)'));
return [];
}
} else {
agents = allAgents;
}
console.log(chalk.cyan(`\nStarting ${eligibleAgents.length} exploitation specialists in parallel...`));
console.log(chalk.gray(' Specialists: ' + eligibleAgents.join(', ')));
console.log(chalk.cyan(`\nStarting ${agents.length} ${specialistLabel} in parallel...`));
console.log(chalk.gray(' Specialists: ' + agents.join(', ')));
console.log();
const startTime = Date.now();
const results = await Promise.allSettled(
eligibleAgents.map(async (agentName, index) => {
await new Promise(resolve => setTimeout(resolve, index * 2000));
let lastError: Error | undefined;
let attempts = 0;
const maxAttempts = 3;
while (attempts < maxAttempts) {
attempts++;
// Build onSuccess callback for vuln phase (validation logging)
const onSuccess = phaseType === 'vuln'
? async (agentName: AgentName): Promise<void> => {
const vulnType = agentName.replace('-vuln', '') as VulnType;
try {
const result = await runAgent(
agentName,
sourceDir,
variables,
distributedConfig,
pipelineTestingMode,
sessionMetadata
);
return {
agentName,
success: result.success,
timing: result.duration,
cost: result.cost,
attempts
};
} catch (error) {
lastError = error as Error;
if (attempts < maxAttempts) {
console.log(chalk.yellow(`Warning: ${agentName} failed attempt ${attempts}/${maxAttempts}, retrying...`));
await new Promise(resolve => setTimeout(resolve, 5000));
const validation = await safeValidateQueueAndDeliverable(vulnType, sourceDir);
if (validation.success && validation.data) {
const message = validation.data.shouldExploit
? `Ready for exploitation (${validation.data.vulnerabilityCount} vulnerabilities)`
: 'No vulnerabilities found';
console.log(chalk.blue(`${agentName}: ${message}`));
}
} catch {
// Validation failure is non-critical
}
}
: undefined;
return {
agentName,
success: false,
attempts,
error: lastError?.message || 'Unknown error'
};
const results = await Promise.allSettled(
agents.map(async (agentName, index) => {
// Add 2-second stagger to prevent API overwhelm
await new Promise(resolve => setTimeout(resolve, index * 2000));
return executeAgentWithRetry(agentName, context, onSuccess);
})
);
const totalDuration = Date.now() - startTime;
// Process and display results
console.log(chalk.cyan('\nExploitation Results'));
console.log(chalk.gray('-'.repeat(80)));
console.log(chalk.bold('Agent Status Attempt Duration Cost'));
console.log(chalk.gray('-'.repeat(80)));
const processedResults: ParallelAgentResult[] = [];
results.forEach((result, index) => {
const agentName = eligibleAgents[index]!;
const agentDisplay = agentName.padEnd(22);
if (result.status === 'fulfilled') {
const data = result.value;
processedResults.push(data);
if (data.success) {
const duration = formatDuration(data.timing || 0);
const cost = `$${(data.cost || 0).toFixed(4)}`;
console.log(
`${chalk.green(agentDisplay)} ${chalk.green('Success')} ` +
`${data.attempts}/3 ${duration.padEnd(11)} ${cost}`
);
} else {
console.log(
`${chalk.red(agentDisplay)} ${chalk.red('Failed ')} ` +
`${data.attempts}/3 - -`
);
if (data.error) {
console.log(chalk.gray(` Error: ${data.error.substring(0, 60)}...`));
}
}
} else {
processedResults.push({
agentName,
success: false,
attempts: 3,
error: String(result.reason)
});
console.log(
`${chalk.red(agentDisplay)} ${chalk.red('Failed ')} ` +
`3/3 - -`
);
}
});
console.log(chalk.gray('-'.repeat(80)));
const successCount = processedResults.filter(r => r.success).length;
console.log(chalk.cyan(`Summary: ${successCount}/${eligibleAgents.length} succeeded in ${formatDuration(totalDuration)}`));
return processedResults;
return displayParallelResults(results, agents, headerText, totalDuration);
}
// Setup graceful cleanup on process signals
@@ -677,13 +586,19 @@ async function main(
const vulnTimer = new Timer('phase-3-vulnerability-analysis');
console.log(chalk.red.bold('\n🚨 PHASE 3: VULNERABILITY ANALYSIS'));
const vulnResults = await runParallelVuln(
const executionContext: AgentExecutionContext = {
sourceDir,
variables,
distributedConfig,
pipelineTestingMode,
sessionMetadata
);
};
const vulnResults = await runParallelAgents(executionContext, {
phaseType: 'vuln',
headerText: 'Vulnerability Analysis Results',
specialistLabel: 'vulnerability analysis specialists'
});
const vulnDuration = vulnTimer.stop();
console.log(chalk.green(`✅ Vulnerability analysis phase complete in ${formatDuration(vulnDuration)}`));
@@ -692,13 +607,11 @@ async function main(
const exploitTimer = new Timer('phase-4-exploitation');
console.log(chalk.red.bold('\n💥 PHASE 4: EXPLOITATION'));
const exploitResults = await runParallelExploit(
sourceDir,
variables,
distributedConfig,
pipelineTestingMode,
sessionMetadata
);
const exploitResults = await runParallelAgents(executionContext, {
phaseType: 'exploit',
headerText: 'Exploitation Results',
specialistLabel: 'exploitation specialists'
});
const exploitDuration = exploitTimer.stop();
console.log(chalk.green(`✅ Exploitation phase complete in ${formatDuration(exploitDuration)}`));

View File

@@ -47,10 +47,6 @@ export type PlaywrightAgent =
export type AgentValidator = (sourceDir: string) => Promise<boolean>;
export type AgentValidatorMap = Record<AgentName, AgentValidator>;
export type McpAgentMapping = Record<PromptName, PlaywrightAgent>;
export type AgentStatus =
| 'pending'
| 'in_progress'
@@ -63,3 +59,26 @@ export interface AgentDefinition {
displayName: string;
prerequisites: AgentName[];
}
/**
* Maps an agent name to its corresponding prompt file name.
*/
export function getPromptNameForAgent(agentName: AgentName): PromptName {
const mappings: Record<AgentName, PromptName> = {
'pre-recon': 'pre-recon-code',
'recon': 'recon',
'injection-vuln': 'vuln-injection',
'xss-vuln': 'vuln-xss',
'auth-vuln': 'vuln-auth',
'ssrf-vuln': 'vuln-ssrf',
'authz-vuln': 'vuln-authz',
'injection-exploit': 'exploit-injection',
'xss-exploit': 'exploit-xss',
'auth-exploit': 'exploit-auth',
'ssrf-exploit': 'exploit-ssrf',
'authz-exploit': 'exploit-authz',
'report': 'report-executive',
};
return mappings[agentName];
}

View File

@@ -31,13 +31,12 @@ type UnlockFunction = () => void;
* }
* ```
*/
// Promise-based mutex with queue semantics - safe for parallel agents on same session
export class SessionMutex {
// Map of sessionId -> Promise (represents active lock)
private locks: Map<string, Promise<void>> = new Map();
/**
* Acquire lock for a session
*/
// Wait for existing lock, then acquire. Queue ensures FIFO ordering.
async lock(sessionId: string): Promise<UnlockFunction> {
if (this.locks.has(sessionId)) {
// Wait for existing lock to be released

73
src/utils/file-io.ts Normal file
View File

@@ -0,0 +1,73 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
/**
* File I/O Utilities
*
* Core utility functions for file operations including atomic writes,
* directory creation, and JSON file handling.
*/
import fs from 'fs/promises';
/**
* Ensure directory exists (idempotent, race-safe)
*/
export async function ensureDirectory(dirPath: string): Promise<void> {
try {
await fs.mkdir(dirPath, { recursive: true });
} catch (error) {
// Ignore EEXIST errors (race condition safe)
if ((error as NodeJS.ErrnoException).code !== 'EEXIST') {
throw error;
}
}
}
/**
* Atomic write using temp file + rename pattern
* Guarantees no partial writes or corruption on crash
*/
export async function atomicWrite(filePath: string, data: object | string): Promise<void> {
const tempPath = `${filePath}.tmp`;
const content = typeof data === 'string' ? data : JSON.stringify(data, null, 2);
try {
// Write to temp file
await fs.writeFile(tempPath, content, 'utf8');
// Atomic rename (POSIX guarantee: atomic on same filesystem)
await fs.rename(tempPath, filePath);
} catch (error) {
// Clean up temp file on failure
try {
await fs.unlink(tempPath);
} catch {
// Ignore cleanup errors
}
throw error;
}
}
/**
* Read and parse JSON file
*/
export async function readJson<T = unknown>(filePath: string): Promise<T> {
const content = await fs.readFile(filePath, 'utf8');
return JSON.parse(content) as T;
}
/**
* Check if file exists
*/
export async function fileExists(filePath: string): Promise<boolean> {
try {
await fs.access(filePath);
return true;
} catch {
return false;
}
}

60
src/utils/formatting.ts Normal file
View File

@@ -0,0 +1,60 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
/**
* Formatting Utilities
*
* Generic formatting functions for durations, timestamps, and percentages.
*/
/**
* Format duration in milliseconds to human-readable string
*/
export function formatDuration(ms: number): string {
if (ms < 1000) {
return `${ms}ms`;
}
const seconds = ms / 1000;
if (seconds < 60) {
return `${seconds.toFixed(1)}s`;
}
const minutes = Math.floor(seconds / 60);
const remainingSeconds = Math.floor(seconds % 60);
return `${minutes}m ${remainingSeconds}s`;
}
/**
* Format timestamp to ISO 8601 string
*/
export function formatTimestamp(timestamp: number = Date.now()): string {
return new Date(timestamp).toISOString();
}
/**
* Calculate percentage
*/
export function calculatePercentage(part: number, total: number): number {
if (total === 0) return 0;
return (part / total) * 100;
}
/**
* Extract agent type from description string for display purposes
*/
export function extractAgentType(description: string): string {
if (description.includes('Pre-recon')) {
return 'pre-reconnaissance';
}
if (description.includes('Recon')) {
return 'reconnaissance';
}
if (description.includes('Report')) {
return 'report generation';
}
return 'analysis';
}

29
src/utils/functional.ts Normal file
View File

@@ -0,0 +1,29 @@
// Copyright (C) 2025 Keygraph, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License version 3
// as published by the Free Software Foundation.
/**
* Functional Programming Utilities
*
* Generic functional composition patterns for async operations.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type PipelineFunction = (x: any) => any | Promise<any>;
/**
* Async pipeline that passes result through a series of functions.
* Clearer than reduce-based pipe and easier to debug.
*/
export async function asyncPipe<TResult>(
initial: unknown,
...fns: PipelineFunction[]
): Promise<TResult> {
let result = initial;
for (const fn of fns) {
result = await fn(result);
}
return result as TResult;
}

View File

@@ -13,7 +13,57 @@ interface GitOperationResult {
error?: Error;
}
// Global git operations semaphore to prevent index.lock conflicts during parallel execution
/**
* Get list of changed files from git status --porcelain output
*/
async function getChangedFiles(
sourceDir: string,
operationDescription: string
): Promise<string[]> {
const status = await executeGitCommandWithRetry(
['git', 'status', '--porcelain'],
sourceDir,
operationDescription
);
return status.stdout
.trim()
.split('\n')
.filter((line) => line.length > 0);
}
/**
* Log a summary of changed files with truncation for long lists
*/
function logChangeSummary(
changes: string[],
messageWithChanges: string,
messageWithoutChanges: string,
color: typeof chalk.green,
maxToShow: number = 5
): void {
if (changes.length > 0) {
console.log(color(messageWithChanges.replace('{count}', String(changes.length))));
changes.slice(0, maxToShow).forEach((change) => console.log(chalk.gray(` ${change}`)));
if (changes.length > maxToShow) {
console.log(chalk.gray(` ... and ${changes.length - maxToShow} more files`));
}
} else {
console.log(color(messageWithoutChanges));
}
}
/**
* Convert unknown error to GitOperationResult
*/
function toErrorResult(error: unknown): GitOperationResult {
const errMsg = error instanceof Error ? error.message : String(error);
return {
success: false,
error: error instanceof Error ? error : new Error(errMsg),
};
}
// Serializes git operations to prevent index.lock conflicts during parallel agent execution
class GitSemaphore {
private queue: Array<() => void> = [];
private running: boolean = false;
@@ -41,33 +91,38 @@ class GitSemaphore {
const gitSemaphore = new GitSemaphore();
// Execute git commands with retry logic for index.lock conflicts
export const executeGitCommandWithRetry = async (
const GIT_LOCK_ERROR_PATTERNS = [
'index.lock',
'unable to lock',
'Another git process',
'fatal: Unable to create',
'fatal: index file',
];
function isGitLockError(errorMessage: string): boolean {
return GIT_LOCK_ERROR_PATTERNS.some((pattern) => errorMessage.includes(pattern));
}
// Retries git commands on lock conflicts with exponential backoff
export async function executeGitCommandWithRetry(
commandArgs: string[],
sourceDir: string,
description: string,
maxRetries: number = 5
): Promise<{ stdout: string; stderr: string }> => {
): Promise<{ stdout: string; stderr: string }> {
await gitSemaphore.acquire();
try {
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
// For arrays like ['git', 'status', '--porcelain'], execute parts separately
const [cmd, ...args] = commandArgs;
const result = await $`cd ${sourceDir} && ${cmd} ${args}`;
return result;
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
const isLockError =
errMsg.includes('index.lock') ||
errMsg.includes('unable to lock') ||
errMsg.includes('Another git process') ||
errMsg.includes('fatal: Unable to create') ||
errMsg.includes('fatal: index file');
if (isLockError && attempt < maxRetries) {
const delay = Math.pow(2, attempt - 1) * 1000; // Exponential backoff: 1s, 2s, 4s, 8s, 16s
if (isGitLockError(errMsg) && attempt < maxRetries) {
const delay = Math.pow(2, attempt - 1) * 1000;
console.log(
chalk.yellow(
` ⚠️ Git lock conflict during ${description} (attempt ${attempt}/${maxRetries}). Retrying in ${delay}ms...`
@@ -80,84 +135,69 @@ export const executeGitCommandWithRetry = async (
throw error;
}
}
// Should never reach here but TypeScript needs a return
throw new Error(`Git command failed after ${maxRetries} retries`);
} finally {
gitSemaphore.release();
}
};
}
// Pure functions for Git workspace management
const cleanWorkspace = async (
// Two-phase reset: hard reset (tracked files) + clean (untracked files)
export async function rollbackGitWorkspace(
sourceDir: string,
reason: string = 'clean start'
): Promise<GitOperationResult> => {
console.log(chalk.blue(` 🧹 Cleaning workspace for ${reason}`));
reason: string = 'retry preparation'
): Promise<GitOperationResult> {
console.log(chalk.yellow(` 🔄 Rolling back workspace for ${reason}`));
try {
// Check for uncommitted changes
const status = await $`cd ${sourceDir} && git status --porcelain`;
const hasChanges = status.stdout.trim().length > 0;
const changes = await getChangedFiles(sourceDir, 'status check for rollback');
if (hasChanges) {
// Show what we're about to remove
const changes = status.stdout
.trim()
.split('\n')
.filter((line) => line.length > 0);
console.log(chalk.yellow(` 🔄 Rolling back workspace for ${reason}`));
await executeGitCommandWithRetry(
['git', 'reset', '--hard', 'HEAD'],
sourceDir,
'hard reset for rollback'
);
await executeGitCommandWithRetry(
['git', 'clean', '-fd'],
sourceDir,
'cleaning untracked files for rollback'
);
await $`cd ${sourceDir} && git reset --hard HEAD`;
await $`cd ${sourceDir} && git clean -fd`;
console.log(
chalk.yellow(` ✅ Rollback completed - removed ${changes.length} contaminated changes:`)
);
changes.slice(0, 3).forEach((change) => console.log(chalk.gray(` ${change}`)));
if (changes.length > 3) {
console.log(chalk.gray(` ... and ${changes.length - 3} more files`));
}
} else {
console.log(chalk.blue(` ✅ Workspace already clean (no changes to remove)`));
}
return { success: true, hadChanges: hasChanges };
logChangeSummary(
changes,
' ✅ Rollback completed - removed {count} contaminated changes:',
' ✅ Rollback completed - no changes to remove',
chalk.yellow,
3
);
return { success: true };
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.yellow(` ⚠️ Workspace cleanup failed: ${errMsg}`));
return { success: false, error: error instanceof Error ? error : new Error(errMsg) };
const result = toErrorResult(error);
console.log(chalk.red(` ❌ Rollback failed after retries: ${result.error?.message}`));
return result;
}
};
}
export const createGitCheckpoint = async (
// Creates checkpoint before each attempt. First attempt preserves workspace; retries clean it.
export async function createGitCheckpoint(
sourceDir: string,
description: string,
attempt: number
): Promise<GitOperationResult> => {
): Promise<GitOperationResult> {
console.log(chalk.blue(` 📍 Creating checkpoint for ${description} (attempt ${attempt})`));
try {
// Only clean workspace on retry attempts (attempt > 1), not on first attempts
// This preserves deliverables between agents while still cleaning on actual retries
// First attempt: preserve existing deliverables. Retries: clean workspace to prevent pollution
if (attempt > 1) {
const cleanResult = await cleanWorkspace(sourceDir, `${description} (retry cleanup)`);
const cleanResult = await rollbackGitWorkspace(sourceDir, `${description} (retry cleanup)`);
if (!cleanResult.success) {
const errMsg = cleanResult.error?.message || 'Unknown error';
console.log(
chalk.yellow(` ⚠️ Workspace cleanup failed, continuing anyway: ${errMsg}`)
chalk.yellow(` ⚠️ Workspace cleanup failed, continuing anyway: ${cleanResult.error?.message}`)
);
}
}
// Check for uncommitted changes with retry logic
const status = await executeGitCommandWithRetry(
['git', 'status', '--porcelain'],
sourceDir,
'status check'
);
const hasChanges = status.stdout.trim().length > 0;
const changes = await getChangedFiles(sourceDir, 'status check');
const hasChanges = changes.length > 0;
// Stage changes with retry logic
await executeGitCommandWithRetry(['git', 'add', '-A'], sourceDir, 'staging changes');
// Create commit with retry logic
await executeGitCommandWithRetry(
['git', 'commit', '-m', `📍 Checkpoint: ${description} (attempt ${attempt})`, '--allow-empty'],
sourceDir,
@@ -171,106 +211,54 @@ export const createGitCheckpoint = async (
}
return { success: true };
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.yellow(` ⚠️ Checkpoint creation failed after retries: ${errMsg}`));
return { success: false, error: error instanceof Error ? error : new Error(errMsg) };
const result = toErrorResult(error);
console.log(chalk.yellow(` ⚠️ Checkpoint creation failed after retries: ${result.error?.message}`));
return result;
}
};
}
export const commitGitSuccess = async (
export async function commitGitSuccess(
sourceDir: string,
description: string
): Promise<GitOperationResult> => {
): Promise<GitOperationResult> {
console.log(chalk.green(` 💾 Committing successful results for ${description}`));
try {
// Check what we're about to commit with retry logic
const status = await executeGitCommandWithRetry(
['git', 'status', '--porcelain'],
sourceDir,
'status check for success commit'
);
const changes = status.stdout
.trim()
.split('\n')
.filter((line) => line.length > 0);
const changes = await getChangedFiles(sourceDir, 'status check for success commit');
// Stage changes with retry logic
await executeGitCommandWithRetry(
['git', 'add', '-A'],
sourceDir,
'staging changes for success commit'
);
// Create success commit with retry logic
await executeGitCommandWithRetry(
['git', 'commit', '-m', `${description}: completed successfully`, '--allow-empty'],
sourceDir,
'creating success commit'
);
if (changes.length > 0) {
console.log(chalk.green(` ✅ Success commit created with ${changes.length} file changes:`));
changes.slice(0, 5).forEach((change) => console.log(chalk.gray(` ${change}`)));
if (changes.length > 5) {
console.log(chalk.gray(` ... and ${changes.length - 5} more files`));
}
} else {
console.log(chalk.green(` ✅ Empty success commit created (agent made no file changes)`));
}
logChangeSummary(
changes,
' ✅ Success commit created with {count} file changes:',
' ✅ Empty success commit created (agent made no file changes)',
chalk.green,
5
);
return { success: true };
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.yellow(` ⚠️ Success commit failed after retries: ${errMsg}`));
return { success: false, error: error instanceof Error ? error : new Error(errMsg) };
const result = toErrorResult(error);
console.log(chalk.yellow(` ⚠️ Success commit failed after retries: ${result.error?.message}`));
return result;
}
};
}
export const rollbackGitWorkspace = async (
sourceDir: string,
reason: string = 'retry preparation'
): Promise<GitOperationResult> => {
console.log(chalk.yellow(` 🔄 Rolling back workspace for ${reason}`));
/**
* Get current git commit hash
*/
export async function getGitCommitHash(sourceDir: string): Promise<string | null> {
try {
// Show what we're about to remove with retry logic
const status = await executeGitCommandWithRetry(
['git', 'status', '--porcelain'],
sourceDir,
'status check for rollback'
);
const changes = status.stdout
.trim()
.split('\n')
.filter((line) => line.length > 0);
// Reset to HEAD with retry logic
await executeGitCommandWithRetry(
['git', 'reset', '--hard', 'HEAD'],
sourceDir,
'hard reset for rollback'
);
// Clean untracked files with retry logic
await executeGitCommandWithRetry(
['git', 'clean', '-fd'],
sourceDir,
'cleaning untracked files for rollback'
);
if (changes.length > 0) {
console.log(
chalk.yellow(` ✅ Rollback completed - removed ${changes.length} contaminated changes:`)
);
changes.slice(0, 3).forEach((change) => console.log(chalk.gray(` ${change}`)));
if (changes.length > 3) {
console.log(chalk.gray(` ... and ${changes.length - 3} more files`));
}
} else {
console.log(chalk.yellow(` ✅ Rollback completed - no changes to remove`));
}
return { success: true };
} catch (error) {
const errMsg = error instanceof Error ? error.message : String(error);
console.log(chalk.red(` ❌ Rollback failed after retries: ${errMsg}`));
return { success: false, error: error instanceof Error ? error : new Error(errMsg) };
const result = await $`cd ${sourceDir} && git rev-parse HEAD`;
return result.stdout.trim();
} catch {
return null;
}
};
}

View File

@@ -5,7 +5,7 @@
// as published by the Free Software Foundation.
import chalk from 'chalk';
import { formatDuration } from '../audit/utils.js';
import { formatDuration } from './formatting.js';
// Timing utilities