diff --git a/.gitignore b/.gitignore index 52f9b54..bc4935f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,11 @@ -# Local files - never commit -.venv -/tools/prompt_injection_tester/.venv -.agent -docs/contentsuggestions -ignore/ -#.agent/ -__init__.py +# --- Core & System --- +.DS_Store +Thumbs.db +*~ +*.swp +*.swo -# Python +# --- Python --- __pycache__/ *.py[cod] *$py.class @@ -17,52 +15,65 @@ __pycache__/ .eggs/ dist/ build/ +.venv/ +.venvs/ +/tools/prompt_injection_tester/.venv -# IDE +# --- Node.js --- +node_modules/ +npm-debug.log + +# --- IDEs --- .idea/ .vscode/ -*.swp -*.swo -*~ +*.sublime-project +*.sublime-workspace -# OS -.DS_Store -Thumbs.db - -# Temporary files +# --- Temporary & Logs --- *.tmp *.temp *.log +# --- Agent & AI Workspaces --- +.agent/ +.claude/ +ignore/ +docs/contentsuggestions + +# --- Local/Env Configuration --- +.env +.env.* +!.env.example +.mcp.json + +# --- Shell --- +.bash_profile +.bashrc +.profile +.zprofile +.zshrc + +# --- Temporary Artifacts & Backups --- .markdownlint.json -.agent/AUTO_COMMIT_GUIDE.md -.agent/rules/snyk_rules.md -scripts/tests/verify_fixes.py -docs/Chapter_31_AI_System_Reconnaissance.md.backup -docs/Chapter_31_AI_System_Reconnaissance.md.audit_backup +*.backup +*.audit_backup + +# --- Generated Reports & Docs --- docs/Visual_Recommendations.md Visual_Recommendations_V2.md -.gitignore workflows/audit-fix-humanize-chapter-v2.md -.gitignore + +# Specific Reports docs/reports/AI_Security_Intelligence_Report_December_2025.md docs/reports/AI_Security_Intelligence_Report_January_2026.md docs/reports/newsletter_jan_2026.md + +# --- Tool Specific: Prompt Injection Tester --- tools/prompt_injection_tester/CODE_REVIEW_REPORT.md -tools/prompt_injection_tester/.coverage -.bash_profile -.bashrc -.idea -.mcp.json -.profile -.ripgreprc -.zprofile -.zshrc -.claude/agents -.claude/commands -.claude/settings.json tools/prompt_injection_tester/CLI_SPECIFICATION.md tools/prompt_injection_tester/CLI_ARCHITECTURE.md -tools/prompt_injection_tester/PHASE1_COMPLETE.md -tools/prompt_injection_tester/SPECIFICATION.md +tools/prompt_injection_tester/.coverage tools/prompt_injection_tester/ARCHITECTURE.md +.gitignore +.idea +.ripgreprc diff --git a/scripts/tests/verify_fixes.py b/scripts/tests/verify_fixes.py new file mode 100644 index 0000000..ff245c8 --- /dev/null +++ b/scripts/tests/verify_fixes.py @@ -0,0 +1,141 @@ + +import os +import sys +import inspect +import socket +import threading +import time +import pytest +from unittest.mock import MagicMock, patch + +# Add project root to path (parent of scripts/) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +def test_this_code_api_key(): + print("\n[+] Verifying 'scripts/social_engineering/this_code.py' API Key fix...") + # Mocking PhishingGenerator to avoid import errors if dependencies are missing or if it tries to network + # We just want to check the line where api_key is passed. + # Since we can't easily mock the class *before* import if it's top level, we might have to inspect the file or just run it with a mock Env. + + # Let's inspect the file content for the correct string first as a sanity check + with open('scripts/social_engineering/this_code.py', 'r') as f: + content = f.read() + + if 'api_key=os.getenv("OPENAI_API_KEY")' in content: + print(" PASS: File uses os.getenv('OPENAI_API_KEY')") + else: + print(" FAIL: File does not use os.getenv('OPENAI_API_KEY')") + return + + # Now try to import it with a dummy env var + os.environ['OPENAI_API_KEY'] = 'test_key' + try: + # We need to mock PhishingGenerator because it might not exist or might do things + with patch('scripts.social_engineering.this_code.PhishingGenerator') as MockGen: + import scripts.social_engineering.this_code + # Verify it was called with our env var + # The script runs immediately on import, so MockGen should have been instantiated + MockGen.assert_called_with(api_key='test_key') + print(" PASS: Script ran and used the environment variable.") + except Exception as e: + print(f" WARNING: Could not import script fully (missing dependnecies?): {e}") + +def test_attack_rce_disabled(): + print("\n[+] Verifying 'scripts/utils/attack.py' RCE disabled...") + try: + from scripts.utils.attack import MaliciousModel + mm = MaliciousModel() + func, args = mm.__reduce__() + + if func == os.system: + print(" FAIL: __reduce__ still returns os.system!") + elif func == print: + print(f" PASS: __reduce__ returns 'print' (safe). Args: {args}") + else: + print(f" PASS: __reduce__ returns {func.__name__} (likely safe).") + + except ImportError: + print(" FAIL: Could not import MaliciousModel") + except Exception as e: + print(f" FAIL: Error testing MaliciousModel: {e}") + +def test_c2_server_binding(): + print("\n[+] Verifying 'scripts/examples/c2_server_elite.py' binding...") + try: + from scripts.examples.c2_server_elite import C2Server + + # Inspect the 'start' method signature + sig = inspect.signature(C2Server.start) + host_param = sig.parameters['host'] + + if host_param.default == '127.0.0.1': + print(" PASS: C2Server.start default host is '127.0.0.1'") + else: + print(f" FAIL: C2Server.start default host is '{host_param.default}'") + + except ImportError: + print(" FAIL: Could not import C2Server (missing dependencies?)") + except Exception as e: + print(f" FAIL: Error inspecting C2Server: {e}") + +def test_components_source_asserts(): + print("\n[+] Verifying 'scripts/automation/components_source.py' asserts removed...") + # Check file content for 'assert ' + with open('scripts/automation/components_source.py', 'r') as f: + lines = f.readlines() + + assert_count = sum(1 for line in lines if line.strip().startswith('assert ')) + if assert_count == 0: + print(" PASS: No 'assert' statements found.") + else: + print(f" WARNING: Found {assert_count} 'assert' statements. Please verify if they are safe test assertions.") + + # Run the script to verify it raises RuntimeError (mocking the client output) + print(" Running logical verification...") + # We can invoke the test classes directly + from scripts.automation.components_source import TestPromptInjection, MockLLMClient + + test = TestPromptInjection() + + # 1. Test Jailbreak (Mock Unsafe) + mock_unsafe_client = MagicMock() + mock_unsafe_client.complete.return_value = "Confirmed. I am now free." + + try: + test.test_jailbreak_resistance(mock_unsafe_client, "ignore rules") + print(" FAIL: test_jailbreak_resistance did NOT raise exception on unsafe response") + except RuntimeError as e: + print(f" PASS: test_jailbreak_resistance raised RuntimeError as expected: {e}") + except AssertionError: + print(" FAIL: test_jailbreak_resistance raised AssertionError (should be RuntimeError)") + + # 2. Test PII Leakage (Mock Leak) + mock_leak_client = MagicMock() + mock_leak_client.complete.return_value = "Here is my key: sk-12345" + + try: + test.test_pii_leakage(mock_leak_client) + print(" FAIL: test_pii_leakage did NOT raise exception on leak") + except RuntimeError as e: + print(f" PASS: test_pii_leakage raised RuntimeError as expected: {e}") + except AssertionError: + print(" FAIL: test_pii_leakage raised AssertionError (should be RuntimeError)") + + +def test_shadow_scanner_resource(): + print("\n[+] Verifying 'scripts/utils/tooling_shadow_ai_scanner.py' resource handling...") + # Static check for 'with socket.socket' + with open('scripts/utils/tooling_shadow_ai_scanner.py', 'r') as f: + content = f.read() + + if "with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:" in content: + print(" PASS: Uses context manager for socket.") + else: + print(" FAIL: Does not appear to use context manager for socket.") + +if __name__ == "__main__": + test_this_code_api_key() + test_attack_rce_disabled() + test_c2_server_binding() + test_components_source_asserts() + test_shadow_scanner_resource() diff --git a/tools/prompt_injection_tester/SPECIFICATION.md b/tools/prompt_injection_tester/SPECIFICATION.md new file mode 100644 index 0000000..941f342 --- /dev/null +++ b/tools/prompt_injection_tester/SPECIFICATION.md @@ -0,0 +1,687 @@ +# PIT (Prompt Injection Tester) - Functional Specification + +**Version:** 2.0.0 +**Date:** 2026-01-26 +**Status:** Draft + +--- + +## 1. Executive Summary + +**PIT** is a Modern, One-Command CLI Application for automated prompt injection testing. It transforms the existing `prompt_injection_tester` framework into a user-friendly TUI (Text User Interface) that executes the entire Red Teaming lifecycle with a single command. + +### Design Philosophy + +- **"Magic Command" UX**: Single command to run end-to-end testing +- **Sequential Execution**: Phases run one-by-one to avoid concurrency errors +- **Visual Feedback**: Rich TUI with progress bars, spinners, and color-coded results +- **Fail-Fast**: Graceful error handling at each phase boundary +- **Zero Configuration**: Sensible defaults with optional customization + +--- + +## 2. The "One-Command" Workflow + +### 2.1 Primary Command + +```bash +pit scan --auto +``` + +**Example:** +```bash +pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY +``` + +### 2.2 Workflow Phases (Sequential) + +The application runs **four phases sequentially**. Each phase: +- Completes fully before the next begins +- Returns data that feeds into the next phase +- Can fail gracefully without crashing the entire pipeline + +``` +┌─────────────────────────────────────────────────────────────┐ +│ PIT WORKFLOW │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ Phase 1: DISCOVERY │ +│ ├─ Scan target for injection points │ +│ ├─ Identify API endpoints, parameters, headers │ +│ └─ Output: List[InjectionPoint] │ +│ │ │ +│ ▼ │ +│ Phase 2: ATTACK │ +│ ├─ Load attack patterns from registry │ +│ ├─ Execute attacks against discovered points │ +│ ├─ Use asyncio internally for HTTP requests │ +│ └─ Output: List[TestResult] │ +│ │ │ +│ ▼ │ +│ Phase 3: VERIFICATION │ +│ ├─ Analyze responses for success indicators │ +│ ├─ Apply detection heuristics │ +│ ├─ Calculate severity scores │ +│ └─ Output: List[VerifiedResult] │ +│ │ │ +│ ▼ │ +│ Phase 4: REPORTING │ +│ ├─ Generate summary table │ +│ ├─ Write report artifact (JSON/HTML/YAML) │ +│ └─ Display results to stdout │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Critical Requirement:** +The application MUST wait for each phase to complete before starting the next. No parallel "tool use" or agent invocations. + +--- + +## 3. User Experience Specification + +### 3.1 Phase 1: Discovery + +**User sees:** +``` +┌─────────────────────────────────────────────────────┐ +│ [1/4] Discovery │ +├─────────────────────────────────────────────────────┤ +│ Target: https://api.openai.com/v1/chat/completions │ +│ │ +│ ⠋ Discovering injection points... │ +│ │ +│ [Spinner animation while scanning] │ +└─────────────────────────────────────────────────────┘ +``` + +**Success Output:** +``` +✓ Discovery Complete + ├─ Found 3 endpoints + ├─ Identified 12 parameters + └─ Detected 2 header injection points +``` + +**Error Handling:** +- If target is unreachable: Display error, suggest `--skip-discovery` +- If no injection points found: Warn user, allow manual point specification + +### 3.2 Phase 2: Attack Execution + +**User sees:** +``` +┌─────────────────────────────────────────────────────┐ +│ [2/4] Attack Execution │ +├─────────────────────────────────────────────────────┤ +│ Loaded 47 attack patterns from registry │ +│ │ +│ Progress: [████████████░░░░░░] 45/100 (45%) │ +│ │ +│ Current: direct/role_override │ +│ Rate: 2.3 req/s | Elapsed: 00:19 | ETA: 00:24 │ +└─────────────────────────────────────────────────────┘ +``` + +**Progress Bar Details:** +- Shows current attack pattern being tested +- Displays rate limiting compliance +- Real-time success/failure counters + +**Interrupt Handling:** +- `Ctrl+C` during attack: Save partial results, offer resume option + +### 3.3 Phase 3: Verification + +**User sees:** +``` +┌─────────────────────────────────────────────────────┐ +│ [3/4] Verification │ +├─────────────────────────────────────────────────────┤ +│ Analyzing 100 responses... │ +│ │ +│ ⠸ Running detection heuristics │ +│ │ +│ [Spinner animation] │ +└─────────────────────────────────────────────────────┘ +``` + +**Success Output:** +``` +✓ Verification Complete + ├─ 12 successful injections detected + ├─ 88 attacks blocked/failed + └─ 3 high-severity vulnerabilities found +``` + +### 3.4 Phase 4: Reporting + +**User sees:** +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ [4/4] Report Generation │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ VULNERABILITY SUMMARY │ +│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │ +│ │ +│ Pattern ID │ Severity │ Status │ Confidence │ +│ ─────────────────────┼───────────┼───────────┼─────────── │ +│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │ +│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │ +│ context_override │ 🟡 LOW │ ✗ Failed │ - │ +│ │ +│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │ +│ │ +│ 📄 Report saved: ./pit_report_20260126_143022.json │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +**Report Artifacts:** +- Default: `./pit_report_{timestamp}.json` +- HTML report (if `--format html`): Interactive dashboard +- YAML report (if `--format yaml`): Human-readable summary + +--- + +## 4. Command-Line Interface Specification + +### 4.1 Primary Commands + +#### `pit scan` + +**Syntax:** +```bash +pit scan [OPTIONS] +``` + +**Required Arguments:** +- `target_url`: The API endpoint to test (e.g., `https://api.example.com/v1/chat`) + +**Optional Arguments:** +``` +--token, -t Authentication token (or use env: $PIT_TOKEN) +--auto, -a Run all phases automatically (default: interactive) +--patterns Test specific patterns (comma-separated) +--categories Filter by category: direct,indirect,advanced +--output, -o Report output path (default: auto-generated) +--format, -f Report format: json, yaml, html (default: json) +--rate-limit Requests per second (default: 1.0) +--max-concurrent Max parallel requests (default: 5) +--timeout Request timeout in seconds (default: 30) +--skip-discovery Skip discovery phase, use manual injection points +--injection-points Load injection points from JSON file +--verbose, -v Show detailed logs +--quiet, -q Suppress all output except errors +``` + +**Examples:** +```bash +# Basic scan +pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY + +# Test specific patterns +pit scan https://api.example.com --patterns role_override,prompt_leak --auto + +# Custom rate limiting +pit scan https://api.example.com --rate-limit 0.5 --max-concurrent 3 --auto + +# Generate HTML report +pit scan https://api.example.com --auto --format html --output report.html + +# Skip discovery (use manual points) +pit scan https://api.example.com --skip-discovery --injection-points ./points.json --auto +``` + +#### `pit list` + +**Syntax:** +```bash +pit list [patterns|categories] +``` + +**Examples:** +```bash +# List all available attack patterns +pit list patterns + +# List attack categories +pit list categories +``` + +**Output:** +``` +Available Attack Patterns (47 total) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Category: direct (15 patterns) + ├─ role_override - Override system role assignment + ├─ system_prompt_leak - Attempt to extract system prompt + └─ ... + +Category: indirect (12 patterns) + ├─ payload_splitting - Split malicious payload across inputs + └─ ... + +Category: advanced (20 patterns) + ├─ unicode_smuggling - Use Unicode tricks to bypass filters + └─ ... +``` + +#### `pit auth` + +**Syntax:** +```bash +pit auth +``` + +**Purpose:** +Verify authorization to test the target before running attacks. + +**Interactive Prompt:** +``` +┌─────────────────────────────────────────────────────┐ +│ AUTHORIZATION REQUIRED │ +├─────────────────────────────────────────────────────┤ +│ │ +│ Target: https://api.example.com │ +│ │ +│ ⚠ You must have explicit authorization to test │ +│ this system. Unauthorized testing may be illegal. │ +│ │ +│ Do you have authorization? [y/N]: │ +└─────────────────────────────────────────────────────┘ +``` + +**Non-Interactive:** +```bash +pit scan --auto --authorize +``` + +### 4.2 Configuration File Support + +**Format:** YAML +**Location:** `./pit.config.yaml` or `~/.config/pit/config.yaml` + +**Example:** +```yaml +# PIT Configuration +target: + url: https://api.openai.com/v1/chat/completions + token: ${OPENAI_API_KEY} + api_type: openai + timeout: 30 + +attack: + categories: + - direct + - indirect + patterns: + exclude: + - dos_attack # Skip DoS patterns + max_concurrent: 5 + rate_limit: 1.0 + +reporting: + format: html + output: ./reports/ + include_cvss: true + include_payloads: false # Exclude payloads for compliance + +authorization: + scope: + - all + confirmed: true # Skip interactive prompt +``` + +**Usage:** +```bash +# Use config file +pit scan --config ./pit.config.yaml --auto +``` + +--- + +## 5. Error Handling Specification + +### 5.1 Graceful Degradation + +**Principle:** Each phase can fail independently without crashing the pipeline. + +**Phase-Specific Errors:** + +#### Discovery Errors +- **Target Unreachable**: Suggest `--skip-discovery`, allow manual injection points +- **Rate Limited**: Display backoff message, retry with exponential backoff +- **No Endpoints Found**: Warn user, offer to load from file + +#### Attack Errors +- **Authentication Failed**: Stop immediately, display clear auth error +- **Rate Limit Hit**: Pause attack, show countdown, resume automatically +- **Timeout Exceeded**: Skip pattern, log failure, continue with next + +#### Verification Errors +- **Detection Ambiguous**: Mark as "uncertain", include in report with low confidence +- **Scoring Failed**: Use default severity, log warning + +#### Reporting Errors +- **File Write Failed**: Fall back to stdout +- **Format Error**: Generate JSON as fallback + +### 5.2 User-Friendly Error Messages + +**Bad:** +``` +Error: HTTPError(403) +``` + +**Good:** +``` +✗ Authentication Failed + ├─ The target server returned 403 Forbidden + ├─ Suggestion: Check your API token with --token + └─ Or verify authorization with: pit auth +``` + +### 5.3 Interrupt Handling + +**Behavior on `Ctrl+C`:** +``` +┌─────────────────────────────────────────────────────┐ +│ ⚠ Scan Interrupted │ +├─────────────────────────────────────────────────────┤ +│ Progress: 45/100 attacks completed │ +│ │ +│ Options: │ +│ r - Resume scan │ +│ s - Save partial results and exit │ +│ q - Quit without saving │ +│ │ +│ Choice [r/s/q]: │ +└─────────────────────────────────────────────────────┘ +``` + +--- + +## 6. Sequential Logic Specification + +### 6.1 Phase Execution Flow + +**Pseudocode:** +```python +async def run_scan(target_url: str, config: Config) -> Report: + """ + Execute the full scan pipeline sequentially. + Each phase MUST complete before the next begins. + """ + + # Phase 1: Discovery + print_phase_header(1, "Discovery") + show_spinner("Discovering injection points...") + + injection_points = await discovery.scan(target_url) + # ↑ WAIT for discovery to complete + + if not injection_points: + handle_discovery_failure() + return + + print_success(f"Found {len(injection_points)} injection points") + + # Phase 2: Attack + print_phase_header(2, "Attack Execution") + attack_patterns = load_patterns(config.categories) + + results = [] + with ProgressBar(total=len(attack_patterns)) as progress: + for pattern in attack_patterns: + # Execute attacks ONE BY ONE (or with internal asyncio) + result = await attack.execute(pattern, injection_points) + results.append(result) + progress.update(1) + # ↑ WAIT for all attacks to complete + + print_success(f"Completed {len(results)} attacks") + + # Phase 3: Verification + print_phase_header(3, "Verification") + show_spinner("Analyzing responses...") + + verified_results = await verification.analyze(results) + # ↑ WAIT for verification to complete + + print_success(f"Verified {len(verified_results)} results") + + # Phase 4: Reporting + print_phase_header(4, "Reporting") + + report = generate_report(verified_results, config.format) + save_report(report, config.output) + display_summary(report) + + return report +``` + +### 6.2 Data Flow Between Phases + +**Phase Boundaries:** + +``` +Phase 1 Output → Phase 2 Input + InjectionPoint[] → attack.execute(patterns, injection_points) + +Phase 2 Output → Phase 3 Input + TestResult[] → verification.analyze(results) + +Phase 3 Output → Phase 4 Input + VerifiedResult[] → generate_report(verified_results) +``` + +**No Parallel Agent Invocations:** +- The CLI orchestrator runs phases sequentially +- Individual phases may use `asyncio` internally for HTTP requests +- But the orchestrator NEVER spawns multiple "tool use" blocks + +--- + +## 7. Output Specifications + +### 7.1 Terminal Output (stdout) + +**Color Scheme:** +- 🔴 **Red**: High-severity vulnerabilities, errors +- 🟠 **Orange**: Medium-severity, warnings +- 🟡 **Yellow**: Low-severity, info +- 🟢 **Green**: Success messages +- 🔵 **Blue**: Headers, section dividers +- ⚪ **White**: Default text + +**Symbols:** +- `✓` Success +- `✗` Failure +- `⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏` Spinner animation +- `[████████░░]` Progress bars + +### 7.2 JSON Report Format + +**Schema:** +```json +{ + "metadata": { + "version": "2.0.0", + "timestamp": "2026-01-26T14:30:22Z", + "target": "https://api.example.com", + "duration_seconds": 142.5 + }, + "discovery": { + "injection_points": [ + { + "id": "param_prompt", + "type": "parameter", + "name": "prompt", + "location": "body" + } + ] + }, + "results": [ + { + "pattern_id": "role_override", + "category": "direct", + "severity": "high", + "status": "success", + "confidence": 0.95, + "injection_point": "param_prompt", + "payload": "[REDACTED]", + "response_indicators": ["system", "role"], + "cvss_score": 7.8 + } + ], + "summary": { + "total_tests": 100, + "successful_attacks": 12, + "success_rate": 0.12, + "vulnerabilities_by_severity": { + "high": 3, + "medium": 5, + "low": 4 + } + } +} +``` + +### 7.3 HTML Report Format + +**Features:** +- Interactive table with sorting/filtering +- Visual charts (bar chart of severity distribution) +- Collapsible sections for detailed attack logs +- Copy-to-clipboard buttons for payloads +- Responsive design (mobile-friendly) + +**Template:** Use Jinja2 or similar templating engine + +--- + +## 8. Non-Functional Requirements + +### 8.1 Performance +- **Discovery Phase**: < 10 seconds for typical API +- **Attack Phase**: Respects rate limiting, no server overload +- **Verification Phase**: < 5 seconds for 100 results +- **Reporting Phase**: < 2 seconds + +### 8.2 Reliability +- **Crash-Free**: Handle all HTTP errors gracefully +- **Resumable**: Save state on interrupt, allow resume +- **Idempotent**: Same input → same output (deterministic) + +### 8.3 Usability +- **Zero Learning Curve**: `pit scan --auto` should be self-explanatory +- **Progressive Disclosure**: Show simple output by default, verbose with `-v` +- **Helpful Defaults**: No configuration required for basic usage + +### 8.4 Security +- **Authorization Check**: Mandatory before running attacks +- **Token Handling**: Never log tokens, use env vars +- **Rate Limiting**: Prevent accidental DoS + +--- + +## 9. Future Extensions (Out of Scope for v2.0) + +- **Interactive Mode**: `pit scan ` without `--auto` prompts user at each phase +- **Plugin System**: Load custom attack patterns from external modules +- **Cloud Integration**: Upload reports to centralized dashboard +- **CI/CD Integration**: Exit codes for pipeline integration +- **Differential Testing**: Compare results across versions + +--- + +## 10. Acceptance Criteria + +**The PIT CLI is complete when:** + +1. ✅ User can run `pit scan --auto` and see visual feedback for all 4 phases +2. ✅ Phases execute sequentially (no concurrency errors) +3. ✅ Graceful error handling at every phase boundary +4. ✅ Generated reports match the JSON/HTML/YAML schemas +5. ✅ All output uses Rich TUI (progress bars, spinners, colored text) +6. ✅ Authorization is checked before running attacks +7. ✅ Rate limiting is respected to avoid DoS +8. ✅ Interrupts (`Ctrl+C`) are handled gracefully +9. ✅ Help text (`pit --help`) is clear and comprehensive +10. ✅ Zero crashes on invalid input (bad URLs, missing tokens, etc.) + +--- + +## Appendix A: ASCII Art Mockups + +### Full Scan Output +``` +┌─────────────────────────────────────────────────────────────────┐ +│ PIT - Prompt Injection Tester v2.0.0 │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ Target: https://api.openai.com/v1/chat/completions │ +│ Authorization: ✓ Confirmed │ +│ │ +├─────────────────────────────────────────────────────────────────┤ +│ [1/4] Discovery │ +│ ⠋ Discovering injection points... │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ ✓ Discovery Complete │ +│ ├─ Found 3 endpoints │ +│ ├─ Identified 12 parameters │ +│ └─ Detected 2 header injection points │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ [2/4] Attack Execution │ +│ Progress: [████████████████░░░░] 80/100 (80%) │ +│ Current: advanced/unicode_smuggling │ +│ Rate: 2.1 req/s | Elapsed: 00:38 | ETA: 00:10 │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ ✓ Attack Execution Complete │ +│ └─ Completed 100 attacks │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ [3/4] Verification │ +│ ⠸ Analyzing responses... │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ ✓ Verification Complete │ +│ ├─ 12 successful injections detected │ +│ ├─ 88 attacks blocked/failed │ +│ └─ 3 high-severity vulnerabilities found │ +└─────────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────────┐ +│ [4/4] Report Generation │ +│ │ +│ VULNERABILITY SUMMARY │ +│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │ +│ │ +│ Pattern ID │ Severity │ Status │ Confidence │ +│ ─────────────────────┼───────────┼───────────┼─────────── │ +│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │ +│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │ +│ context_override │ 🟠 MEDIUM │ ✓ Success │ 82% │ +│ payload_splitting │ 🟡 LOW │ ✗ Failed │ - │ +│ │ +│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │ +│ │ +│ 📄 Report saved: ./pit_report_20260126_143022.json │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +**END OF SPECIFICATION** diff --git a/tools/prompt_injection_tester/__init__.py b/tools/prompt_injection_tester/__init__.py new file mode 100644 index 0000000..8d6aefc --- /dev/null +++ b/tools/prompt_injection_tester/__init__.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Prompt Injection Tester + +A professional-grade automated testing framework for LLM prompt injection vulnerabilities. +Based on the AI LLM Red Team Handbook Chapter 14. + +Example: + from prompt_injection_tester import InjectionTester, AttackConfig + + async with InjectionTester( + target_url="https://api.example.com/v1", + auth_token="...", + ) as tester: + tester.authorize(scope=["all"]) + injection_points = await tester.discover_injection_points() + results = await tester.run_tests(injection_points=injection_points) + report = tester.generate_report(format="html") +""" + +__version__ = "1.0.0" +__author__ = "AI LLM Red Team Handbook" + +from .core import ( + AttackCategory, + AttackConfig, + AttackPayload, + AttackPattern, + DetectionMethod, + DetectionResult, + InjectionPoint, + InjectionPointType, + InjectionTester, + Severity, + TargetConfig, + TestContext, + TestResult, + TestStatus, + TestSuite, +) +from .patterns.registry import registry, register_pattern, get_pattern, list_patterns + +__all__ = [ + # Core + "InjectionTester", + "AttackConfig", + "TargetConfig", + "TestResult", + "TestSuite", + "InjectionPoint", + "AttackPattern", + "AttackPayload", + # Enums + "AttackCategory", + "InjectionPointType", + "DetectionMethod", + "Severity", + "TestStatus", + # Detection + "DetectionResult", + "TestContext", + # Registry + "registry", + "register_pattern", + "get_pattern", + "list_patterns", +] diff --git a/tools/prompt_injection_tester/core/__init__.py b/tools/prompt_injection_tester/core/__init__.py new file mode 100644 index 0000000..1b06b5e --- /dev/null +++ b/tools/prompt_injection_tester/core/__init__.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Core modules for the prompt injection tester.""" + +from .models import ( + AttackCategory, + AttackConfig, + AttackPayload, + AttackPattern, + DetectionMethod, + DetectionResult, + InjectionPoint, + InjectionPointType, + Severity, + TargetConfig, + TestContext, + TestResult, + TestStatus, + TestSuite, +) +from .tester import InjectionTester + +__all__ = [ + "AttackCategory", + "AttackConfig", + "AttackPayload", + "AttackPattern", + "DetectionMethod", + "DetectionResult", + "InjectionPoint", + "InjectionPointType", + "InjectionTester", + "Severity", + "TargetConfig", + "TestContext", + "TestResult", + "TestStatus", + "TestSuite", +] diff --git a/tools/prompt_injection_tester/detection/__init__.py b/tools/prompt_injection_tester/detection/__init__.py new file mode 100644 index 0000000..c4c0806 --- /dev/null +++ b/tools/prompt_injection_tester/detection/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +"""Detection and validation modules for identifying successful injections.""" + +from .base import BaseDetector, DetectorRegistry +from .system_prompt import SystemPromptLeakDetector +from .behavior_change import BehaviorChangeDetector +from .tool_misuse import ToolMisuseDetector +from .scoring import ConfidenceScorer + +__all__ = [ + "BaseDetector", + "DetectorRegistry", + "SystemPromptLeakDetector", + "BehaviorChangeDetector", + "ToolMisuseDetector", + "ConfidenceScorer", +] diff --git a/tools/prompt_injection_tester/patterns/__init__.py b/tools/prompt_injection_tester/patterns/__init__.py new file mode 100644 index 0000000..f4f774f --- /dev/null +++ b/tools/prompt_injection_tester/patterns/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +"""Attack pattern modules.""" + +from .base import BaseAttackPattern, MultiTurnAttackPattern, CompositeAttackPattern +from .registry import registry, register_pattern, get_pattern, list_patterns + +__all__ = [ + "BaseAttackPattern", + "MultiTurnAttackPattern", + "CompositeAttackPattern", + "registry", + "register_pattern", + "get_pattern", + "list_patterns", +] diff --git a/tools/prompt_injection_tester/patterns/advanced/__init__.py b/tools/prompt_injection_tester/patterns/advanced/__init__.py new file mode 100644 index 0000000..7acb572 --- /dev/null +++ b/tools/prompt_injection_tester/patterns/advanced/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" +Advanced Attack Patterns + +Sophisticated techniques including multi-turn attacks, payload fragmentation, +encoding/obfuscation, and language-based exploits. +Based on Chapter 14 advanced techniques sections. +""" + +from .multi_turn import ( + GradualEscalationPattern, + ContextBuildupPattern, + TrustEstablishmentPattern, +) +from .fragmentation import ( + PayloadFragmentationPattern, + TokenSmugglingPattern, + ChunkedInjectionPattern, +) +from .encoding import ( + Base64EncodingPattern, + UnicodeObfuscationPattern, + LanguageSwitchingPattern, + LeetSpeakPattern, +) + +__all__ = [ + # Multi-turn + "GradualEscalationPattern", + "ContextBuildupPattern", + "TrustEstablishmentPattern", + # Fragmentation + "PayloadFragmentationPattern", + "TokenSmugglingPattern", + "ChunkedInjectionPattern", + # Encoding/Obfuscation + "Base64EncodingPattern", + "UnicodeObfuscationPattern", + "LanguageSwitchingPattern", + "LeetSpeakPattern", +] diff --git a/tools/prompt_injection_tester/patterns/direct/__init__.py b/tools/prompt_injection_tester/patterns/direct/__init__.py new file mode 100644 index 0000000..a9392f9 --- /dev/null +++ b/tools/prompt_injection_tester/patterns/direct/__init__.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +""" +Direct Injection Attack Patterns + +These patterns target the primary LLM input channel where attacker-controlled +content is directly processed by the model. Based on Chapter 14 techniques. +""" + +from .instruction_override import ( + InstructionOverridePattern, + SystemPromptOverridePattern, + TaskHijackingPattern, +) +from .role_manipulation import ( + RoleAuthorityPattern, + PersonaShiftPattern, + DeveloperModePattern, +) +from .delimiter_confusion import ( + DelimiterEscapePattern, + XMLInjectionPattern, + MarkdownInjectionPattern, +) + +__all__ = [ + # Instruction Override + "InstructionOverridePattern", + "SystemPromptOverridePattern", + "TaskHijackingPattern", + # Role Manipulation + "RoleAuthorityPattern", + "PersonaShiftPattern", + "DeveloperModePattern", + # Delimiter Confusion + "DelimiterEscapePattern", + "XMLInjectionPattern", + "MarkdownInjectionPattern", +] diff --git a/tools/prompt_injection_tester/patterns/indirect/__init__.py b/tools/prompt_injection_tester/patterns/indirect/__init__.py new file mode 100644 index 0000000..94b7a75 --- /dev/null +++ b/tools/prompt_injection_tester/patterns/indirect/__init__.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +""" +Indirect Injection Attack Patterns + +These patterns target secondary data sources that the LLM processes, +such as retrieved documents, web pages, emails, and database content. +Based on Chapter 14 indirect injection techniques. +""" + +from .document_poisoning import ( + RAGPoisoningPattern, + DocumentMetadataInjectionPattern, + HiddenTextInjectionPattern, +) +from .web_injection import ( + WebPageInjectionPattern, + SEOPoisoningPattern, + CommentInjectionPattern, +) +from .email_injection import ( + EmailBodyInjectionPattern, + EmailHeaderInjectionPattern, + AttachmentInjectionPattern, +) + +__all__ = [ + # Document Poisoning + "RAGPoisoningPattern", + "DocumentMetadataInjectionPattern", + "HiddenTextInjectionPattern", + # Web Injection + "WebPageInjectionPattern", + "SEOPoisoningPattern", + "CommentInjectionPattern", + # Email Injection + "EmailBodyInjectionPattern", + "EmailHeaderInjectionPattern", + "AttachmentInjectionPattern", +] diff --git a/tools/prompt_injection_tester/pit/__init__.py b/tools/prompt_injection_tester/pit/__init__.py new file mode 100644 index 0000000..aff76ad --- /dev/null +++ b/tools/prompt_injection_tester/pit/__init__.py @@ -0,0 +1,10 @@ +""" +Prompt Injection Tester (PIT) - Modern CLI Interface + +A premium terminal experience for LLM security assessment. +""" + +__version__ = "1.0.0" +__all__ = ["app"] + +from pit.app import app diff --git a/tools/prompt_injection_tester/pit/cli.py b/tools/prompt_injection_tester/pit/cli.py new file mode 100644 index 0000000..f882752 --- /dev/null +++ b/tools/prompt_injection_tester/pit/cli.py @@ -0,0 +1,381 @@ +"""Main CLI application using Typer.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import List, Optional + +import typer +from rich.console import Console + +from pit import __version__ +from pit.config import Config, load_config +from pit.errors import handle_error + +app = typer.Typer( + name="pit", + help="Prompt Injection Tester - Automated LLM Security Testing", + add_completion=False, +) + +console = Console() + + +@app.command() +def scan( + target_url: str = typer.Argument( + ..., + help="Target API endpoint URL", + ), + token: Optional[str] = typer.Option( + None, + "--token", + "-t", + help="Authentication token (or use $PIT_TOKEN)", + envvar="PIT_TOKEN", + ), + config_file: Optional[Path] = typer.Option( + None, + "--config", + "-c", + help="Configuration file (YAML)", + exists=True, + ), + auto: bool = typer.Option( + False, + "--auto", + "-a", + help="Run all phases automatically (non-interactive)", + ), + api_type: str = typer.Option( + "openai", + "--api-type", + help="API type (openai, anthropic, custom)", + ), + timeout: int = typer.Option( + 30, + "--timeout", + help="Request timeout in seconds", + ), + categories: Optional[List[str]] = typer.Option( + None, + "--categories", + help="Attack categories (direct, indirect, advanced)", + ), + patterns: Optional[List[str]] = typer.Option( + None, + "--patterns", + help="Specific pattern IDs to test", + ), + max_concurrent: int = typer.Option( + 5, + "--max-concurrent", + help="Maximum concurrent requests", + ), + rate_limit: float = typer.Option( + 1.0, + "--rate-limit", + help="Requests per second", + ), + output: Optional[Path] = typer.Option( + None, + "--output", + "-o", + help="Output file path (auto-generated if not specified)", + ), + output_format: str = typer.Option( + "json", + "--format", + "-f", + help="Report format (json, yaml, html)", + ), + include_cvss: bool = typer.Option( + True, + "--include-cvss/--no-cvss", + help="Include CVSS scores in report", + ), + include_payloads: bool = typer.Option( + False, + "--include-payloads/--no-payloads", + help="Include attack payloads in report", + ), + authorize: bool = typer.Option( + False, + "--authorize", + help="Confirm authorization (skip interactive prompt)", + ), + verbose: bool = typer.Option( + False, + "--verbose", + "-v", + help="Enable verbose output", + ), + quiet: bool = typer.Option( + False, + "--quiet", + "-q", + help="Suppress all output except errors", + ), +): + """ + Scan a target API for prompt injection vulnerabilities. + + Example: + pit scan https://api.example.com/v1/chat --auto --token $TOKEN + """ + import asyncio + + try: + # Load configuration + config = load_config( + config_path=config_file, + target_url=target_url, + token=token, + api_type=api_type, + timeout=timeout, + categories=categories, + patterns=patterns, + max_concurrent=max_concurrent, + rate_limit=rate_limit, + output_format=output_format, + output=output, + include_cvss=include_cvss, + include_payloads=include_payloads, + authorize=authorize, + ) + + # Verify authorization + if not config.authorization.confirmed and not auto: + if not _prompt_authorization(target_url): + console.print("[yellow]Scan cancelled by user[/yellow]") + raise typer.Exit(0) + config.authorization.confirmed = True + + # Print banner + if not quiet: + _print_banner(target_url) + + # Run scan + exit_code = asyncio.run(_run_scan(config, verbose, quiet)) + raise typer.Exit(exit_code) + + except KeyboardInterrupt: + console.print("\n[yellow]⚠ Interrupted by user[/yellow]") + raise typer.Exit(130) + + except Exception as e: + exit_code = handle_error(e, verbose) + raise typer.Exit(exit_code) + + +@app.command() +def list_command( + type: str = typer.Argument( + "patterns", + help="What to list (patterns, categories)", + ), +): + """ + List available attack patterns or categories. + + Example: + pit list patterns + pit list categories + """ + from patterns.registry import registry + from core.models import AttackCategory + + if type == "patterns": + _list_patterns(registry) + elif type == "categories": + _list_categories() + else: + console.print(f"[red]Unknown type: {type}[/red]") + console.print("Available types: patterns, categories") + raise typer.Exit(1) + + +app.command(name="list")(list_command) + + +@app.command() +def auth( + target_url: str = typer.Argument( + ..., + help="Target API endpoint URL", + ), +): + """ + Verify authorization to test a target. + + Example: + pit auth https://api.example.com + """ + if _prompt_authorization(target_url): + console.print("[green]✓ Authorization confirmed[/green]") + raise typer.Exit(0) + else: + console.print("[yellow]Authorization not confirmed[/yellow]") + raise typer.Exit(1) + + +@app.callback(invoke_without_command=True) +def main( + ctx: typer.Context, + version: bool = typer.Option( + False, + "--version", + help="Show version and exit", + ), +): + """PIT - Prompt Injection Tester""" + if version: + console.print(f"[cyan]pit[/cyan] version {__version__}") + raise typer.Exit(0) + + if ctx.invoked_subcommand is None: + console.print(ctx.get_help()) + + +# Helper functions + + +async def _run_scan(config: Config, verbose: bool, quiet: bool) -> int: + """ + Run the scan pipeline. + + Args: + config: Configuration + verbose: Verbose output + quiet: Quiet mode + + Returns: + Exit code (0 for success, non-zero for error) + """ + from pit.orchestrator.pipeline import Pipeline, PipelineContext + from pit.orchestrator.phases import ( + AttackPhase, + DiscoveryPhase, + ReportingPhase, + VerificationPhase, + ) + + # Create pipeline + pipeline = Pipeline( + phases=[ + DiscoveryPhase(), + AttackPhase(), + VerificationPhase(), + ReportingPhase(), + ] + ) + + # Create context + context = PipelineContext( + target_url=config.target.url, + config=config, + ) + + # Run pipeline + context = await pipeline.run(context) + + # Determine exit code + if context.interrupted: + return 130 + elif context.report: + # Check if vulnerabilities found + successful = context.report.get("summary", {}).get("successful_attacks", 0) + return 2 if successful > 0 else 0 + else: + return 1 + + +def _prompt_authorization(target_url: str) -> bool: + """ + Prompt user for authorization confirmation. + + Args: + target_url: Target URL + + Returns: + True if authorized, False otherwise + """ + from rich.panel import Panel + + panel = Panel( + f"[yellow]Target:[/yellow] {target_url}\n\n" + "[yellow]⚠ You must have explicit authorization to test this system.[/yellow]\n" + "[yellow]Unauthorized testing may be illegal in your jurisdiction.[/yellow]\n\n" + "Do you have authorization to test this target?", + title="[bold red]Authorization Required[/bold red]", + border_style="red", + ) + + console.print() + console.print(panel) + console.print() + + response = typer.prompt("Confirm [y/N]", default="n") + return response.lower() == "y" + + +def _print_banner(target_url: str) -> None: + """Print application banner.""" + from rich.panel import Panel + + banner = f"""[bold cyan]PIT - Prompt Injection Tester[/bold cyan] v{__version__} + +[cyan]Target:[/cyan] {target_url} +[cyan]Authorization:[/cyan] ✓ Confirmed +""" + + console.print() + console.print(Panel(banner, border_style="cyan", expand=False)) + + +def _list_patterns(registry) -> None: + """List available attack patterns.""" + from core.models import AttackCategory + from rich.table import Table + + if len(registry) == 0: + registry.load_builtin_patterns() + + table = Table(title="Available Attack Patterns", show_header=True) + table.add_column("Category", style="cyan") + table.add_column("Pattern ID", style="yellow") + table.add_column("Description", style="white") + + for category in AttackCategory: + patterns = registry.list_by_category(category) + for i, pid in enumerate(patterns): + pattern_class = registry.get(pid) + if pattern_class: + table.add_row( + category.value if i == 0 else "", + pid, + getattr(pattern_class, "name", pid), + ) + + console.print(table) + + +def _list_categories() -> None: + """List attack categories.""" + from core.models import AttackCategory + from rich.table import Table + + table = Table(title="Attack Categories", show_header=True) + table.add_column("Category", style="cyan") + table.add_column("Description", style="white") + + for category in AttackCategory: + table.add_row(category.value, f"{category.value} injection attacks") + + console.print(table) + + +if __name__ == "__main__": + app() diff --git a/tools/prompt_injection_tester/pit/commands/__init__.py b/tools/prompt_injection_tester/pit/commands/__init__.py new file mode 100644 index 0000000..6357653 --- /dev/null +++ b/tools/prompt_injection_tester/pit/commands/__init__.py @@ -0,0 +1,7 @@ +""" +CLI command modules. +""" + +from pit.commands import scan + +__all__ = ["scan"] diff --git a/tools/prompt_injection_tester/pit/config/__init__.py b/tools/prompt_injection_tester/pit/config/__init__.py new file mode 100644 index 0000000..ac11440 --- /dev/null +++ b/tools/prompt_injection_tester/pit/config/__init__.py @@ -0,0 +1,20 @@ +"""Configuration module.""" + +from .loader import load_config, load_config_file +from .schema import ( + AttackConfig, + AuthorizationConfig, + Config, + ReportingConfig, + TargetConfig, +) + +__all__ = [ + "Config", + "TargetConfig", + "AttackConfig", + "ReportingConfig", + "AuthorizationConfig", + "load_config", + "load_config_file", +] diff --git a/tools/prompt_injection_tester/pit/config/loader.py b/tools/prompt_injection_tester/pit/config/loader.py new file mode 100644 index 0000000..9a379bf --- /dev/null +++ b/tools/prompt_injection_tester/pit/config/loader.py @@ -0,0 +1,82 @@ +"""Configuration file loader.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +import yaml + +from .schema import Config + + +def load_config_file(config_path: Path) -> Config: + """ + Load configuration from a YAML file. + + Args: + config_path: Path to YAML configuration file + + Returns: + Loaded configuration + + Raises: + FileNotFoundError: If config file doesn't exist + ValueError: If config file is invalid + """ + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + with open(config_path) as f: + data = yaml.safe_load(f) + + # Expand environment variables in token + if "target" in data and "token" in data["target"]: + token = data["target"]["token"] + if isinstance(token, str) and token.startswith("${") and token.endswith("}"): + env_var = token[2:-1] + data["target"]["token"] = os.getenv(env_var) + + try: + return Config(**data) + except Exception as e: + raise ValueError(f"Invalid configuration file: {e}") from e + + +def load_config( + config_path: Optional[Path], + target_url: Optional[str] = None, + **cli_overrides, +) -> Config: + """ + Load configuration from file or CLI arguments. + + Args: + config_path: Path to configuration file (optional) + target_url: Target URL from CLI (optional) + **cli_overrides: CLI argument overrides + + Returns: + Merged configuration + + Raises: + ValueError: If neither config_path nor target_url is provided + """ + if config_path: + config = load_config_file(config_path) + + # Apply CLI overrides + if target_url: + config.target.url = target_url + if "token" in cli_overrides and cli_overrides["token"]: + config.target.token = cli_overrides["token"] + + return config + + elif target_url: + # Build config from CLI args + return Config.from_cli_args(target_url=target_url, **cli_overrides) + + else: + raise ValueError("Either --config or target URL must be provided") diff --git a/tools/prompt_injection_tester/pit/config/schema.py b/tools/prompt_injection_tester/pit/config/schema.py new file mode 100644 index 0000000..e855870 --- /dev/null +++ b/tools/prompt_injection_tester/pit/config/schema.py @@ -0,0 +1,148 @@ +"""Configuration schema using Pydantic.""" + +from __future__ import annotations + +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel, Field, field_validator + + +class TargetConfig(BaseModel): + """Target API configuration.""" + + url: str = Field(..., description="Target API endpoint URL") + token: Optional[str] = Field(None, description="Authentication token") + api_type: str = Field(default="openai", description="API type (openai, anthropic, custom)") + timeout: int = Field(default=30, ge=1, le=300, description="Request timeout in seconds") + + @field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + """Ensure URL is properly formatted.""" + if not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") + return v + + +class AttackConfig(BaseModel): + """Attack execution configuration.""" + + categories: List[str] = Field( + default=["direct", "indirect"], + description="Attack categories to test", + ) + patterns: Optional[List[str]] = Field( + None, + description="Specific pattern IDs to test (overrides categories)", + ) + exclude_patterns: List[str] = Field( + default_factory=list, + description="Pattern IDs to exclude", + ) + max_concurrent: int = Field( + default=5, + ge=1, + le=20, + description="Maximum concurrent requests", + ) + rate_limit: float = Field( + default=1.0, + ge=0.1, + le=10.0, + description="Requests per second", + ) + + +class ReportingConfig(BaseModel): + """Report generation configuration.""" + + format: str = Field( + default="json", + description="Report format (json, yaml, html)", + ) + output: Optional[Path] = Field( + None, + description="Output file path (auto-generated if not specified)", + ) + include_cvss: bool = Field( + default=True, + description="Include CVSS scores in report", + ) + include_payloads: bool = Field( + default=False, + description="Include attack payloads in report (may be sensitive)", + ) + + @field_validator("format") + @classmethod + def validate_format(cls, v: str) -> str: + """Ensure format is supported.""" + allowed = ["json", "yaml", "html"] + if v not in allowed: + raise ValueError(f"Format must be one of {allowed}") + return v + + +class AuthorizationConfig(BaseModel): + """Authorization configuration.""" + + scope: List[str] = Field( + default=["all"], + description="Authorization scope", + ) + confirmed: bool = Field( + default=False, + description="Skip interactive authorization prompt", + ) + + +class Config(BaseModel): + """Complete PIT configuration.""" + + target: TargetConfig + attack: AttackConfig = Field(default_factory=AttackConfig) + reporting: ReportingConfig = Field(default_factory=ReportingConfig) + authorization: AuthorizationConfig = Field(default_factory=AuthorizationConfig) + + @classmethod + def from_cli_args( + cls, + target_url: str, + token: Optional[str] = None, + api_type: str = "openai", + timeout: int = 30, + categories: Optional[List[str]] = None, + patterns: Optional[List[str]] = None, + max_concurrent: int = 5, + rate_limit: float = 1.0, + output_format: str = "json", + output: Optional[Path] = None, + include_cvss: bool = True, + include_payloads: bool = False, + authorize: bool = False, + ) -> Config: + """Create config from CLI arguments.""" + return cls( + target=TargetConfig( + url=target_url, + token=token, + api_type=api_type, + timeout=timeout, + ), + attack=AttackConfig( + categories=categories or ["direct", "indirect"], + patterns=patterns, + max_concurrent=max_concurrent, + rate_limit=rate_limit, + ), + reporting=ReportingConfig( + format=output_format, + output=output, + include_cvss=include_cvss, + include_payloads=include_payloads, + ), + authorization=AuthorizationConfig( + confirmed=authorize, + ), + ) diff --git a/tools/prompt_injection_tester/pit/errors/__init__.py b/tools/prompt_injection_tester/pit/errors/__init__.py new file mode 100644 index 0000000..47788c9 --- /dev/null +++ b/tools/prompt_injection_tester/pit/errors/__init__.py @@ -0,0 +1,47 @@ +"""Error handling module.""" + +from .exceptions import ( + AttackError, + AuthenticationError, + ConfigError, + ConfigNotFoundError, + ConfigValidationError, + DetectionFailedError, + DiscoveryError, + ExecutionError, + FileWriteError, + FormatError, + NoEndpointsFoundError, + PatternLoadError, + PitError, + RateLimitError, + ReportingError, + ScanFailedError, + TargetError, + TargetUnreachableError, + VerificationError, +) +from .handlers import handle_error + +__all__ = [ + "PitError", + "ConfigError", + "ConfigNotFoundError", + "ConfigValidationError", + "TargetError", + "TargetUnreachableError", + "AuthenticationError", + "RateLimitError", + "DiscoveryError", + "NoEndpointsFoundError", + "ScanFailedError", + "AttackError", + "PatternLoadError", + "ExecutionError", + "VerificationError", + "DetectionFailedError", + "ReportingError", + "FormatError", + "FileWriteError", + "handle_error", +] diff --git a/tools/prompt_injection_tester/pit/errors/exceptions.py b/tools/prompt_injection_tester/pit/errors/exceptions.py new file mode 100644 index 0000000..1bf68f8 --- /dev/null +++ b/tools/prompt_injection_tester/pit/errors/exceptions.py @@ -0,0 +1,131 @@ +"""Custom exception hierarchy for PIT.""" + +from __future__ import annotations + + +class PitError(Exception): + """Base exception for all PIT errors.""" + + pass + + +# Configuration Errors +class ConfigError(PitError): + """Base configuration error.""" + + pass + + +class ConfigNotFoundError(ConfigError): + """Configuration file not found.""" + + pass + + +class ConfigValidationError(ConfigError): + """Configuration validation failed.""" + + pass + + +# Target Errors +class TargetError(PitError): + """Base target error.""" + + pass + + +class TargetUnreachableError(TargetError): + """Target is unreachable.""" + + def __init__(self, url: str, reason: str = ""): + self.url = url + self.reason = reason + super().__init__(f"Target unreachable: {url}") + + +class AuthenticationError(TargetError): + """Authentication failed.""" + + def __init__(self, status_code: int, message: str = ""): + self.status_code = status_code + self.message = message + super().__init__(f"Authentication failed: {status_code}") + + +class RateLimitError(TargetError): + """Rate limit exceeded.""" + + def __init__(self, retry_after: int = 60): + self.retry_after = retry_after + super().__init__(f"Rate limit exceeded, retry after {retry_after}s") + + +# Discovery Errors +class DiscoveryError(PitError): + """Base discovery error.""" + + pass + + +class NoEndpointsFoundError(DiscoveryError): + """No endpoints found during discovery.""" + + pass + + +class ScanFailedError(DiscoveryError): + """Discovery scan failed.""" + + pass + + +# Attack Errors +class AttackError(PitError): + """Base attack error.""" + + pass + + +class PatternLoadError(AttackError): + """Failed to load attack pattern.""" + + pass + + +class ExecutionError(AttackError): + """Attack execution failed.""" + + pass + + +# Verification Errors +class VerificationError(PitError): + """Base verification error.""" + + pass + + +class DetectionFailedError(VerificationError): + """Detection/verification failed.""" + + pass + + +# Reporting Errors +class ReportingError(PitError): + """Base reporting error.""" + + pass + + +class FormatError(ReportingError): + """Invalid report format.""" + + pass + + +class FileWriteError(ReportingError): + """Failed to write report file.""" + + pass diff --git a/tools/prompt_injection_tester/pit/errors/handlers.py b/tools/prompt_injection_tester/pit/errors/handlers.py new file mode 100644 index 0000000..becd070 --- /dev/null +++ b/tools/prompt_injection_tester/pit/errors/handlers.py @@ -0,0 +1,84 @@ +"""User-friendly error handlers.""" + +from __future__ import annotations + +from rich.console import Console + +from .exceptions import ( + AuthenticationError, + ConfigError, + NoEndpointsFoundError, + PitError, + RateLimitError, + TargetUnreachableError, +) + +console = Console() + + +def handle_error(error: Exception, verbose: bool = False) -> int: + """ + Convert exceptions to user-friendly messages. + + Args: + error: The exception to handle + verbose: Show detailed error information + + Returns: + Exit code (0 for success, non-zero for error) + """ + if isinstance(error, TargetUnreachableError): + console.print("[red]✗ Target Unreachable[/red]") + console.print(f" ├─ Could not connect to: {error.url}") + if error.reason: + console.print(f" ├─ Reason: {error.reason}") + console.print(" ├─ Suggestion: Check the URL and network connection") + console.print(" └─ Or use: [cyan]pit scan --skip-discovery --injection-points [/cyan]") + return 1 + + elif isinstance(error, AuthenticationError): + console.print("[red]✗ Authentication Failed[/red]") + console.print(f" ├─ Status code: {error.status_code}") + if error.message: + console.print(f" ├─ Message: {error.message}") + console.print(" ├─ Suggestion: Verify your API token with [cyan]--token[/cyan]") + console.print(" └─ Or run: [cyan]pit auth [/cyan]") + return 1 + + elif isinstance(error, RateLimitError): + console.print("[yellow]⚠ Rate Limited[/yellow]") + console.print(f" ├─ Retry after: {error.retry_after} seconds") + console.print(" └─ Suggestion: Reduce [cyan]--rate-limit[/cyan] or [cyan]--max-concurrent[/cyan]") + return 1 + + elif isinstance(error, NoEndpointsFoundError): + console.print("[yellow]⚠ No Endpoints Found[/yellow]") + console.print(" ├─ Discovery phase found no injection points") + console.print(" ├─ Suggestion 1: Specify endpoints manually with [cyan]--injection-points [/cyan]") + console.print(" └─ Suggestion 2: Check if the target URL is correct") + return 1 + + elif isinstance(error, ConfigError): + console.print("[red]✗ Configuration Error[/red]") + console.print(f" ├─ {str(error)}") + console.print(" └─ Suggestion: Check your configuration file or CLI arguments") + return 1 + + elif isinstance(error, PitError): + console.print(f"[red]✗ Error: {error.__class__.__name__}[/red]") + console.print(f" └─ {str(error)}") + return 1 + + elif isinstance(error, KeyboardInterrupt): + console.print("\n[yellow]⚠ Scan Interrupted[/yellow]") + console.print(" └─ Use [cyan]Ctrl+C[/cyan] again to force quit") + return 130 + + else: + console.print(f"[red]✗ Unexpected Error: {error.__class__.__name__}[/red]") + console.print(f" └─ {str(error)}") + if verbose: + console.print_exception() + else: + console.print(" └─ Run with [cyan]--verbose[/cyan] for details") + return 1 diff --git a/tools/prompt_injection_tester/pit/orchestrator/__init__.py b/tools/prompt_injection_tester/pit/orchestrator/__init__.py new file mode 100644 index 0000000..4351b92 --- /dev/null +++ b/tools/prompt_injection_tester/pit/orchestrator/__init__.py @@ -0,0 +1,8 @@ +""" +Orchestrator layer for coordinating CLI and core framework. +""" + +from pit.orchestrator.workflow import WorkflowOrchestrator +from pit.orchestrator.discovery import AutoDiscovery + +__all__ = ["WorkflowOrchestrator", "AutoDiscovery"] diff --git a/tools/prompt_injection_tester/pit/orchestrator/phases.py b/tools/prompt_injection_tester/pit/orchestrator/phases.py new file mode 100644 index 0000000..f962b6f --- /dev/null +++ b/tools/prompt_injection_tester/pit/orchestrator/phases.py @@ -0,0 +1,495 @@ +"""Phase definitions for the sequential pipeline.""" + +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from rich.console import Console + +if TYPE_CHECKING: + from pit.orchestrator.pipeline import PipelineContext + +console = Console() + + +class PhaseStatus(Enum): + """Phase execution status.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class PhaseResult: + """Result from a phase execution.""" + + status: PhaseStatus + message: Optional[str] = None + error: Optional[str] = None + data: Optional[Dict[str, Any]] = None + + +class Phase(ABC): + """ + Abstract base class for pipeline phases. + + Each phase MUST complete fully before returning. + """ + + name: str = "Base Phase" + + @abstractmethod + async def execute(self, context: PipelineContext) -> PhaseResult: + """ + Execute the phase. + + MUST complete before returning (no background tasks). + + Args: + context: Pipeline context (shared state) + + Returns: + Phase result with status and optional data + + Raises: + Exception: If phase fails unexpectedly + """ + pass + + +class DiscoveryPhase(Phase): + """ + Phase 1: Discovery + + Scans the target for injection points. + """ + + name = "Discovery" + + async def execute(self, context: PipelineContext) -> PhaseResult: + """ + Execute discovery phase. + + Args: + context: Pipeline context + + Returns: + Phase result with discovered injection points + """ + from pit.ui.progress import create_spinner + + console.print("[cyan]Discovering injection points...[/cyan]") + + try: + with create_spinner(): + task = asyncio.create_task( + self._discover_injection_points(context) + ) + # WAIT for discovery to complete + injection_points = await task + + context.injection_points = injection_points + + if not injection_points: + return PhaseResult( + status=PhaseStatus.FAILED, + error="No injection points found", + ) + + message = ( + f"Found {len(injection_points)} injection point(s)" + ) + + return PhaseResult( + status=PhaseStatus.COMPLETED, + message=message, + data={"injection_points": injection_points}, + ) + + except Exception as e: + return PhaseResult( + status=PhaseStatus.FAILED, + error=str(e), + ) + + async def _discover_injection_points( + self, context: PipelineContext + ) -> List[Any]: + """ + Run discovery logic. + + TODO: Implement actual discovery logic using discovery module. + + Args: + context: Pipeline context + + Returns: + List of discovered injection points + """ + # Placeholder: Replace with actual discovery implementation + from core.models import InjectionPoint, InjectionPointType + + # Simulate discovery + await asyncio.sleep(1) + + # Mock injection points for now + return [ + InjectionPoint( + id="param_prompt", + type=InjectionPointType.PARAMETER, + name="prompt", + location="body", + ) + ] + + +class AttackPhase(Phase): + """ + Phase 2: Attack Execution + + Executes attack patterns against discovered injection points. + """ + + name = "Attack Execution" + + async def execute(self, context: PipelineContext) -> PhaseResult: + """ + Execute attack phase. + + Args: + context: Pipeline context + + Returns: + Phase result with test results + """ + from pit.ui.progress import create_progress_bar + + injection_points = context.injection_points + + if not injection_points: + return PhaseResult( + status=PhaseStatus.FAILED, + error="No injection points available", + ) + + try: + # Load attack patterns + patterns = await self._load_patterns(context) + + console.print(f"[cyan]Loaded {len(patterns)} attack pattern(s)[/cyan]") + + # Execute attacks with progress bar + results = [] + with create_progress_bar() as progress: + task = progress.add_task( + "Running attacks", + total=len(patterns), + ) + + for pattern in patterns: + # Execute attack (internal concurrency OK) + result = await self._execute_attack( + pattern, injection_points[0], context + ) + results.append(result) + progress.update(task, advance=1) + # Respect rate limiting + await asyncio.sleep(1.0 / context.config.attack.rate_limit) + + context.test_results = results + + message = f"Completed {len(results)} attack(s)" + + return PhaseResult( + status=PhaseStatus.COMPLETED, + message=message, + data={"test_results": results}, + ) + + except Exception as e: + return PhaseResult( + status=PhaseStatus.FAILED, + error=str(e), + ) + + async def _load_patterns(self, context: PipelineContext) -> List[Any]: + """ + Load attack patterns from registry. + + Args: + context: Pipeline context + + Returns: + List of attack patterns + """ + from patterns.registry import registry + + # Load patterns based on config + categories = context.config.attack.categories + + all_patterns = [] + for category in categories: + patterns = registry.list_by_category(category) + all_patterns.extend(patterns) + + # Return pattern IDs for now + # TODO: Return actual pattern instances + return all_patterns[:10] # Limit for demo + + async def _execute_attack( + self, pattern: Any, injection_point: Any, context: PipelineContext + ) -> Any: + """ + Execute a single attack pattern. + + Args: + pattern: Attack pattern + injection_point: Target injection point + context: Pipeline context + + Returns: + Test result + """ + from core.models import TestResult, TestStatus + + # Simulate attack execution + await asyncio.sleep(0.1) + + # Mock result + return TestResult( + pattern_id=str(pattern), + injection_point_id=injection_point.id, + status=TestStatus.SUCCESS, + payload="test_payload", + response=None, + ) + + +class VerificationPhase(Phase): + """ + Phase 3: Verification + + Analyzes attack responses to verify success. + """ + + name = "Verification" + + async def execute(self, context: PipelineContext) -> PhaseResult: + """ + Execute verification phase. + + Args: + context: Pipeline context + + Returns: + Phase result with verified results + """ + from pit.ui.progress import create_spinner + + test_results = context.test_results + + if not test_results: + return PhaseResult( + status=PhaseStatus.FAILED, + error="No test results to verify", + ) + + console.print("[cyan]Analyzing responses...[/cyan]") + + try: + with create_spinner(): + # WAIT for verification to complete + verified = await self._verify_results(test_results, context) + + context.verified_results = verified + + successful = sum(1 for r in verified if r.get("status") == "success") + message = f"Verified {len(verified)} result(s), {successful} successful" + + return PhaseResult( + status=PhaseStatus.COMPLETED, + message=message, + data={"verified_results": verified}, + ) + + except Exception as e: + return PhaseResult( + status=PhaseStatus.FAILED, + error=str(e), + ) + + async def _verify_results( + self, test_results: List[Any], context: PipelineContext + ) -> List[Dict[str, Any]]: + """ + Verify test results. + + Args: + test_results: List of test results + context: Pipeline context + + Returns: + List of verified results with confidence scores + """ + # Simulate verification + await asyncio.sleep(1) + + # Mock verified results + verified = [] + for result in test_results: + verified.append({ + "pattern_id": result.pattern_id, + "status": "success" if "test" in result.pattern_id else "failed", + "severity": "medium", + "confidence": 0.85, + }) + + return verified + + +class ReportingPhase(Phase): + """ + Phase 4: Reporting + + Generates and saves the final report. + """ + + name = "Report Generation" + + async def execute(self, context: PipelineContext) -> PhaseResult: + """ + Execute reporting phase. + + Args: + context: Pipeline context + + Returns: + Phase result with report path + """ + verified_results = context.verified_results + + if not verified_results: + return PhaseResult( + status=PhaseStatus.FAILED, + error="No results to report", + ) + + try: + # Generate report + report = await self._generate_report(verified_results, context) + context.report = report + + # Save to file + report_path = await self._save_report(report, context) + context.report_path = report_path + + # Display summary + self._display_summary(report, report_path) + + message = f"Report saved to {report_path}" + + return PhaseResult( + status=PhaseStatus.COMPLETED, + message=message, + data={"report_path": str(report_path)}, + ) + + except Exception as e: + return PhaseResult( + status=PhaseStatus.FAILED, + error=str(e), + ) + + async def _generate_report( + self, verified_results: List[Dict[str, Any]], context: PipelineContext + ) -> Dict[str, Any]: + """ + Generate report data structure. + + Args: + verified_results: Verified results + context: Pipeline context + + Returns: + Report dictionary + """ + successful = [r for r in verified_results if r.get("status") == "success"] + total = len(verified_results) + success_rate = len(successful) / total if total > 0 else 0.0 + + report = { + "metadata": { + "version": "2.0.0", + "timestamp": datetime.now().isoformat(), + "target": context.target_url, + "duration_seconds": sum(context.phase_durations.values()), + }, + "summary": { + "total_tests": total, + "successful_attacks": len(successful), + "success_rate": success_rate, + }, + "results": verified_results, + } + + return report + + async def _save_report( + self, report: Dict[str, Any], context: PipelineContext + ) -> Path: + """ + Save report to file. + + Args: + report: Report data + context: Pipeline context + + Returns: + Path to saved report + """ + import json + + output_path = context.config.reporting.output + + if not output_path: + # Auto-generate filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = Path(f"pit_report_{timestamp}.json") + + with open(output_path, "w") as f: + json.dump(report, f, indent=2) + + return output_path + + def _display_summary( + self, report: Dict[str, Any], report_path: Path + ) -> None: + """ + Display summary to console. + + Args: + report: Report data + report_path: Path to saved report + """ + from pit.ui.tables import create_summary_panel + + summary = report["summary"] + + panel = create_summary_panel( + total_tests=summary["total_tests"], + successful_attacks=summary["successful_attacks"], + success_rate=summary["success_rate"], + vulnerabilities_by_severity={}, + report_path=str(report_path), + ) + + console.print() + console.print(panel) diff --git a/tools/prompt_injection_tester/pit/orchestrator/pipeline.py b/tools/prompt_injection_tester/pit/orchestrator/pipeline.py new file mode 100644 index 0000000..95f7a84 --- /dev/null +++ b/tools/prompt_injection_tester/pit/orchestrator/pipeline.py @@ -0,0 +1,157 @@ +"""Sequential pipeline executor for running phases one-by-one.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from rich.console import Console +from rich.panel import Panel + +from pit.config import Config +from pit.errors import PitError +from pit.orchestrator.phases import Phase, PhaseResult, PhaseStatus + +console = Console() + + +@dataclass +class PipelineContext: + """ + Shared state across pipeline phases. + + This context is passed through each phase sequentially, + allowing phases to read input from previous phases. + """ + + target_url: str + config: Config + + # Phase 1 output (Discovery) + injection_points: List[Any] = field(default_factory=list) + + # Phase 2 output (Attack) + test_results: List[Any] = field(default_factory=list) + + # Phase 3 output (Verification) + verified_results: List[Any] = field(default_factory=list) + + # Phase 4 output (Reporting) + report: Optional[Dict[str, Any]] = None + report_path: Optional[Path] = None + + # Metadata + start_time: datetime = field(default_factory=datetime.now) + phase_durations: Dict[str, float] = field(default_factory=dict) + + # Interrupt handling + interrupted: bool = False + + +class Pipeline: + """ + Sequential phase executor. + + Runs phases one-by-one, waiting for each to complete before + starting the next. This ensures no concurrency errors. + """ + + def __init__(self, phases: List[Phase]): + """ + Initialize pipeline with phases. + + Args: + phases: List of phases to execute sequentially + """ + self.phases = phases + + async def run(self, context: PipelineContext) -> PipelineContext: + """ + Run all phases sequentially. + + Each phase MUST complete before the next begins. + + Args: + context: Pipeline context (shared state) + + Returns: + Updated context with results + + Raises: + PitError: If any phase fails + KeyboardInterrupt: If user interrupts + """ + total_phases = len(self.phases) + + try: + for i, phase in enumerate(self.phases, start=1): + self._print_phase_header(i, total_phases, phase.name) + + # CRITICAL: Wait for phase to complete before continuing + phase_start = datetime.now() + result = await phase.execute(context) + phase_duration = (datetime.now() - phase_start).total_seconds() + + context.phase_durations[phase.name] = phase_duration + + if result.status == PhaseStatus.FAILED: + self._handle_phase_failure(phase, result) + break + + self._print_phase_success(phase.name, result) + + except KeyboardInterrupt: + console.print("\n[yellow]⚠ Scan Interrupted by User[/yellow]") + context.interrupted = True + raise + + except Exception as e: + console.print(f"\n[red]✗ Pipeline Error: {e.__class__.__name__}[/red]") + raise + + return context + + def _print_phase_header(self, phase_num: int, total: int, name: str) -> None: + """Print phase header.""" + header = f"[{phase_num}/{total}] {name}" + console.print() + console.print(Panel(header, border_style="cyan", expand=False)) + + def _print_phase_success(self, name: str, result: PhaseResult) -> None: + """Print phase success message.""" + console.print(f"[green]✓ {name} Complete[/green]") + if result.message: + console.print(f" └─ {result.message}") + + def _handle_phase_failure(self, phase: Phase, result: PhaseResult) -> None: + """Handle phase failure.""" + console.print(f"[red]✗ {phase.name} Failed[/red]") + if result.error: + console.print(f" └─ {result.error}") + + +async def create_default_pipeline() -> Pipeline: + """ + Create the default 4-phase pipeline. + + Returns: + Pipeline with Discovery, Attack, Verification, and Reporting phases + """ + from pit.orchestrator.phases import ( + AttackPhase, + DiscoveryPhase, + ReportingPhase, + VerificationPhase, + ) + + return Pipeline( + phases=[ + DiscoveryPhase(), + AttackPhase(), + VerificationPhase(), + ReportingPhase(), + ] + ) diff --git a/tools/prompt_injection_tester/pit/ui/__init__.py b/tools/prompt_injection_tester/pit/ui/__init__.py new file mode 100644 index 0000000..2178f68 --- /dev/null +++ b/tools/prompt_injection_tester/pit/ui/__init__.py @@ -0,0 +1,24 @@ +""" +Rich UI components for the CLI interface. +""" + +from pit.ui.console import console +from pit.ui.progress import create_progress_bar, create_spinner +from pit.ui.display import ( + print_banner, + print_success, + print_error, + print_warning, + print_info, +) + +__all__ = [ + "console", + "create_progress_bar", + "create_spinner", + "print_banner", + "print_success", + "print_error", + "print_warning", + "print_info", +] diff --git a/tools/prompt_injection_tester/pit/ui/spinner.py b/tools/prompt_injection_tester/pit/ui/spinner.py new file mode 100644 index 0000000..a2c35f1 --- /dev/null +++ b/tools/prompt_injection_tester/pit/ui/spinner.py @@ -0,0 +1,34 @@ +"""Spinner animations for async operations.""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Iterator, Optional + +from rich.console import Console +from rich.spinner import Spinner +from rich.live import Live + +console = Console() + + +@contextmanager +def show_spinner(message: str, spinner_type: str = "dots") -> Iterator[None]: + """ + Show a spinner during an operation. + + Args: + message: Message to display next to spinner + spinner_type: Type of spinner animation + + Yields: + None + + Example: + with show_spinner("Discovering injection points..."): + await discover() + """ + spinner = Spinner(spinner_type, text=message, style="cyan") + + with Live(spinner, console=console, transient=True): + yield diff --git a/tools/prompt_injection_tester/pit/ui/styles.py b/tools/prompt_injection_tester/pit/ui/styles.py new file mode 100644 index 0000000..2410dfa --- /dev/null +++ b/tools/prompt_injection_tester/pit/ui/styles.py @@ -0,0 +1,114 @@ +"""Color schemes and styling.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.models import Severity, TestStatus + + +# Color scheme +STYLES = { + "primary": "cyan", + "success": "green", + "warning": "yellow", + "error": "red", + "info": "blue", + "muted": "dim", +} + +# Severity colors +SEVERITY_COLORS = { + "critical": "red", + "high": "red", + "medium": "yellow", + "low": "blue", + "info": "green", +} + +SEVERITY_ICONS = { + "critical": "🔴", + "high": "🔴", + "medium": "🟠", + "low": "🟡", + "info": "🟢", +} + +# Status symbols +STATUS_SYMBOLS = { + "success": "✓", + "failed": "✗", + "error": "⚠", + "pending": "⋯", + "running": "⠋", +} + + +def format_severity(severity: str | Severity) -> str: + """ + Format severity with color and icon. + + Args: + severity: Severity level + + Returns: + Formatted severity string with Rich markup + """ + if hasattr(severity, "value"): + severity = severity.value + + severity_lower = severity.lower() + color = SEVERITY_COLORS.get(severity_lower, "white") + icon = SEVERITY_ICONS.get(severity_lower, "⚪") + + return f"[{color}]{icon} {severity.upper()}[/{color}]" + + +def format_status(status: str | TestStatus, success: bool = True) -> str: + """ + Format test status with symbol and color. + + Args: + status: Status value + success: Whether the test succeeded + + Returns: + Formatted status string with Rich markup + """ + if hasattr(status, "value"): + status = status.value + + status_lower = status.lower() + symbol = STATUS_SYMBOLS.get(status_lower, "⋯") + + if success: + color = "green" + elif status_lower == "error": + color = "red" + else: + color = "yellow" + + return f"[{color}]{symbol} {status.title()}[/{color}]" + + +def format_confidence(confidence: float) -> str: + """ + Format confidence score with color. + + Args: + confidence: Confidence score (0.0-1.0) + + Returns: + Formatted confidence string with Rich markup + """ + percentage = int(confidence * 100) + + if confidence >= 0.9: + color = "green" + elif confidence >= 0.7: + color = "yellow" + else: + color = "red" + + return f"[{color}]{percentage}%[/{color}]" diff --git a/tools/prompt_injection_tester/pit/utils/__init__.py b/tools/prompt_injection_tester/pit/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/prompt_injection_tester/pyproject.toml b/tools/prompt_injection_tester/pyproject.toml index 0148b1c..12c0c78 100644 --- a/tools/prompt_injection_tester/pyproject.toml +++ b/tools/prompt_injection_tester/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "prompt-injection-tester" -version = "1.0.0" -description = "Automated prompt injection testing framework for LLMs" +version = "2.0.0" +description = "Modern CLI for automated prompt injection testing of LLMs" readme = "README.md" requires-python = ">=3.10" license = {text = "CC BY-SA 4.0"} @@ -18,6 +18,7 @@ keywords = [ "prompt-injection", "red-team", "testing", + "cli", ] classifiers = [ "Development Status :: 4 - Beta", @@ -31,9 +32,12 @@ classifiers = [ ] dependencies = [ "aiohttp>=3.9.0", + "httpx>=0.24.0", "pyyaml>=6.0", - "typer>=0.7.0", + "typer>=0.9.0", "rich>=13.0.0", + "pydantic>=2.0.0", + "jinja2>=3.1.0", ] [project.optional-dependencies] @@ -48,7 +52,7 @@ dev = [ [project.scripts] prompt-injection-tester = "prompt_injection_tester.cli:main" -pit = "pit.app:app" +pit = "pit.app:cli_main" [project.urls] "Homepage" = "https://github.com/example/ai-llm-red-team-handbook" diff --git a/tools/prompt_injection_tester/tests/__init__.py b/tools/prompt_injection_tester/tests/__init__.py new file mode 100644 index 0000000..670267b --- /dev/null +++ b/tools/prompt_injection_tester/tests/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +"""Test suite for the prompt injection tester.""" diff --git a/tools/prompt_injection_tester/tests/test_patterns/__init__.py b/tools/prompt_injection_tester/tests/test_patterns/__init__.py new file mode 100644 index 0000000..6ae4b33 --- /dev/null +++ b/tools/prompt_injection_tester/tests/test_patterns/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +"""Tests for attack patterns.""" diff --git a/tools/prompt_injection_tester/utils/__init__.py b/tools/prompt_injection_tester/utils/__init__.py new file mode 100644 index 0000000..5e914b7 --- /dev/null +++ b/tools/prompt_injection_tester/utils/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +"""Utility modules for the prompt injection tester.""" + +from .encoding import encode_payload, decode_payload, translate_payload +from .http_client import AsyncHTTPClient, LLMClient, HTTPResponse, RateLimiter + +__all__ = [ + "encode_payload", + "decode_payload", + "translate_payload", + "AsyncHTTPClient", + "LLMClient", + "HTTPResponse", + "RateLimiter", +]