From c12eca046cc0b113037dae0b475233dce11015c5 Mon Sep 17 00:00:00 2001 From: ajmallesh Date: Tue, 13 Jan 2026 10:52:26 -0800 Subject: [PATCH] fix: resolve parallel workflow race conditions and retry logic bugs - Fix save_deliverable race condition using closure pattern instead of global variable - Fix error classification order so OutputValidationError matches before generic validation - Fix ApplicationFailure re-classification bug by checking instanceof before re-throwing - Add per-error-type retry limits (3 for output validation, 50 for billing) - Add fast retry intervals for pipeline testing mode (10s vs 5min) - Increase worker concurrent activities to 25 for parallel workflows --- CLAUDE.md | 7 +- docker-compose.yml | 3 +- mcp-server/src/index.ts | 22 +++-- mcp-server/src/tools/save-deliverable.ts | 109 +++++++++++++---------- mcp-server/src/utils/file-operations.ts | 12 ++- shannon | 72 +++++++++++---- src/error-handling.ts | 12 ++- src/temporal/activities.ts | 20 +++++ src/temporal/worker.ts | 4 +- src/temporal/workflows.ts | 83 ++++++++++------- 10 files changed, 226 insertions(+), 118 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index c113beb..04bce8c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -46,9 +46,10 @@ Examples: ### Options ```bash -CONFIG= YAML configuration file for authentication and testing parameters -OUTPUT= Custom output directory for session folder (default: ./audit-logs/) -PIPELINE_TESTING=true Use minimal prompts for fast pipeline testing +CONFIG= YAML configuration file for authentication and testing parameters +OUTPUT= Custom output directory for session folder (default: ./audit-logs/) +PIPELINE_TESTING=true Use minimal prompts and fast retry intervals (10s instead of 5min) +REBUILD=true Force Docker rebuild with --no-cache (use when code changes aren't picked up) ``` ### Generate TOTP for Authentication diff --git a/docker-compose.yml b/docker-compose.yml index 8558219..7d509e2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,7 +27,8 @@ services: condition: service_healthy volumes: - ./prompts:/app/prompts - - ${TARGET_REPO:-/tmp/target-repo}:/target-repo + - ${TARGET_REPO:-.}:/target-repo + - ${BENCHMARKS_BASE:-.}:/benchmarks shm_size: 2gb ipc: host security_opt: diff --git a/mcp-server/src/index.ts b/mcp-server/src/index.ts index 934f61b..0844e96 100644 --- a/mcp-server/src/index.ts +++ b/mcp-server/src/index.ts @@ -11,22 +11,25 @@ * for Shannon penetration testing agents. * * Replaces bash script invocations with native tool access. + * + * Uses factory pattern to create tools with targetDir captured in closure, + * ensuring thread-safety when multiple workflows run in parallel. */ import { createSdkMcpServer } from '@anthropic-ai/claude-agent-sdk'; -import { saveDeliverableTool } from './tools/save-deliverable.js'; +import { createSaveDeliverableTool } from './tools/save-deliverable.js'; import { generateTotpTool } from './tools/generate-totp.js'; -declare global { - var __SHANNON_TARGET_DIR: string | undefined; -} - /** * Create Shannon Helper MCP Server with target directory context + * + * Each workflow should create its own MCP server instance with its targetDir. + * The save_deliverable tool captures targetDir in a closure, preventing race + * conditions when multiple workflows run in parallel. */ export function createShannonHelperServer(targetDir: string): ReturnType { - // Store target directory for tool access - global.__SHANNON_TARGET_DIR = targetDir; + // Create save_deliverable tool with targetDir in closure (no global variable) + const saveDeliverableTool = createSaveDeliverableTool(targetDir); return createSdkMcpServer({ name: 'shannon-helper', @@ -35,8 +38,9 @@ export function createShannonHelperServer(targetDir: string): ReturnType; /** - * save_deliverable tool implementation + * Create save_deliverable handler with targetDir captured in closure + * + * This factory pattern ensures each MCP server instance has its own targetDir, + * preventing race conditions when multiple workflows run in parallel. */ -export async function saveDeliverable(args: SaveDeliverableInput): Promise { - try { - const { deliverable_type, content } = args; +function createSaveDeliverableHandler(targetDir: string) { + return async function saveDeliverable(args: SaveDeliverableInput): Promise { + try { + const { deliverable_type, content } = args; - // Validate queue JSON if applicable - if (isQueueType(deliverable_type)) { - const queueValidation = validateQueueJson(content); - if (!queueValidation.valid) { - const errorResponse = createValidationError( - queueValidation.message ?? 'Invalid queue JSON', - true, - { - deliverableType: deliverable_type, - expectedFormat: '{"vulnerabilities": [...]}', - } - ); - return createToolResult(errorResponse); + // Validate queue JSON if applicable + if (isQueueType(deliverable_type)) { + const queueValidation = validateQueueJson(content); + if (!queueValidation.valid) { + const errorResponse = createValidationError( + queueValidation.message ?? 'Invalid queue JSON', + true, + { + deliverableType: deliverable_type, + expectedFormat: '{"vulnerabilities": [...]}', + } + ); + return createToolResult(errorResponse); + } } + + // Get filename and save file (targetDir captured from closure) + const filename = DELIVERABLE_FILENAMES[deliverable_type]; + const filepath = saveDeliverableFile(targetDir, filename, content); + + // Success response + const successResponse: SaveDeliverableResponse = { + status: 'success', + message: `Deliverable saved successfully: ${filename}`, + filepath, + deliverableType: deliverable_type, + validated: isQueueType(deliverable_type), + }; + + return createToolResult(successResponse); + } catch (error) { + const errorResponse = createGenericError( + error, + false, + { deliverableType: args.deliverable_type } + ); + + return createToolResult(errorResponse); } - - // Get filename and save file - const filename = DELIVERABLE_FILENAMES[deliverable_type]; - const filepath = saveDeliverableFile(filename, content); - - // Success response - const successResponse: SaveDeliverableResponse = { - status: 'success', - message: `Deliverable saved successfully: ${filename}`, - filepath, - deliverableType: deliverable_type, - validated: isQueueType(deliverable_type), - }; - - return createToolResult(successResponse); - } catch (error) { - const errorResponse = createGenericError( - error, - false, - { deliverableType: args.deliverable_type } - ); - - return createToolResult(errorResponse); - } + }; } /** - * Tool definition for MCP server - created using SDK's tool() function + * Factory function to create save_deliverable tool with targetDir in closure + * + * Each MCP server instance should call this with its own targetDir to ensure + * deliverables are saved to the correct workflow's directory. */ -export const saveDeliverableTool = tool( - 'save_deliverable', - 'Saves deliverable files with automatic validation. Queue files must have {"vulnerabilities": [...]} structure.', - SaveDeliverableInputSchema.shape, - saveDeliverable -); +export function createSaveDeliverableTool(targetDir: string) { + return tool( + 'save_deliverable', + 'Saves deliverable files with automatic validation. Queue files must have {"vulnerabilities": [...]} structure.', + SaveDeliverableInputSchema.shape, + createSaveDeliverableHandler(targetDir) + ); +} diff --git a/mcp-server/src/utils/file-operations.ts b/mcp-server/src/utils/file-operations.ts index a10e438..8f718b1 100644 --- a/mcp-server/src/utils/file-operations.ts +++ b/mcp-server/src/utils/file-operations.ts @@ -14,16 +14,14 @@ import { writeFileSync, mkdirSync } from 'fs'; import { join } from 'path'; -declare global { - var __SHANNON_TARGET_DIR: string | undefined; -} - /** * Save deliverable file to deliverables/ directory + * + * @param targetDir - Target directory for deliverables (passed explicitly to avoid race conditions) + * @param filename - Name of the deliverable file + * @param content - File content to save */ -export function saveDeliverableFile(filename: string, content: string): string { - // Use target directory from global context (set by createShannonHelperServer) - const targetDir = global.__SHANNON_TARGET_DIR || process.cwd(); +export function saveDeliverableFile(targetDir: string, filename: string, content: string): string { const deliverablesDir = join(targetDir, 'deliverables'); const filepath = join(deliverablesDir, filename); diff --git a/shannon b/shannon index 651d94e..920ee09 100755 --- a/shannon +++ b/shannon @@ -52,10 +52,48 @@ parse_args() { ID=*) ID="${arg#ID=}" ;; CLEAN=*) CLEAN="${arg#CLEAN=}" ;; PIPELINE_TESTING=*) PIPELINE_TESTING="${arg#PIPELINE_TESTING=}" ;; + REBUILD=*) REBUILD="${arg#REBUILD=}" ;; esac done } +# Check if Temporal is running and healthy +is_temporal_ready() { + docker compose -f "$COMPOSE_FILE" exec -T temporal \ + temporal operator cluster health --address localhost:7233 2>/dev/null | grep -q "SERVING" +} + +# Ensure containers are running +ensure_containers() { + # Quick check: if Temporal is already healthy, we're good + if is_temporal_ready; then + return 0 + fi + + # Need to start containers + echo "Starting Shannon containers..." + if [ "$REBUILD" = "true" ]; then + # Force rebuild without cache (use when code changes aren't being picked up) + echo "Rebuilding with --no-cache..." + docker compose -f "$COMPOSE_FILE" build --no-cache worker + fi + docker compose -f "$COMPOSE_FILE" up -d --build + + # Wait for Temporal to be ready + echo "Waiting for Temporal to be ready..." + for i in $(seq 1 30); do + if is_temporal_ready; then + echo "Temporal is ready!" + return 0 + fi + if [ "$i" -eq 30 ]; then + echo "Timeout waiting for Temporal" + exit 1 + fi + sleep 2 + done +} + cmd_start() { parse_args "$@" @@ -72,22 +110,22 @@ cmd_start() { exit 1 fi - # Start containers - TARGET_REPO="$REPO" docker compose -f "$COMPOSE_FILE" up -d --build + # Determine container path for REPO + # - If REPO is already a container path (/benchmarks/*, /target-repo), use as-is + # - Otherwise, it's a host path - mount to /target-repo and use that + case "$REPO" in + /benchmarks/*|/target-repo|/target-repo/*) + CONTAINER_REPO="$REPO" + ;; + *) + # Host path - export for docker-compose mount + export TARGET_REPO="$REPO" + CONTAINER_REPO="/target-repo" + ;; + esac - # Wait for Temporal to be ready - echo "Waiting for Temporal to be ready..." - for i in $(seq 1 30); do - if docker compose -f "$COMPOSE_FILE" exec -T temporal \ - temporal operator cluster health --address localhost:7233 2>/dev/null | grep -q "SERVING"; then - break - fi - if [ "$i" -eq 30 ]; then - echo "Timeout waiting for Temporal" - exit 1 - fi - sleep 2 - done + # Ensure containers are running (starts them if needed) + ensure_containers # Build optional args ARGS="" @@ -95,9 +133,9 @@ cmd_start() { [ -n "$OUTPUT" ] && ARGS="$ARGS --output $OUTPUT" [ "$PIPELINE_TESTING" = "true" ] && ARGS="$ARGS --pipeline-testing" - # Run the client + # Run the client to submit workflow docker compose -f "$COMPOSE_FILE" exec -T worker \ - node dist/temporal/client.js "$URL" "/target-repo" $ARGS + node dist/temporal/client.js "$URL" "$CONTAINER_REPO" $ARGS } cmd_logs() { diff --git a/src/error-handling.ts b/src/error-handling.ts index 2b837ba..f43d482 100644 --- a/src/error-handling.ts +++ b/src/error-handling.ts @@ -247,8 +247,18 @@ export function classifyErrorForTemporal(error: unknown): TemporalErrorClassific return { type: 'PermissionError', retryable: false }; } + // === OUTPUT VALIDATION ERRORS (Retryable) === + // Agent didn't produce expected deliverables - retry may succeed + // IMPORTANT: Must come BEFORE generic 'validation' check below + if ( + message.includes('failed output validation') || + message.includes('output validation failed') + ) { + return { type: 'OutputValidationError', retryable: true }; + } + // Invalid Request (400) - malformed request is permanent - // Note: Checked AFTER billing since Anthropic billing is 400 + // Note: Checked AFTER billing and AFTER output validation if ( message.includes('invalid_request_error') || message.includes('malformed') || diff --git a/src/temporal/activities.ts b/src/temporal/activities.ts index c2b2dc7..40f9d3b 100644 --- a/src/temporal/activities.ts +++ b/src/temporal/activities.ts @@ -25,6 +25,10 @@ import chalk from 'chalk'; const MAX_ERROR_MESSAGE_LENGTH = 2000; const MAX_STACK_TRACE_LENGTH = 1000; +// Max retries for output validation errors (agent didn't save deliverables) +// Lower than default 50 since this is unlikely to self-heal +const MAX_OUTPUT_VALIDATION_RETRIES = 3; + /** * Truncate error message to prevent buffer overflow in Temporal serialization. */ @@ -193,6 +197,16 @@ async function runAgentActivity( success: false, error: 'Output validation failed', }); + + // Limit output validation retries (unlikely to self-heal) + if (attemptNumber >= MAX_OUTPUT_VALIDATION_RETRIES) { + throw ApplicationFailure.nonRetryable( + `Agent ${agentName} failed output validation after ${attemptNumber} attempts`, + 'OutputValidationError', + [{ agentName, attemptNumber, elapsed: Date.now() - startTime }] + ); + } + // Let Temporal retry (will be classified as OutputValidationError) throw new Error(`Agent ${agentName} failed output validation`); } @@ -224,6 +238,12 @@ async function runAgentActivity( console.error(`Failed to rollback git workspace for ${agentName}:`, rollbackErr); } + // If error is already an ApplicationFailure (e.g., from our retry limit logic), + // re-throw it directly without re-classifying + if (error instanceof ApplicationFailure) { + throw error; + } + // Classify error for Temporal retry behavior const classified = classifyErrorForTemporal(error); // Truncate message to prevent protobuf buffer overflow diff --git a/src/temporal/worker.ts b/src/temporal/worker.ts index 7346257..81c7f7e 100644 --- a/src/temporal/worker.ts +++ b/src/temporal/worker.ts @@ -9,7 +9,7 @@ * Temporal worker for Shannon pentest pipeline. * * Polls the 'shannon-pipeline' task queue and executes activities. - * Handles up to 5 concurrent activities to support parallel agent execution. + * Handles up to 25 concurrent activities to support multiple parallel workflows. * * Usage: * npm run temporal:worker @@ -49,7 +49,7 @@ async function runWorker(): Promise { workflowBundle, activities, taskQueue: 'shannon-pipeline', - maxConcurrentActivityTaskExecutions: 5, // Match parallel agent count + maxConcurrentActivityTaskExecutions: 25, // Support multiple parallel workflows (5 agents × ~5 workflows) }); // Graceful shutdown handling diff --git a/src/temporal/workflows.ts b/src/temporal/workflows.ts index 700dc57..078b5d1 100644 --- a/src/temporal/workflows.ts +++ b/src/temporal/workflows.ts @@ -35,25 +35,44 @@ import { type PipelineProgress, } from './shared.js'; -// Activity proxy with retry configuration +// Retry configuration for production (long intervals for billing recovery) +const PRODUCTION_RETRY = { + initialInterval: '5 minutes', + maximumInterval: '30 minutes', + backoffCoefficient: 2, + maximumAttempts: 50, + nonRetryableErrorTypes: [ + 'AuthenticationError', + 'PermissionError', + 'InvalidRequestError', + 'RequestTooLargeError', + 'ConfigurationError', + 'InvalidTargetError', + 'ExecutionLimitError', + ], +}; + +// Retry configuration for pipeline testing (fast iteration) +const TESTING_RETRY = { + initialInterval: '10 seconds', + maximumInterval: '30 seconds', + backoffCoefficient: 2, + maximumAttempts: 5, + nonRetryableErrorTypes: PRODUCTION_RETRY.nonRetryableErrorTypes, +}; + +// Activity proxy with production retry configuration (default) const acts = proxyActivities({ startToCloseTimeout: '2 hours', heartbeatTimeout: '30 seconds', - retry: { - initialInterval: '5 minutes', - maximumInterval: '30 minutes', - backoffCoefficient: 2, - maximumAttempts: 50, - nonRetryableErrorTypes: [ - 'AuthenticationError', - 'PermissionError', - 'InvalidRequestError', - 'RequestTooLargeError', - 'ConfigurationError', - 'InvalidTargetError', - 'ExecutionLimitError', - ], - }, + retry: PRODUCTION_RETRY, +}); + +// Activity proxy with testing retry configuration (fast) +const testActs = proxyActivities({ + startToCloseTimeout: '10 minutes', + heartbeatTimeout: '30 seconds', + retry: TESTING_RETRY, }); export async function pentestPipelineWorkflow( @@ -61,6 +80,10 @@ export async function pentestPipelineWorkflow( ): Promise { const { workflowId } = workflowInfo(); + // Select activity proxy based on testing mode + // Pipeline testing uses fast retry intervals (10s) for quick iteration + const a = input.pipelineTestingMode ? testActs : acts; + // Workflow state (queryable) const state: PipelineState = { status: 'running', @@ -99,13 +122,13 @@ export async function pentestPipelineWorkflow( state.currentPhase = 'pre-recon'; state.currentAgent = 'pre-recon'; state.agentMetrics['pre-recon'] = - await acts.runPreReconAgent(activityInput); + await a.runPreReconAgent(activityInput); state.completedAgents.push('pre-recon'); // === Phase 2: Reconnaissance === state.currentPhase = 'recon'; state.currentAgent = 'recon'; - state.agentMetrics['recon'] = await acts.runReconAgent(activityInput); + state.agentMetrics['recon'] = await a.runReconAgent(activityInput); state.completedAgents.push('recon'); // === Phase 3: Vulnerability Analysis (Parallel) === @@ -113,11 +136,11 @@ export async function pentestPipelineWorkflow( state.currentAgent = 'vuln-agents'; const vulnResults = await Promise.all([ - acts.runInjectionVulnAgent(activityInput), - acts.runXssVulnAgent(activityInput), - acts.runAuthVulnAgent(activityInput), - acts.runSsrfVulnAgent(activityInput), - acts.runAuthzVulnAgent(activityInput), + a.runInjectionVulnAgent(activityInput), + a.runXssVulnAgent(activityInput), + a.runAuthVulnAgent(activityInput), + a.runSsrfVulnAgent(activityInput), + a.runAuthzVulnAgent(activityInput), ]); const vulnAgents = [ @@ -141,11 +164,11 @@ export async function pentestPipelineWorkflow( state.currentAgent = 'exploit-agents'; const exploitResults = await Promise.all([ - acts.runInjectionExploitAgent(activityInput), - acts.runXssExploitAgent(activityInput), - acts.runAuthExploitAgent(activityInput), - acts.runSsrfExploitAgent(activityInput), - acts.runAuthzExploitAgent(activityInput), + a.runInjectionExploitAgent(activityInput), + a.runXssExploitAgent(activityInput), + a.runAuthExploitAgent(activityInput), + a.runSsrfExploitAgent(activityInput), + a.runAuthzExploitAgent(activityInput), ]); const exploitAgents = [ @@ -169,10 +192,10 @@ export async function pentestPipelineWorkflow( state.currentAgent = 'report'; // First, assemble the concatenated report from exploitation evidence files - await acts.assembleReportActivity(activityInput); + await a.assembleReportActivity(activityInput); // Then run the report agent to add executive summary and clean up - state.agentMetrics['report'] = await acts.runReportAgent(activityInput); + state.agentMetrics['report'] = await a.runReportAgent(activityInput); state.completedAgents.push('report'); // === Complete ===