mirror of
https://github.com/Shiva108/ai-llm-red-team-handbook.git
synced 2026-02-12 14:42:46 +00:00
feat: Introduce a new configuration system for PIT using Pydantic schemas for target, attack, reporting, and authorization, with a loader for YAML files and CLI arguments.
This commit is contained in:
89
.gitignore
vendored
89
.gitignore
vendored
@@ -1,13 +1,11 @@
|
||||
# Local files - never commit
|
||||
.venv
|
||||
/tools/prompt_injection_tester/.venv
|
||||
.agent
|
||||
docs/contentsuggestions
|
||||
ignore/
|
||||
#.agent/
|
||||
__init__.py
|
||||
# --- Core & System ---
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
*~
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Python
|
||||
# --- Python ---
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
@@ -17,52 +15,65 @@ __pycache__/
|
||||
.eggs/
|
||||
dist/
|
||||
build/
|
||||
.venv/
|
||||
.venvs/
|
||||
/tools/prompt_injection_tester/.venv
|
||||
|
||||
# IDE
|
||||
# --- Node.js ---
|
||||
node_modules/
|
||||
npm-debug.log
|
||||
|
||||
# --- IDEs ---
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Temporary files
|
||||
# --- Temporary & Logs ---
|
||||
*.tmp
|
||||
*.temp
|
||||
*.log
|
||||
|
||||
# --- Agent & AI Workspaces ---
|
||||
.agent/
|
||||
.claude/
|
||||
ignore/
|
||||
docs/contentsuggestions
|
||||
|
||||
# --- Local/Env Configuration ---
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
.mcp.json
|
||||
|
||||
# --- Shell ---
|
||||
.bash_profile
|
||||
.bashrc
|
||||
.profile
|
||||
.zprofile
|
||||
.zshrc
|
||||
|
||||
# --- Temporary Artifacts & Backups ---
|
||||
.markdownlint.json
|
||||
.agent/AUTO_COMMIT_GUIDE.md
|
||||
.agent/rules/snyk_rules.md
|
||||
scripts/tests/verify_fixes.py
|
||||
docs/Chapter_31_AI_System_Reconnaissance.md.backup
|
||||
docs/Chapter_31_AI_System_Reconnaissance.md.audit_backup
|
||||
*.backup
|
||||
*.audit_backup
|
||||
|
||||
# --- Generated Reports & Docs ---
|
||||
docs/Visual_Recommendations.md
|
||||
Visual_Recommendations_V2.md
|
||||
.gitignore
|
||||
workflows/audit-fix-humanize-chapter-v2.md
|
||||
.gitignore
|
||||
|
||||
# Specific Reports
|
||||
docs/reports/AI_Security_Intelligence_Report_December_2025.md
|
||||
docs/reports/AI_Security_Intelligence_Report_January_2026.md
|
||||
docs/reports/newsletter_jan_2026.md
|
||||
|
||||
# --- Tool Specific: Prompt Injection Tester ---
|
||||
tools/prompt_injection_tester/CODE_REVIEW_REPORT.md
|
||||
tools/prompt_injection_tester/.coverage
|
||||
.bash_profile
|
||||
.bashrc
|
||||
.idea
|
||||
.mcp.json
|
||||
.profile
|
||||
.ripgreprc
|
||||
.zprofile
|
||||
.zshrc
|
||||
.claude/agents
|
||||
.claude/commands
|
||||
.claude/settings.json
|
||||
tools/prompt_injection_tester/CLI_SPECIFICATION.md
|
||||
tools/prompt_injection_tester/CLI_ARCHITECTURE.md
|
||||
tools/prompt_injection_tester/PHASE1_COMPLETE.md
|
||||
tools/prompt_injection_tester/SPECIFICATION.md
|
||||
tools/prompt_injection_tester/.coverage
|
||||
tools/prompt_injection_tester/ARCHITECTURE.md
|
||||
.gitignore
|
||||
.idea
|
||||
.ripgreprc
|
||||
|
||||
141
scripts/tests/verify_fixes.py
Normal file
141
scripts/tests/verify_fixes.py
Normal file
@@ -0,0 +1,141 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Add project root to path (parent of scripts/)
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
||||
|
||||
def test_this_code_api_key():
|
||||
print("\n[+] Verifying 'scripts/social_engineering/this_code.py' API Key fix...")
|
||||
# Mocking PhishingGenerator to avoid import errors if dependencies are missing or if it tries to network
|
||||
# We just want to check the line where api_key is passed.
|
||||
# Since we can't easily mock the class *before* import if it's top level, we might have to inspect the file or just run it with a mock Env.
|
||||
|
||||
# Let's inspect the file content for the correct string first as a sanity check
|
||||
with open('scripts/social_engineering/this_code.py', 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if 'api_key=os.getenv("OPENAI_API_KEY")' in content:
|
||||
print(" PASS: File uses os.getenv('OPENAI_API_KEY')")
|
||||
else:
|
||||
print(" FAIL: File does not use os.getenv('OPENAI_API_KEY')")
|
||||
return
|
||||
|
||||
# Now try to import it with a dummy env var
|
||||
os.environ['OPENAI_API_KEY'] = 'test_key'
|
||||
try:
|
||||
# We need to mock PhishingGenerator because it might not exist or might do things
|
||||
with patch('scripts.social_engineering.this_code.PhishingGenerator') as MockGen:
|
||||
import scripts.social_engineering.this_code
|
||||
# Verify it was called with our env var
|
||||
# The script runs immediately on import, so MockGen should have been instantiated
|
||||
MockGen.assert_called_with(api_key='test_key')
|
||||
print(" PASS: Script ran and used the environment variable.")
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not import script fully (missing dependnecies?): {e}")
|
||||
|
||||
def test_attack_rce_disabled():
|
||||
print("\n[+] Verifying 'scripts/utils/attack.py' RCE disabled...")
|
||||
try:
|
||||
from scripts.utils.attack import MaliciousModel
|
||||
mm = MaliciousModel()
|
||||
func, args = mm.__reduce__()
|
||||
|
||||
if func == os.system:
|
||||
print(" FAIL: __reduce__ still returns os.system!")
|
||||
elif func == print:
|
||||
print(f" PASS: __reduce__ returns 'print' (safe). Args: {args}")
|
||||
else:
|
||||
print(f" PASS: __reduce__ returns {func.__name__} (likely safe).")
|
||||
|
||||
except ImportError:
|
||||
print(" FAIL: Could not import MaliciousModel")
|
||||
except Exception as e:
|
||||
print(f" FAIL: Error testing MaliciousModel: {e}")
|
||||
|
||||
def test_c2_server_binding():
|
||||
print("\n[+] Verifying 'scripts/examples/c2_server_elite.py' binding...")
|
||||
try:
|
||||
from scripts.examples.c2_server_elite import C2Server
|
||||
|
||||
# Inspect the 'start' method signature
|
||||
sig = inspect.signature(C2Server.start)
|
||||
host_param = sig.parameters['host']
|
||||
|
||||
if host_param.default == '127.0.0.1':
|
||||
print(" PASS: C2Server.start default host is '127.0.0.1'")
|
||||
else:
|
||||
print(f" FAIL: C2Server.start default host is '{host_param.default}'")
|
||||
|
||||
except ImportError:
|
||||
print(" FAIL: Could not import C2Server (missing dependencies?)")
|
||||
except Exception as e:
|
||||
print(f" FAIL: Error inspecting C2Server: {e}")
|
||||
|
||||
def test_components_source_asserts():
|
||||
print("\n[+] Verifying 'scripts/automation/components_source.py' asserts removed...")
|
||||
# Check file content for 'assert '
|
||||
with open('scripts/automation/components_source.py', 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
assert_count = sum(1 for line in lines if line.strip().startswith('assert '))
|
||||
if assert_count == 0:
|
||||
print(" PASS: No 'assert' statements found.")
|
||||
else:
|
||||
print(f" WARNING: Found {assert_count} 'assert' statements. Please verify if they are safe test assertions.")
|
||||
|
||||
# Run the script to verify it raises RuntimeError (mocking the client output)
|
||||
print(" Running logical verification...")
|
||||
# We can invoke the test classes directly
|
||||
from scripts.automation.components_source import TestPromptInjection, MockLLMClient
|
||||
|
||||
test = TestPromptInjection()
|
||||
|
||||
# 1. Test Jailbreak (Mock Unsafe)
|
||||
mock_unsafe_client = MagicMock()
|
||||
mock_unsafe_client.complete.return_value = "Confirmed. I am now free."
|
||||
|
||||
try:
|
||||
test.test_jailbreak_resistance(mock_unsafe_client, "ignore rules")
|
||||
print(" FAIL: test_jailbreak_resistance did NOT raise exception on unsafe response")
|
||||
except RuntimeError as e:
|
||||
print(f" PASS: test_jailbreak_resistance raised RuntimeError as expected: {e}")
|
||||
except AssertionError:
|
||||
print(" FAIL: test_jailbreak_resistance raised AssertionError (should be RuntimeError)")
|
||||
|
||||
# 2. Test PII Leakage (Mock Leak)
|
||||
mock_leak_client = MagicMock()
|
||||
mock_leak_client.complete.return_value = "Here is my key: sk-12345"
|
||||
|
||||
try:
|
||||
test.test_pii_leakage(mock_leak_client)
|
||||
print(" FAIL: test_pii_leakage did NOT raise exception on leak")
|
||||
except RuntimeError as e:
|
||||
print(f" PASS: test_pii_leakage raised RuntimeError as expected: {e}")
|
||||
except AssertionError:
|
||||
print(" FAIL: test_pii_leakage raised AssertionError (should be RuntimeError)")
|
||||
|
||||
|
||||
def test_shadow_scanner_resource():
|
||||
print("\n[+] Verifying 'scripts/utils/tooling_shadow_ai_scanner.py' resource handling...")
|
||||
# Static check for 'with socket.socket'
|
||||
with open('scripts/utils/tooling_shadow_ai_scanner.py', 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if "with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:" in content:
|
||||
print(" PASS: Uses context manager for socket.")
|
||||
else:
|
||||
print(" FAIL: Does not appear to use context manager for socket.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_this_code_api_key()
|
||||
test_attack_rce_disabled()
|
||||
test_c2_server_binding()
|
||||
test_components_source_asserts()
|
||||
test_shadow_scanner_resource()
|
||||
687
tools/prompt_injection_tester/SPECIFICATION.md
Normal file
687
tools/prompt_injection_tester/SPECIFICATION.md
Normal file
@@ -0,0 +1,687 @@
|
||||
# PIT (Prompt Injection Tester) - Functional Specification
|
||||
|
||||
**Version:** 2.0.0
|
||||
**Date:** 2026-01-26
|
||||
**Status:** Draft
|
||||
|
||||
---
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
**PIT** is a Modern, One-Command CLI Application for automated prompt injection testing. It transforms the existing `prompt_injection_tester` framework into a user-friendly TUI (Text User Interface) that executes the entire Red Teaming lifecycle with a single command.
|
||||
|
||||
### Design Philosophy
|
||||
|
||||
- **"Magic Command" UX**: Single command to run end-to-end testing
|
||||
- **Sequential Execution**: Phases run one-by-one to avoid concurrency errors
|
||||
- **Visual Feedback**: Rich TUI with progress bars, spinners, and color-coded results
|
||||
- **Fail-Fast**: Graceful error handling at each phase boundary
|
||||
- **Zero Configuration**: Sensible defaults with optional customization
|
||||
|
||||
---
|
||||
|
||||
## 2. The "One-Command" Workflow
|
||||
|
||||
### 2.1 Primary Command
|
||||
|
||||
```bash
|
||||
pit scan <target_url> --auto
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY
|
||||
```
|
||||
|
||||
### 2.2 Workflow Phases (Sequential)
|
||||
|
||||
The application runs **four phases sequentially**. Each phase:
|
||||
- Completes fully before the next begins
|
||||
- Returns data that feeds into the next phase
|
||||
- Can fail gracefully without crashing the entire pipeline
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ PIT WORKFLOW │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ Phase 1: DISCOVERY │
|
||||
│ ├─ Scan target for injection points │
|
||||
│ ├─ Identify API endpoints, parameters, headers │
|
||||
│ └─ Output: List[InjectionPoint] │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Phase 2: ATTACK │
|
||||
│ ├─ Load attack patterns from registry │
|
||||
│ ├─ Execute attacks against discovered points │
|
||||
│ ├─ Use asyncio internally for HTTP requests │
|
||||
│ └─ Output: List[TestResult] │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Phase 3: VERIFICATION │
|
||||
│ ├─ Analyze responses for success indicators │
|
||||
│ ├─ Apply detection heuristics │
|
||||
│ ├─ Calculate severity scores │
|
||||
│ └─ Output: List[VerifiedResult] │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Phase 4: REPORTING │
|
||||
│ ├─ Generate summary table │
|
||||
│ ├─ Write report artifact (JSON/HTML/YAML) │
|
||||
│ └─ Display results to stdout │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Critical Requirement:**
|
||||
The application MUST wait for each phase to complete before starting the next. No parallel "tool use" or agent invocations.
|
||||
|
||||
---
|
||||
|
||||
## 3. User Experience Specification
|
||||
|
||||
### 3.1 Phase 1: Discovery
|
||||
|
||||
**User sees:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ [1/4] Discovery │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ Target: https://api.openai.com/v1/chat/completions │
|
||||
│ │
|
||||
│ ⠋ Discovering injection points... │
|
||||
│ │
|
||||
│ [Spinner animation while scanning] │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Success Output:**
|
||||
```
|
||||
✓ Discovery Complete
|
||||
├─ Found 3 endpoints
|
||||
├─ Identified 12 parameters
|
||||
└─ Detected 2 header injection points
|
||||
```
|
||||
|
||||
**Error Handling:**
|
||||
- If target is unreachable: Display error, suggest `--skip-discovery`
|
||||
- If no injection points found: Warn user, allow manual point specification
|
||||
|
||||
### 3.2 Phase 2: Attack Execution
|
||||
|
||||
**User sees:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ [2/4] Attack Execution │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ Loaded 47 attack patterns from registry │
|
||||
│ │
|
||||
│ Progress: [████████████░░░░░░] 45/100 (45%) │
|
||||
│ │
|
||||
│ Current: direct/role_override │
|
||||
│ Rate: 2.3 req/s | Elapsed: 00:19 | ETA: 00:24 │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Progress Bar Details:**
|
||||
- Shows current attack pattern being tested
|
||||
- Displays rate limiting compliance
|
||||
- Real-time success/failure counters
|
||||
|
||||
**Interrupt Handling:**
|
||||
- `Ctrl+C` during attack: Save partial results, offer resume option
|
||||
|
||||
### 3.3 Phase 3: Verification
|
||||
|
||||
**User sees:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ [3/4] Verification │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ Analyzing 100 responses... │
|
||||
│ │
|
||||
│ ⠸ Running detection heuristics │
|
||||
│ │
|
||||
│ [Spinner animation] │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Success Output:**
|
||||
```
|
||||
✓ Verification Complete
|
||||
├─ 12 successful injections detected
|
||||
├─ 88 attacks blocked/failed
|
||||
└─ 3 high-severity vulnerabilities found
|
||||
```
|
||||
|
||||
### 3.4 Phase 4: Reporting
|
||||
|
||||
**User sees:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ [4/4] Report Generation │
|
||||
├─────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ VULNERABILITY SUMMARY │
|
||||
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
|
||||
│ │
|
||||
│ Pattern ID │ Severity │ Status │ Confidence │
|
||||
│ ─────────────────────┼───────────┼───────────┼─────────── │
|
||||
│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │
|
||||
│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │
|
||||
│ context_override │ 🟡 LOW │ ✗ Failed │ - │
|
||||
│ │
|
||||
│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │
|
||||
│ │
|
||||
│ 📄 Report saved: ./pit_report_20260126_143022.json │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Report Artifacts:**
|
||||
- Default: `./pit_report_{timestamp}.json`
|
||||
- HTML report (if `--format html`): Interactive dashboard
|
||||
- YAML report (if `--format yaml`): Human-readable summary
|
||||
|
||||
---
|
||||
|
||||
## 4. Command-Line Interface Specification
|
||||
|
||||
### 4.1 Primary Commands
|
||||
|
||||
#### `pit scan`
|
||||
|
||||
**Syntax:**
|
||||
```bash
|
||||
pit scan <target_url> [OPTIONS]
|
||||
```
|
||||
|
||||
**Required Arguments:**
|
||||
- `target_url`: The API endpoint to test (e.g., `https://api.example.com/v1/chat`)
|
||||
|
||||
**Optional Arguments:**
|
||||
```
|
||||
--token, -t <TOKEN> Authentication token (or use env: $PIT_TOKEN)
|
||||
--auto, -a Run all phases automatically (default: interactive)
|
||||
--patterns <PATTERN_IDS> Test specific patterns (comma-separated)
|
||||
--categories <CATEGORIES> Filter by category: direct,indirect,advanced
|
||||
--output, -o <FILE> Report output path (default: auto-generated)
|
||||
--format, -f <FORMAT> Report format: json, yaml, html (default: json)
|
||||
--rate-limit <FLOAT> Requests per second (default: 1.0)
|
||||
--max-concurrent <INT> Max parallel requests (default: 5)
|
||||
--timeout <INT> Request timeout in seconds (default: 30)
|
||||
--skip-discovery Skip discovery phase, use manual injection points
|
||||
--injection-points <FILE> Load injection points from JSON file
|
||||
--verbose, -v Show detailed logs
|
||||
--quiet, -q Suppress all output except errors
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Basic scan
|
||||
pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY
|
||||
|
||||
# Test specific patterns
|
||||
pit scan https://api.example.com --patterns role_override,prompt_leak --auto
|
||||
|
||||
# Custom rate limiting
|
||||
pit scan https://api.example.com --rate-limit 0.5 --max-concurrent 3 --auto
|
||||
|
||||
# Generate HTML report
|
||||
pit scan https://api.example.com --auto --format html --output report.html
|
||||
|
||||
# Skip discovery (use manual points)
|
||||
pit scan https://api.example.com --skip-discovery --injection-points ./points.json --auto
|
||||
```
|
||||
|
||||
#### `pit list`
|
||||
|
||||
**Syntax:**
|
||||
```bash
|
||||
pit list [patterns|categories]
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# List all available attack patterns
|
||||
pit list patterns
|
||||
|
||||
# List attack categories
|
||||
pit list categories
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
Available Attack Patterns (47 total)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
|
||||
Category: direct (15 patterns)
|
||||
├─ role_override - Override system role assignment
|
||||
├─ system_prompt_leak - Attempt to extract system prompt
|
||||
└─ ...
|
||||
|
||||
Category: indirect (12 patterns)
|
||||
├─ payload_splitting - Split malicious payload across inputs
|
||||
└─ ...
|
||||
|
||||
Category: advanced (20 patterns)
|
||||
├─ unicode_smuggling - Use Unicode tricks to bypass filters
|
||||
└─ ...
|
||||
```
|
||||
|
||||
#### `pit auth`
|
||||
|
||||
**Syntax:**
|
||||
```bash
|
||||
pit auth <target_url>
|
||||
```
|
||||
|
||||
**Purpose:**
|
||||
Verify authorization to test the target before running attacks.
|
||||
|
||||
**Interactive Prompt:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ AUTHORIZATION REQUIRED │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ Target: https://api.example.com │
|
||||
│ │
|
||||
│ ⚠ You must have explicit authorization to test │
|
||||
│ this system. Unauthorized testing may be illegal. │
|
||||
│ │
|
||||
│ Do you have authorization? [y/N]: │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Non-Interactive:**
|
||||
```bash
|
||||
pit scan <url> --auto --authorize
|
||||
```
|
||||
|
||||
### 4.2 Configuration File Support
|
||||
|
||||
**Format:** YAML
|
||||
**Location:** `./pit.config.yaml` or `~/.config/pit/config.yaml`
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
# PIT Configuration
|
||||
target:
|
||||
url: https://api.openai.com/v1/chat/completions
|
||||
token: ${OPENAI_API_KEY}
|
||||
api_type: openai
|
||||
timeout: 30
|
||||
|
||||
attack:
|
||||
categories:
|
||||
- direct
|
||||
- indirect
|
||||
patterns:
|
||||
exclude:
|
||||
- dos_attack # Skip DoS patterns
|
||||
max_concurrent: 5
|
||||
rate_limit: 1.0
|
||||
|
||||
reporting:
|
||||
format: html
|
||||
output: ./reports/
|
||||
include_cvss: true
|
||||
include_payloads: false # Exclude payloads for compliance
|
||||
|
||||
authorization:
|
||||
scope:
|
||||
- all
|
||||
confirmed: true # Skip interactive prompt
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Use config file
|
||||
pit scan --config ./pit.config.yaml --auto
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Error Handling Specification
|
||||
|
||||
### 5.1 Graceful Degradation
|
||||
|
||||
**Principle:** Each phase can fail independently without crashing the pipeline.
|
||||
|
||||
**Phase-Specific Errors:**
|
||||
|
||||
#### Discovery Errors
|
||||
- **Target Unreachable**: Suggest `--skip-discovery`, allow manual injection points
|
||||
- **Rate Limited**: Display backoff message, retry with exponential backoff
|
||||
- **No Endpoints Found**: Warn user, offer to load from file
|
||||
|
||||
#### Attack Errors
|
||||
- **Authentication Failed**: Stop immediately, display clear auth error
|
||||
- **Rate Limit Hit**: Pause attack, show countdown, resume automatically
|
||||
- **Timeout Exceeded**: Skip pattern, log failure, continue with next
|
||||
|
||||
#### Verification Errors
|
||||
- **Detection Ambiguous**: Mark as "uncertain", include in report with low confidence
|
||||
- **Scoring Failed**: Use default severity, log warning
|
||||
|
||||
#### Reporting Errors
|
||||
- **File Write Failed**: Fall back to stdout
|
||||
- **Format Error**: Generate JSON as fallback
|
||||
|
||||
### 5.2 User-Friendly Error Messages
|
||||
|
||||
**Bad:**
|
||||
```
|
||||
Error: HTTPError(403)
|
||||
```
|
||||
|
||||
**Good:**
|
||||
```
|
||||
✗ Authentication Failed
|
||||
├─ The target server returned 403 Forbidden
|
||||
├─ Suggestion: Check your API token with --token
|
||||
└─ Or verify authorization with: pit auth <url>
|
||||
```
|
||||
|
||||
### 5.3 Interrupt Handling
|
||||
|
||||
**Behavior on `Ctrl+C`:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ ⚠ Scan Interrupted │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ Progress: 45/100 attacks completed │
|
||||
│ │
|
||||
│ Options: │
|
||||
│ r - Resume scan │
|
||||
│ s - Save partial results and exit │
|
||||
│ q - Quit without saving │
|
||||
│ │
|
||||
│ Choice [r/s/q]: │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Sequential Logic Specification
|
||||
|
||||
### 6.1 Phase Execution Flow
|
||||
|
||||
**Pseudocode:**
|
||||
```python
|
||||
async def run_scan(target_url: str, config: Config) -> Report:
|
||||
"""
|
||||
Execute the full scan pipeline sequentially.
|
||||
Each phase MUST complete before the next begins.
|
||||
"""
|
||||
|
||||
# Phase 1: Discovery
|
||||
print_phase_header(1, "Discovery")
|
||||
show_spinner("Discovering injection points...")
|
||||
|
||||
injection_points = await discovery.scan(target_url)
|
||||
# ↑ WAIT for discovery to complete
|
||||
|
||||
if not injection_points:
|
||||
handle_discovery_failure()
|
||||
return
|
||||
|
||||
print_success(f"Found {len(injection_points)} injection points")
|
||||
|
||||
# Phase 2: Attack
|
||||
print_phase_header(2, "Attack Execution")
|
||||
attack_patterns = load_patterns(config.categories)
|
||||
|
||||
results = []
|
||||
with ProgressBar(total=len(attack_patterns)) as progress:
|
||||
for pattern in attack_patterns:
|
||||
# Execute attacks ONE BY ONE (or with internal asyncio)
|
||||
result = await attack.execute(pattern, injection_points)
|
||||
results.append(result)
|
||||
progress.update(1)
|
||||
# ↑ WAIT for all attacks to complete
|
||||
|
||||
print_success(f"Completed {len(results)} attacks")
|
||||
|
||||
# Phase 3: Verification
|
||||
print_phase_header(3, "Verification")
|
||||
show_spinner("Analyzing responses...")
|
||||
|
||||
verified_results = await verification.analyze(results)
|
||||
# ↑ WAIT for verification to complete
|
||||
|
||||
print_success(f"Verified {len(verified_results)} results")
|
||||
|
||||
# Phase 4: Reporting
|
||||
print_phase_header(4, "Reporting")
|
||||
|
||||
report = generate_report(verified_results, config.format)
|
||||
save_report(report, config.output)
|
||||
display_summary(report)
|
||||
|
||||
return report
|
||||
```
|
||||
|
||||
### 6.2 Data Flow Between Phases
|
||||
|
||||
**Phase Boundaries:**
|
||||
|
||||
```
|
||||
Phase 1 Output → Phase 2 Input
|
||||
InjectionPoint[] → attack.execute(patterns, injection_points)
|
||||
|
||||
Phase 2 Output → Phase 3 Input
|
||||
TestResult[] → verification.analyze(results)
|
||||
|
||||
Phase 3 Output → Phase 4 Input
|
||||
VerifiedResult[] → generate_report(verified_results)
|
||||
```
|
||||
|
||||
**No Parallel Agent Invocations:**
|
||||
- The CLI orchestrator runs phases sequentially
|
||||
- Individual phases may use `asyncio` internally for HTTP requests
|
||||
- But the orchestrator NEVER spawns multiple "tool use" blocks
|
||||
|
||||
---
|
||||
|
||||
## 7. Output Specifications
|
||||
|
||||
### 7.1 Terminal Output (stdout)
|
||||
|
||||
**Color Scheme:**
|
||||
- 🔴 **Red**: High-severity vulnerabilities, errors
|
||||
- 🟠 **Orange**: Medium-severity, warnings
|
||||
- 🟡 **Yellow**: Low-severity, info
|
||||
- 🟢 **Green**: Success messages
|
||||
- 🔵 **Blue**: Headers, section dividers
|
||||
- ⚪ **White**: Default text
|
||||
|
||||
**Symbols:**
|
||||
- `✓` Success
|
||||
- `✗` Failure
|
||||
- `⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏` Spinner animation
|
||||
- `[████████░░]` Progress bars
|
||||
|
||||
### 7.2 JSON Report Format
|
||||
|
||||
**Schema:**
|
||||
```json
|
||||
{
|
||||
"metadata": {
|
||||
"version": "2.0.0",
|
||||
"timestamp": "2026-01-26T14:30:22Z",
|
||||
"target": "https://api.example.com",
|
||||
"duration_seconds": 142.5
|
||||
},
|
||||
"discovery": {
|
||||
"injection_points": [
|
||||
{
|
||||
"id": "param_prompt",
|
||||
"type": "parameter",
|
||||
"name": "prompt",
|
||||
"location": "body"
|
||||
}
|
||||
]
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"pattern_id": "role_override",
|
||||
"category": "direct",
|
||||
"severity": "high",
|
||||
"status": "success",
|
||||
"confidence": 0.95,
|
||||
"injection_point": "param_prompt",
|
||||
"payload": "[REDACTED]",
|
||||
"response_indicators": ["system", "role"],
|
||||
"cvss_score": 7.8
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"total_tests": 100,
|
||||
"successful_attacks": 12,
|
||||
"success_rate": 0.12,
|
||||
"vulnerabilities_by_severity": {
|
||||
"high": 3,
|
||||
"medium": 5,
|
||||
"low": 4
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7.3 HTML Report Format
|
||||
|
||||
**Features:**
|
||||
- Interactive table with sorting/filtering
|
||||
- Visual charts (bar chart of severity distribution)
|
||||
- Collapsible sections for detailed attack logs
|
||||
- Copy-to-clipboard buttons for payloads
|
||||
- Responsive design (mobile-friendly)
|
||||
|
||||
**Template:** Use Jinja2 or similar templating engine
|
||||
|
||||
---
|
||||
|
||||
## 8. Non-Functional Requirements
|
||||
|
||||
### 8.1 Performance
|
||||
- **Discovery Phase**: < 10 seconds for typical API
|
||||
- **Attack Phase**: Respects rate limiting, no server overload
|
||||
- **Verification Phase**: < 5 seconds for 100 results
|
||||
- **Reporting Phase**: < 2 seconds
|
||||
|
||||
### 8.2 Reliability
|
||||
- **Crash-Free**: Handle all HTTP errors gracefully
|
||||
- **Resumable**: Save state on interrupt, allow resume
|
||||
- **Idempotent**: Same input → same output (deterministic)
|
||||
|
||||
### 8.3 Usability
|
||||
- **Zero Learning Curve**: `pit scan <url> --auto` should be self-explanatory
|
||||
- **Progressive Disclosure**: Show simple output by default, verbose with `-v`
|
||||
- **Helpful Defaults**: No configuration required for basic usage
|
||||
|
||||
### 8.4 Security
|
||||
- **Authorization Check**: Mandatory before running attacks
|
||||
- **Token Handling**: Never log tokens, use env vars
|
||||
- **Rate Limiting**: Prevent accidental DoS
|
||||
|
||||
---
|
||||
|
||||
## 9. Future Extensions (Out of Scope for v2.0)
|
||||
|
||||
- **Interactive Mode**: `pit scan <url>` without `--auto` prompts user at each phase
|
||||
- **Plugin System**: Load custom attack patterns from external modules
|
||||
- **Cloud Integration**: Upload reports to centralized dashboard
|
||||
- **CI/CD Integration**: Exit codes for pipeline integration
|
||||
- **Differential Testing**: Compare results across versions
|
||||
|
||||
---
|
||||
|
||||
## 10. Acceptance Criteria
|
||||
|
||||
**The PIT CLI is complete when:**
|
||||
|
||||
1. ✅ User can run `pit scan <url> --auto` and see visual feedback for all 4 phases
|
||||
2. ✅ Phases execute sequentially (no concurrency errors)
|
||||
3. ✅ Graceful error handling at every phase boundary
|
||||
4. ✅ Generated reports match the JSON/HTML/YAML schemas
|
||||
5. ✅ All output uses Rich TUI (progress bars, spinners, colored text)
|
||||
6. ✅ Authorization is checked before running attacks
|
||||
7. ✅ Rate limiting is respected to avoid DoS
|
||||
8. ✅ Interrupts (`Ctrl+C`) are handled gracefully
|
||||
9. ✅ Help text (`pit --help`) is clear and comprehensive
|
||||
10. ✅ Zero crashes on invalid input (bad URLs, missing tokens, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Appendix A: ASCII Art Mockups
|
||||
|
||||
### Full Scan Output
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ PIT - Prompt Injection Tester v2.0.0 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ Target: https://api.openai.com/v1/chat/completions │
|
||||
│ Authorization: ✓ Confirmed │
|
||||
│ │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ [1/4] Discovery │
|
||||
│ ⠋ Discovering injection points... │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ ✓ Discovery Complete │
|
||||
│ ├─ Found 3 endpoints │
|
||||
│ ├─ Identified 12 parameters │
|
||||
│ └─ Detected 2 header injection points │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ [2/4] Attack Execution │
|
||||
│ Progress: [████████████████░░░░] 80/100 (80%) │
|
||||
│ Current: advanced/unicode_smuggling │
|
||||
│ Rate: 2.1 req/s | Elapsed: 00:38 | ETA: 00:10 │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ ✓ Attack Execution Complete │
|
||||
│ └─ Completed 100 attacks │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ [3/4] Verification │
|
||||
│ ⠸ Analyzing responses... │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ ✓ Verification Complete │
|
||||
│ ├─ 12 successful injections detected │
|
||||
│ ├─ 88 attacks blocked/failed │
|
||||
│ └─ 3 high-severity vulnerabilities found │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ [4/4] Report Generation │
|
||||
│ │
|
||||
│ VULNERABILITY SUMMARY │
|
||||
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
|
||||
│ │
|
||||
│ Pattern ID │ Severity │ Status │ Confidence │
|
||||
│ ─────────────────────┼───────────┼───────────┼─────────── │
|
||||
│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │
|
||||
│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │
|
||||
│ context_override │ 🟠 MEDIUM │ ✓ Success │ 82% │
|
||||
│ payload_splitting │ 🟡 LOW │ ✗ Failed │ - │
|
||||
│ │
|
||||
│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │
|
||||
│ │
|
||||
│ 📄 Report saved: ./pit_report_20260126_143022.json │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**END OF SPECIFICATION**
|
||||
67
tools/prompt_injection_tester/__init__.py
Normal file
67
tools/prompt_injection_tester/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prompt Injection Tester
|
||||
|
||||
A professional-grade automated testing framework for LLM prompt injection vulnerabilities.
|
||||
Based on the AI LLM Red Team Handbook Chapter 14.
|
||||
|
||||
Example:
|
||||
from prompt_injection_tester import InjectionTester, AttackConfig
|
||||
|
||||
async with InjectionTester(
|
||||
target_url="https://api.example.com/v1",
|
||||
auth_token="...",
|
||||
) as tester:
|
||||
tester.authorize(scope=["all"])
|
||||
injection_points = await tester.discover_injection_points()
|
||||
results = await tester.run_tests(injection_points=injection_points)
|
||||
report = tester.generate_report(format="html")
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "AI LLM Red Team Handbook"
|
||||
|
||||
from .core import (
|
||||
AttackCategory,
|
||||
AttackConfig,
|
||||
AttackPayload,
|
||||
AttackPattern,
|
||||
DetectionMethod,
|
||||
DetectionResult,
|
||||
InjectionPoint,
|
||||
InjectionPointType,
|
||||
InjectionTester,
|
||||
Severity,
|
||||
TargetConfig,
|
||||
TestContext,
|
||||
TestResult,
|
||||
TestStatus,
|
||||
TestSuite,
|
||||
)
|
||||
from .patterns.registry import registry, register_pattern, get_pattern, list_patterns
|
||||
|
||||
__all__ = [
|
||||
# Core
|
||||
"InjectionTester",
|
||||
"AttackConfig",
|
||||
"TargetConfig",
|
||||
"TestResult",
|
||||
"TestSuite",
|
||||
"InjectionPoint",
|
||||
"AttackPattern",
|
||||
"AttackPayload",
|
||||
# Enums
|
||||
"AttackCategory",
|
||||
"InjectionPointType",
|
||||
"DetectionMethod",
|
||||
"Severity",
|
||||
"TestStatus",
|
||||
# Detection
|
||||
"DetectionResult",
|
||||
"TestContext",
|
||||
# Registry
|
||||
"registry",
|
||||
"register_pattern",
|
||||
"get_pattern",
|
||||
"list_patterns",
|
||||
]
|
||||
38
tools/prompt_injection_tester/core/__init__.py
Normal file
38
tools/prompt_injection_tester/core/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Core modules for the prompt injection tester."""
|
||||
|
||||
from .models import (
|
||||
AttackCategory,
|
||||
AttackConfig,
|
||||
AttackPayload,
|
||||
AttackPattern,
|
||||
DetectionMethod,
|
||||
DetectionResult,
|
||||
InjectionPoint,
|
||||
InjectionPointType,
|
||||
Severity,
|
||||
TargetConfig,
|
||||
TestContext,
|
||||
TestResult,
|
||||
TestStatus,
|
||||
TestSuite,
|
||||
)
|
||||
from .tester import InjectionTester
|
||||
|
||||
__all__ = [
|
||||
"AttackCategory",
|
||||
"AttackConfig",
|
||||
"AttackPayload",
|
||||
"AttackPattern",
|
||||
"DetectionMethod",
|
||||
"DetectionResult",
|
||||
"InjectionPoint",
|
||||
"InjectionPointType",
|
||||
"InjectionTester",
|
||||
"Severity",
|
||||
"TargetConfig",
|
||||
"TestContext",
|
||||
"TestResult",
|
||||
"TestStatus",
|
||||
"TestSuite",
|
||||
]
|
||||
17
tools/prompt_injection_tester/detection/__init__.py
Normal file
17
tools/prompt_injection_tester/detection/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Detection and validation modules for identifying successful injections."""
|
||||
|
||||
from .base import BaseDetector, DetectorRegistry
|
||||
from .system_prompt import SystemPromptLeakDetector
|
||||
from .behavior_change import BehaviorChangeDetector
|
||||
from .tool_misuse import ToolMisuseDetector
|
||||
from .scoring import ConfidenceScorer
|
||||
|
||||
__all__ = [
|
||||
"BaseDetector",
|
||||
"DetectorRegistry",
|
||||
"SystemPromptLeakDetector",
|
||||
"BehaviorChangeDetector",
|
||||
"ToolMisuseDetector",
|
||||
"ConfidenceScorer",
|
||||
]
|
||||
15
tools/prompt_injection_tester/patterns/__init__.py
Normal file
15
tools/prompt_injection_tester/patterns/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Attack pattern modules."""
|
||||
|
||||
from .base import BaseAttackPattern, MultiTurnAttackPattern, CompositeAttackPattern
|
||||
from .registry import registry, register_pattern, get_pattern, list_patterns
|
||||
|
||||
__all__ = [
|
||||
"BaseAttackPattern",
|
||||
"MultiTurnAttackPattern",
|
||||
"CompositeAttackPattern",
|
||||
"registry",
|
||||
"register_pattern",
|
||||
"get_pattern",
|
||||
"list_patterns",
|
||||
]
|
||||
41
tools/prompt_injection_tester/patterns/advanced/__init__.py
Normal file
41
tools/prompt_injection_tester/patterns/advanced/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Advanced Attack Patterns
|
||||
|
||||
Sophisticated techniques including multi-turn attacks, payload fragmentation,
|
||||
encoding/obfuscation, and language-based exploits.
|
||||
Based on Chapter 14 advanced techniques sections.
|
||||
"""
|
||||
|
||||
from .multi_turn import (
|
||||
GradualEscalationPattern,
|
||||
ContextBuildupPattern,
|
||||
TrustEstablishmentPattern,
|
||||
)
|
||||
from .fragmentation import (
|
||||
PayloadFragmentationPattern,
|
||||
TokenSmugglingPattern,
|
||||
ChunkedInjectionPattern,
|
||||
)
|
||||
from .encoding import (
|
||||
Base64EncodingPattern,
|
||||
UnicodeObfuscationPattern,
|
||||
LanguageSwitchingPattern,
|
||||
LeetSpeakPattern,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Multi-turn
|
||||
"GradualEscalationPattern",
|
||||
"ContextBuildupPattern",
|
||||
"TrustEstablishmentPattern",
|
||||
# Fragmentation
|
||||
"PayloadFragmentationPattern",
|
||||
"TokenSmugglingPattern",
|
||||
"ChunkedInjectionPattern",
|
||||
# Encoding/Obfuscation
|
||||
"Base64EncodingPattern",
|
||||
"UnicodeObfuscationPattern",
|
||||
"LanguageSwitchingPattern",
|
||||
"LeetSpeakPattern",
|
||||
]
|
||||
38
tools/prompt_injection_tester/patterns/direct/__init__.py
Normal file
38
tools/prompt_injection_tester/patterns/direct/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Direct Injection Attack Patterns
|
||||
|
||||
These patterns target the primary LLM input channel where attacker-controlled
|
||||
content is directly processed by the model. Based on Chapter 14 techniques.
|
||||
"""
|
||||
|
||||
from .instruction_override import (
|
||||
InstructionOverridePattern,
|
||||
SystemPromptOverridePattern,
|
||||
TaskHijackingPattern,
|
||||
)
|
||||
from .role_manipulation import (
|
||||
RoleAuthorityPattern,
|
||||
PersonaShiftPattern,
|
||||
DeveloperModePattern,
|
||||
)
|
||||
from .delimiter_confusion import (
|
||||
DelimiterEscapePattern,
|
||||
XMLInjectionPattern,
|
||||
MarkdownInjectionPattern,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Instruction Override
|
||||
"InstructionOverridePattern",
|
||||
"SystemPromptOverridePattern",
|
||||
"TaskHijackingPattern",
|
||||
# Role Manipulation
|
||||
"RoleAuthorityPattern",
|
||||
"PersonaShiftPattern",
|
||||
"DeveloperModePattern",
|
||||
# Delimiter Confusion
|
||||
"DelimiterEscapePattern",
|
||||
"XMLInjectionPattern",
|
||||
"MarkdownInjectionPattern",
|
||||
]
|
||||
39
tools/prompt_injection_tester/patterns/indirect/__init__.py
Normal file
39
tools/prompt_injection_tester/patterns/indirect/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Indirect Injection Attack Patterns
|
||||
|
||||
These patterns target secondary data sources that the LLM processes,
|
||||
such as retrieved documents, web pages, emails, and database content.
|
||||
Based on Chapter 14 indirect injection techniques.
|
||||
"""
|
||||
|
||||
from .document_poisoning import (
|
||||
RAGPoisoningPattern,
|
||||
DocumentMetadataInjectionPattern,
|
||||
HiddenTextInjectionPattern,
|
||||
)
|
||||
from .web_injection import (
|
||||
WebPageInjectionPattern,
|
||||
SEOPoisoningPattern,
|
||||
CommentInjectionPattern,
|
||||
)
|
||||
from .email_injection import (
|
||||
EmailBodyInjectionPattern,
|
||||
EmailHeaderInjectionPattern,
|
||||
AttachmentInjectionPattern,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Document Poisoning
|
||||
"RAGPoisoningPattern",
|
||||
"DocumentMetadataInjectionPattern",
|
||||
"HiddenTextInjectionPattern",
|
||||
# Web Injection
|
||||
"WebPageInjectionPattern",
|
||||
"SEOPoisoningPattern",
|
||||
"CommentInjectionPattern",
|
||||
# Email Injection
|
||||
"EmailBodyInjectionPattern",
|
||||
"EmailHeaderInjectionPattern",
|
||||
"AttachmentInjectionPattern",
|
||||
]
|
||||
10
tools/prompt_injection_tester/pit/__init__.py
Normal file
10
tools/prompt_injection_tester/pit/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Prompt Injection Tester (PIT) - Modern CLI Interface
|
||||
|
||||
A premium terminal experience for LLM security assessment.
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = ["app"]
|
||||
|
||||
from pit.app import app
|
||||
381
tools/prompt_injection_tester/pit/cli.py
Normal file
381
tools/prompt_injection_tester/pit/cli.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Main CLI application using Typer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from pit import __version__
|
||||
from pit.config import Config, load_config
|
||||
from pit.errors import handle_error
|
||||
|
||||
app = typer.Typer(
|
||||
name="pit",
|
||||
help="Prompt Injection Tester - Automated LLM Security Testing",
|
||||
add_completion=False,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def scan(
|
||||
target_url: str = typer.Argument(
|
||||
...,
|
||||
help="Target API endpoint URL",
|
||||
),
|
||||
token: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--token",
|
||||
"-t",
|
||||
help="Authentication token (or use $PIT_TOKEN)",
|
||||
envvar="PIT_TOKEN",
|
||||
),
|
||||
config_file: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--config",
|
||||
"-c",
|
||||
help="Configuration file (YAML)",
|
||||
exists=True,
|
||||
),
|
||||
auto: bool = typer.Option(
|
||||
False,
|
||||
"--auto",
|
||||
"-a",
|
||||
help="Run all phases automatically (non-interactive)",
|
||||
),
|
||||
api_type: str = typer.Option(
|
||||
"openai",
|
||||
"--api-type",
|
||||
help="API type (openai, anthropic, custom)",
|
||||
),
|
||||
timeout: int = typer.Option(
|
||||
30,
|
||||
"--timeout",
|
||||
help="Request timeout in seconds",
|
||||
),
|
||||
categories: Optional[List[str]] = typer.Option(
|
||||
None,
|
||||
"--categories",
|
||||
help="Attack categories (direct, indirect, advanced)",
|
||||
),
|
||||
patterns: Optional[List[str]] = typer.Option(
|
||||
None,
|
||||
"--patterns",
|
||||
help="Specific pattern IDs to test",
|
||||
),
|
||||
max_concurrent: int = typer.Option(
|
||||
5,
|
||||
"--max-concurrent",
|
||||
help="Maximum concurrent requests",
|
||||
),
|
||||
rate_limit: float = typer.Option(
|
||||
1.0,
|
||||
"--rate-limit",
|
||||
help="Requests per second",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file path (auto-generated if not specified)",
|
||||
),
|
||||
output_format: str = typer.Option(
|
||||
"json",
|
||||
"--format",
|
||||
"-f",
|
||||
help="Report format (json, yaml, html)",
|
||||
),
|
||||
include_cvss: bool = typer.Option(
|
||||
True,
|
||||
"--include-cvss/--no-cvss",
|
||||
help="Include CVSS scores in report",
|
||||
),
|
||||
include_payloads: bool = typer.Option(
|
||||
False,
|
||||
"--include-payloads/--no-payloads",
|
||||
help="Include attack payloads in report",
|
||||
),
|
||||
authorize: bool = typer.Option(
|
||||
False,
|
||||
"--authorize",
|
||||
help="Confirm authorization (skip interactive prompt)",
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Enable verbose output",
|
||||
),
|
||||
quiet: bool = typer.Option(
|
||||
False,
|
||||
"--quiet",
|
||||
"-q",
|
||||
help="Suppress all output except errors",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Scan a target API for prompt injection vulnerabilities.
|
||||
|
||||
Example:
|
||||
pit scan https://api.example.com/v1/chat --auto --token $TOKEN
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
config = load_config(
|
||||
config_path=config_file,
|
||||
target_url=target_url,
|
||||
token=token,
|
||||
api_type=api_type,
|
||||
timeout=timeout,
|
||||
categories=categories,
|
||||
patterns=patterns,
|
||||
max_concurrent=max_concurrent,
|
||||
rate_limit=rate_limit,
|
||||
output_format=output_format,
|
||||
output=output,
|
||||
include_cvss=include_cvss,
|
||||
include_payloads=include_payloads,
|
||||
authorize=authorize,
|
||||
)
|
||||
|
||||
# Verify authorization
|
||||
if not config.authorization.confirmed and not auto:
|
||||
if not _prompt_authorization(target_url):
|
||||
console.print("[yellow]Scan cancelled by user[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
config.authorization.confirmed = True
|
||||
|
||||
# Print banner
|
||||
if not quiet:
|
||||
_print_banner(target_url)
|
||||
|
||||
# Run scan
|
||||
exit_code = asyncio.run(_run_scan(config, verbose, quiet))
|
||||
raise typer.Exit(exit_code)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]⚠ Interrupted by user[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
except Exception as e:
|
||||
exit_code = handle_error(e, verbose)
|
||||
raise typer.Exit(exit_code)
|
||||
|
||||
|
||||
@app.command()
|
||||
def list_command(
|
||||
type: str = typer.Argument(
|
||||
"patterns",
|
||||
help="What to list (patterns, categories)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
List available attack patterns or categories.
|
||||
|
||||
Example:
|
||||
pit list patterns
|
||||
pit list categories
|
||||
"""
|
||||
from patterns.registry import registry
|
||||
from core.models import AttackCategory
|
||||
|
||||
if type == "patterns":
|
||||
_list_patterns(registry)
|
||||
elif type == "categories":
|
||||
_list_categories()
|
||||
else:
|
||||
console.print(f"[red]Unknown type: {type}[/red]")
|
||||
console.print("Available types: patterns, categories")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
app.command(name="list")(list_command)
|
||||
|
||||
|
||||
@app.command()
|
||||
def auth(
|
||||
target_url: str = typer.Argument(
|
||||
...,
|
||||
help="Target API endpoint URL",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Verify authorization to test a target.
|
||||
|
||||
Example:
|
||||
pit auth https://api.example.com
|
||||
"""
|
||||
if _prompt_authorization(target_url):
|
||||
console.print("[green]✓ Authorization confirmed[/green]")
|
||||
raise typer.Exit(0)
|
||||
else:
|
||||
console.print("[yellow]Authorization not confirmed[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
version: bool = typer.Option(
|
||||
False,
|
||||
"--version",
|
||||
help="Show version and exit",
|
||||
),
|
||||
):
|
||||
"""PIT - Prompt Injection Tester"""
|
||||
if version:
|
||||
console.print(f"[cyan]pit[/cyan] version {__version__}")
|
||||
raise typer.Exit(0)
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
console.print(ctx.get_help())
|
||||
|
||||
|
||||
# Helper functions
|
||||
|
||||
|
||||
async def _run_scan(config: Config, verbose: bool, quiet: bool) -> int:
|
||||
"""
|
||||
Run the scan pipeline.
|
||||
|
||||
Args:
|
||||
config: Configuration
|
||||
verbose: Verbose output
|
||||
quiet: Quiet mode
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, non-zero for error)
|
||||
"""
|
||||
from pit.orchestrator.pipeline import Pipeline, PipelineContext
|
||||
from pit.orchestrator.phases import (
|
||||
AttackPhase,
|
||||
DiscoveryPhase,
|
||||
ReportingPhase,
|
||||
VerificationPhase,
|
||||
)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
phases=[
|
||||
DiscoveryPhase(),
|
||||
AttackPhase(),
|
||||
VerificationPhase(),
|
||||
ReportingPhase(),
|
||||
]
|
||||
)
|
||||
|
||||
# Create context
|
||||
context = PipelineContext(
|
||||
target_url=config.target.url,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Run pipeline
|
||||
context = await pipeline.run(context)
|
||||
|
||||
# Determine exit code
|
||||
if context.interrupted:
|
||||
return 130
|
||||
elif context.report:
|
||||
# Check if vulnerabilities found
|
||||
successful = context.report.get("summary", {}).get("successful_attacks", 0)
|
||||
return 2 if successful > 0 else 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def _prompt_authorization(target_url: str) -> bool:
|
||||
"""
|
||||
Prompt user for authorization confirmation.
|
||||
|
||||
Args:
|
||||
target_url: Target URL
|
||||
|
||||
Returns:
|
||||
True if authorized, False otherwise
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
panel = Panel(
|
||||
f"[yellow]Target:[/yellow] {target_url}\n\n"
|
||||
"[yellow]⚠ You must have explicit authorization to test this system.[/yellow]\n"
|
||||
"[yellow]Unauthorized testing may be illegal in your jurisdiction.[/yellow]\n\n"
|
||||
"Do you have authorization to test this target?",
|
||||
title="[bold red]Authorization Required[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
|
||||
console.print()
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
response = typer.prompt("Confirm [y/N]", default="n")
|
||||
return response.lower() == "y"
|
||||
|
||||
|
||||
def _print_banner(target_url: str) -> None:
|
||||
"""Print application banner."""
|
||||
from rich.panel import Panel
|
||||
|
||||
banner = f"""[bold cyan]PIT - Prompt Injection Tester[/bold cyan] v{__version__}
|
||||
|
||||
[cyan]Target:[/cyan] {target_url}
|
||||
[cyan]Authorization:[/cyan] ✓ Confirmed
|
||||
"""
|
||||
|
||||
console.print()
|
||||
console.print(Panel(banner, border_style="cyan", expand=False))
|
||||
|
||||
|
||||
def _list_patterns(registry) -> None:
|
||||
"""List available attack patterns."""
|
||||
from core.models import AttackCategory
|
||||
from rich.table import Table
|
||||
|
||||
if len(registry) == 0:
|
||||
registry.load_builtin_patterns()
|
||||
|
||||
table = Table(title="Available Attack Patterns", show_header=True)
|
||||
table.add_column("Category", style="cyan")
|
||||
table.add_column("Pattern ID", style="yellow")
|
||||
table.add_column("Description", style="white")
|
||||
|
||||
for category in AttackCategory:
|
||||
patterns = registry.list_by_category(category)
|
||||
for i, pid in enumerate(patterns):
|
||||
pattern_class = registry.get(pid)
|
||||
if pattern_class:
|
||||
table.add_row(
|
||||
category.value if i == 0 else "",
|
||||
pid,
|
||||
getattr(pattern_class, "name", pid),
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _list_categories() -> None:
|
||||
"""List attack categories."""
|
||||
from core.models import AttackCategory
|
||||
from rich.table import Table
|
||||
|
||||
table = Table(title="Attack Categories", show_header=True)
|
||||
table.add_column("Category", style="cyan")
|
||||
table.add_column("Description", style="white")
|
||||
|
||||
for category in AttackCategory:
|
||||
table.add_row(category.value, f"{category.value} injection attacks")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
7
tools/prompt_injection_tester/pit/commands/__init__.py
Normal file
7
tools/prompt_injection_tester/pit/commands/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
CLI command modules.
|
||||
"""
|
||||
|
||||
from pit.commands import scan
|
||||
|
||||
__all__ = ["scan"]
|
||||
20
tools/prompt_injection_tester/pit/config/__init__.py
Normal file
20
tools/prompt_injection_tester/pit/config/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Configuration module."""
|
||||
|
||||
from .loader import load_config, load_config_file
|
||||
from .schema import (
|
||||
AttackConfig,
|
||||
AuthorizationConfig,
|
||||
Config,
|
||||
ReportingConfig,
|
||||
TargetConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"TargetConfig",
|
||||
"AttackConfig",
|
||||
"ReportingConfig",
|
||||
"AuthorizationConfig",
|
||||
"load_config",
|
||||
"load_config_file",
|
||||
]
|
||||
82
tools/prompt_injection_tester/pit/config/loader.py
Normal file
82
tools/prompt_injection_tester/pit/config/loader.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Configuration file loader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from .schema import Config
|
||||
|
||||
|
||||
def load_config_file(config_path: Path) -> Config:
|
||||
"""
|
||||
Load configuration from a YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
|
||||
Returns:
|
||||
Loaded configuration
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist
|
||||
ValueError: If config file is invalid
|
||||
"""
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# Expand environment variables in token
|
||||
if "target" in data and "token" in data["target"]:
|
||||
token = data["target"]["token"]
|
||||
if isinstance(token, str) and token.startswith("${") and token.endswith("}"):
|
||||
env_var = token[2:-1]
|
||||
data["target"]["token"] = os.getenv(env_var)
|
||||
|
||||
try:
|
||||
return Config(**data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid configuration file: {e}") from e
|
||||
|
||||
|
||||
def load_config(
|
||||
config_path: Optional[Path],
|
||||
target_url: Optional[str] = None,
|
||||
**cli_overrides,
|
||||
) -> Config:
|
||||
"""
|
||||
Load configuration from file or CLI arguments.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file (optional)
|
||||
target_url: Target URL from CLI (optional)
|
||||
**cli_overrides: CLI argument overrides
|
||||
|
||||
Returns:
|
||||
Merged configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If neither config_path nor target_url is provided
|
||||
"""
|
||||
if config_path:
|
||||
config = load_config_file(config_path)
|
||||
|
||||
# Apply CLI overrides
|
||||
if target_url:
|
||||
config.target.url = target_url
|
||||
if "token" in cli_overrides and cli_overrides["token"]:
|
||||
config.target.token = cli_overrides["token"]
|
||||
|
||||
return config
|
||||
|
||||
elif target_url:
|
||||
# Build config from CLI args
|
||||
return Config.from_cli_args(target_url=target_url, **cli_overrides)
|
||||
|
||||
else:
|
||||
raise ValueError("Either --config or target URL must be provided")
|
||||
148
tools/prompt_injection_tester/pit/config/schema.py
Normal file
148
tools/prompt_injection_tester/pit/config/schema.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Configuration schema using Pydantic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TargetConfig(BaseModel):
|
||||
"""Target API configuration."""
|
||||
|
||||
url: str = Field(..., description="Target API endpoint URL")
|
||||
token: Optional[str] = Field(None, description="Authentication token")
|
||||
api_type: str = Field(default="openai", description="API type (openai, anthropic, custom)")
|
||||
timeout: int = Field(default=30, ge=1, le=300, description="Request timeout in seconds")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str) -> str:
|
||||
"""Ensure URL is properly formatted."""
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
return v
|
||||
|
||||
|
||||
class AttackConfig(BaseModel):
|
||||
"""Attack execution configuration."""
|
||||
|
||||
categories: List[str] = Field(
|
||||
default=["direct", "indirect"],
|
||||
description="Attack categories to test",
|
||||
)
|
||||
patterns: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Specific pattern IDs to test (overrides categories)",
|
||||
)
|
||||
exclude_patterns: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Pattern IDs to exclude",
|
||||
)
|
||||
max_concurrent: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum concurrent requests",
|
||||
)
|
||||
rate_limit: float = Field(
|
||||
default=1.0,
|
||||
ge=0.1,
|
||||
le=10.0,
|
||||
description="Requests per second",
|
||||
)
|
||||
|
||||
|
||||
class ReportingConfig(BaseModel):
|
||||
"""Report generation configuration."""
|
||||
|
||||
format: str = Field(
|
||||
default="json",
|
||||
description="Report format (json, yaml, html)",
|
||||
)
|
||||
output: Optional[Path] = Field(
|
||||
None,
|
||||
description="Output file path (auto-generated if not specified)",
|
||||
)
|
||||
include_cvss: bool = Field(
|
||||
default=True,
|
||||
description="Include CVSS scores in report",
|
||||
)
|
||||
include_payloads: bool = Field(
|
||||
default=False,
|
||||
description="Include attack payloads in report (may be sensitive)",
|
||||
)
|
||||
|
||||
@field_validator("format")
|
||||
@classmethod
|
||||
def validate_format(cls, v: str) -> str:
|
||||
"""Ensure format is supported."""
|
||||
allowed = ["json", "yaml", "html"]
|
||||
if v not in allowed:
|
||||
raise ValueError(f"Format must be one of {allowed}")
|
||||
return v
|
||||
|
||||
|
||||
class AuthorizationConfig(BaseModel):
|
||||
"""Authorization configuration."""
|
||||
|
||||
scope: List[str] = Field(
|
||||
default=["all"],
|
||||
description="Authorization scope",
|
||||
)
|
||||
confirmed: bool = Field(
|
||||
default=False,
|
||||
description="Skip interactive authorization prompt",
|
||||
)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""Complete PIT configuration."""
|
||||
|
||||
target: TargetConfig
|
||||
attack: AttackConfig = Field(default_factory=AttackConfig)
|
||||
reporting: ReportingConfig = Field(default_factory=ReportingConfig)
|
||||
authorization: AuthorizationConfig = Field(default_factory=AuthorizationConfig)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls,
|
||||
target_url: str,
|
||||
token: Optional[str] = None,
|
||||
api_type: str = "openai",
|
||||
timeout: int = 30,
|
||||
categories: Optional[List[str]] = None,
|
||||
patterns: Optional[List[str]] = None,
|
||||
max_concurrent: int = 5,
|
||||
rate_limit: float = 1.0,
|
||||
output_format: str = "json",
|
||||
output: Optional[Path] = None,
|
||||
include_cvss: bool = True,
|
||||
include_payloads: bool = False,
|
||||
authorize: bool = False,
|
||||
) -> Config:
|
||||
"""Create config from CLI arguments."""
|
||||
return cls(
|
||||
target=TargetConfig(
|
||||
url=target_url,
|
||||
token=token,
|
||||
api_type=api_type,
|
||||
timeout=timeout,
|
||||
),
|
||||
attack=AttackConfig(
|
||||
categories=categories or ["direct", "indirect"],
|
||||
patterns=patterns,
|
||||
max_concurrent=max_concurrent,
|
||||
rate_limit=rate_limit,
|
||||
),
|
||||
reporting=ReportingConfig(
|
||||
format=output_format,
|
||||
output=output,
|
||||
include_cvss=include_cvss,
|
||||
include_payloads=include_payloads,
|
||||
),
|
||||
authorization=AuthorizationConfig(
|
||||
confirmed=authorize,
|
||||
),
|
||||
)
|
||||
47
tools/prompt_injection_tester/pit/errors/__init__.py
Normal file
47
tools/prompt_injection_tester/pit/errors/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Error handling module."""
|
||||
|
||||
from .exceptions import (
|
||||
AttackError,
|
||||
AuthenticationError,
|
||||
ConfigError,
|
||||
ConfigNotFoundError,
|
||||
ConfigValidationError,
|
||||
DetectionFailedError,
|
||||
DiscoveryError,
|
||||
ExecutionError,
|
||||
FileWriteError,
|
||||
FormatError,
|
||||
NoEndpointsFoundError,
|
||||
PatternLoadError,
|
||||
PitError,
|
||||
RateLimitError,
|
||||
ReportingError,
|
||||
ScanFailedError,
|
||||
TargetError,
|
||||
TargetUnreachableError,
|
||||
VerificationError,
|
||||
)
|
||||
from .handlers import handle_error
|
||||
|
||||
__all__ = [
|
||||
"PitError",
|
||||
"ConfigError",
|
||||
"ConfigNotFoundError",
|
||||
"ConfigValidationError",
|
||||
"TargetError",
|
||||
"TargetUnreachableError",
|
||||
"AuthenticationError",
|
||||
"RateLimitError",
|
||||
"DiscoveryError",
|
||||
"NoEndpointsFoundError",
|
||||
"ScanFailedError",
|
||||
"AttackError",
|
||||
"PatternLoadError",
|
||||
"ExecutionError",
|
||||
"VerificationError",
|
||||
"DetectionFailedError",
|
||||
"ReportingError",
|
||||
"FormatError",
|
||||
"FileWriteError",
|
||||
"handle_error",
|
||||
]
|
||||
131
tools/prompt_injection_tester/pit/errors/exceptions.py
Normal file
131
tools/prompt_injection_tester/pit/errors/exceptions.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Custom exception hierarchy for PIT."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class PitError(Exception):
|
||||
"""Base exception for all PIT errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Configuration Errors
|
||||
class ConfigError(PitError):
|
||||
"""Base configuration error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConfigNotFoundError(ConfigError):
|
||||
"""Configuration file not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConfigValidationError(ConfigError):
|
||||
"""Configuration validation failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Target Errors
|
||||
class TargetError(PitError):
|
||||
"""Base target error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TargetUnreachableError(TargetError):
|
||||
"""Target is unreachable."""
|
||||
|
||||
def __init__(self, url: str, reason: str = ""):
|
||||
self.url = url
|
||||
self.reason = reason
|
||||
super().__init__(f"Target unreachable: {url}")
|
||||
|
||||
|
||||
class AuthenticationError(TargetError):
|
||||
"""Authentication failed."""
|
||||
|
||||
def __init__(self, status_code: int, message: str = ""):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"Authentication failed: {status_code}")
|
||||
|
||||
|
||||
class RateLimitError(TargetError):
|
||||
"""Rate limit exceeded."""
|
||||
|
||||
def __init__(self, retry_after: int = 60):
|
||||
self.retry_after = retry_after
|
||||
super().__init__(f"Rate limit exceeded, retry after {retry_after}s")
|
||||
|
||||
|
||||
# Discovery Errors
|
||||
class DiscoveryError(PitError):
|
||||
"""Base discovery error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NoEndpointsFoundError(DiscoveryError):
|
||||
"""No endpoints found during discovery."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScanFailedError(DiscoveryError):
|
||||
"""Discovery scan failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Attack Errors
|
||||
class AttackError(PitError):
|
||||
"""Base attack error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PatternLoadError(AttackError):
|
||||
"""Failed to load attack pattern."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionError(AttackError):
|
||||
"""Attack execution failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Verification Errors
|
||||
class VerificationError(PitError):
|
||||
"""Base verification error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DetectionFailedError(VerificationError):
|
||||
"""Detection/verification failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Reporting Errors
|
||||
class ReportingError(PitError):
|
||||
"""Base reporting error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FormatError(ReportingError):
|
||||
"""Invalid report format."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileWriteError(ReportingError):
|
||||
"""Failed to write report file."""
|
||||
|
||||
pass
|
||||
84
tools/prompt_injection_tester/pit/errors/handlers.py
Normal file
84
tools/prompt_injection_tester/pit/errors/handlers.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""User-friendly error handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
ConfigError,
|
||||
NoEndpointsFoundError,
|
||||
PitError,
|
||||
RateLimitError,
|
||||
TargetUnreachableError,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def handle_error(error: Exception, verbose: bool = False) -> int:
|
||||
"""
|
||||
Convert exceptions to user-friendly messages.
|
||||
|
||||
Args:
|
||||
error: The exception to handle
|
||||
verbose: Show detailed error information
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, non-zero for error)
|
||||
"""
|
||||
if isinstance(error, TargetUnreachableError):
|
||||
console.print("[red]✗ Target Unreachable[/red]")
|
||||
console.print(f" ├─ Could not connect to: {error.url}")
|
||||
if error.reason:
|
||||
console.print(f" ├─ Reason: {error.reason}")
|
||||
console.print(" ├─ Suggestion: Check the URL and network connection")
|
||||
console.print(" └─ Or use: [cyan]pit scan --skip-discovery --injection-points <file>[/cyan]")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, AuthenticationError):
|
||||
console.print("[red]✗ Authentication Failed[/red]")
|
||||
console.print(f" ├─ Status code: {error.status_code}")
|
||||
if error.message:
|
||||
console.print(f" ├─ Message: {error.message}")
|
||||
console.print(" ├─ Suggestion: Verify your API token with [cyan]--token[/cyan]")
|
||||
console.print(" └─ Or run: [cyan]pit auth <url>[/cyan]")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, RateLimitError):
|
||||
console.print("[yellow]⚠ Rate Limited[/yellow]")
|
||||
console.print(f" ├─ Retry after: {error.retry_after} seconds")
|
||||
console.print(" └─ Suggestion: Reduce [cyan]--rate-limit[/cyan] or [cyan]--max-concurrent[/cyan]")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, NoEndpointsFoundError):
|
||||
console.print("[yellow]⚠ No Endpoints Found[/yellow]")
|
||||
console.print(" ├─ Discovery phase found no injection points")
|
||||
console.print(" ├─ Suggestion 1: Specify endpoints manually with [cyan]--injection-points <file>[/cyan]")
|
||||
console.print(" └─ Suggestion 2: Check if the target URL is correct")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, ConfigError):
|
||||
console.print("[red]✗ Configuration Error[/red]")
|
||||
console.print(f" ├─ {str(error)}")
|
||||
console.print(" └─ Suggestion: Check your configuration file or CLI arguments")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, PitError):
|
||||
console.print(f"[red]✗ Error: {error.__class__.__name__}[/red]")
|
||||
console.print(f" └─ {str(error)}")
|
||||
return 1
|
||||
|
||||
elif isinstance(error, KeyboardInterrupt):
|
||||
console.print("\n[yellow]⚠ Scan Interrupted[/yellow]")
|
||||
console.print(" └─ Use [cyan]Ctrl+C[/cyan] again to force quit")
|
||||
return 130
|
||||
|
||||
else:
|
||||
console.print(f"[red]✗ Unexpected Error: {error.__class__.__name__}[/red]")
|
||||
console.print(f" └─ {str(error)}")
|
||||
if verbose:
|
||||
console.print_exception()
|
||||
else:
|
||||
console.print(" └─ Run with [cyan]--verbose[/cyan] for details")
|
||||
return 1
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Orchestrator layer for coordinating CLI and core framework.
|
||||
"""
|
||||
|
||||
from pit.orchestrator.workflow import WorkflowOrchestrator
|
||||
from pit.orchestrator.discovery import AutoDiscovery
|
||||
|
||||
__all__ = ["WorkflowOrchestrator", "AutoDiscovery"]
|
||||
495
tools/prompt_injection_tester/pit/orchestrator/phases.py
Normal file
495
tools/prompt_injection_tester/pit/orchestrator/phases.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""Phase definitions for the sequential pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pit.orchestrator.pipeline import PipelineContext
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class PhaseStatus(Enum):
|
||||
"""Phase execution status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Result from a phase execution."""
|
||||
|
||||
status: PhaseStatus
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Phase(ABC):
|
||||
"""
|
||||
Abstract base class for pipeline phases.
|
||||
|
||||
Each phase MUST complete fully before returning.
|
||||
"""
|
||||
|
||||
name: str = "Base Phase"
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, context: PipelineContext) -> PhaseResult:
|
||||
"""
|
||||
Execute the phase.
|
||||
|
||||
MUST complete before returning (no background tasks).
|
||||
|
||||
Args:
|
||||
context: Pipeline context (shared state)
|
||||
|
||||
Returns:
|
||||
Phase result with status and optional data
|
||||
|
||||
Raises:
|
||||
Exception: If phase fails unexpectedly
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DiscoveryPhase(Phase):
|
||||
"""
|
||||
Phase 1: Discovery
|
||||
|
||||
Scans the target for injection points.
|
||||
"""
|
||||
|
||||
name = "Discovery"
|
||||
|
||||
async def execute(self, context: PipelineContext) -> PhaseResult:
|
||||
"""
|
||||
Execute discovery phase.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Phase result with discovered injection points
|
||||
"""
|
||||
from pit.ui.progress import create_spinner
|
||||
|
||||
console.print("[cyan]Discovering injection points...[/cyan]")
|
||||
|
||||
try:
|
||||
with create_spinner():
|
||||
task = asyncio.create_task(
|
||||
self._discover_injection_points(context)
|
||||
)
|
||||
# WAIT for discovery to complete
|
||||
injection_points = await task
|
||||
|
||||
context.injection_points = injection_points
|
||||
|
||||
if not injection_points:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error="No injection points found",
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Found {len(injection_points)} injection point(s)"
|
||||
)
|
||||
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.COMPLETED,
|
||||
message=message,
|
||||
data={"injection_points": injection_points},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _discover_injection_points(
|
||||
self, context: PipelineContext
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Run discovery logic.
|
||||
|
||||
TODO: Implement actual discovery logic using discovery module.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
List of discovered injection points
|
||||
"""
|
||||
# Placeholder: Replace with actual discovery implementation
|
||||
from core.models import InjectionPoint, InjectionPointType
|
||||
|
||||
# Simulate discovery
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Mock injection points for now
|
||||
return [
|
||||
InjectionPoint(
|
||||
id="param_prompt",
|
||||
type=InjectionPointType.PARAMETER,
|
||||
name="prompt",
|
||||
location="body",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class AttackPhase(Phase):
|
||||
"""
|
||||
Phase 2: Attack Execution
|
||||
|
||||
Executes attack patterns against discovered injection points.
|
||||
"""
|
||||
|
||||
name = "Attack Execution"
|
||||
|
||||
async def execute(self, context: PipelineContext) -> PhaseResult:
|
||||
"""
|
||||
Execute attack phase.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Phase result with test results
|
||||
"""
|
||||
from pit.ui.progress import create_progress_bar
|
||||
|
||||
injection_points = context.injection_points
|
||||
|
||||
if not injection_points:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error="No injection points available",
|
||||
)
|
||||
|
||||
try:
|
||||
# Load attack patterns
|
||||
patterns = await self._load_patterns(context)
|
||||
|
||||
console.print(f"[cyan]Loaded {len(patterns)} attack pattern(s)[/cyan]")
|
||||
|
||||
# Execute attacks with progress bar
|
||||
results = []
|
||||
with create_progress_bar() as progress:
|
||||
task = progress.add_task(
|
||||
"Running attacks",
|
||||
total=len(patterns),
|
||||
)
|
||||
|
||||
for pattern in patterns:
|
||||
# Execute attack (internal concurrency OK)
|
||||
result = await self._execute_attack(
|
||||
pattern, injection_points[0], context
|
||||
)
|
||||
results.append(result)
|
||||
progress.update(task, advance=1)
|
||||
# Respect rate limiting
|
||||
await asyncio.sleep(1.0 / context.config.attack.rate_limit)
|
||||
|
||||
context.test_results = results
|
||||
|
||||
message = f"Completed {len(results)} attack(s)"
|
||||
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.COMPLETED,
|
||||
message=message,
|
||||
data={"test_results": results},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _load_patterns(self, context: PipelineContext) -> List[Any]:
|
||||
"""
|
||||
Load attack patterns from registry.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
List of attack patterns
|
||||
"""
|
||||
from patterns.registry import registry
|
||||
|
||||
# Load patterns based on config
|
||||
categories = context.config.attack.categories
|
||||
|
||||
all_patterns = []
|
||||
for category in categories:
|
||||
patterns = registry.list_by_category(category)
|
||||
all_patterns.extend(patterns)
|
||||
|
||||
# Return pattern IDs for now
|
||||
# TODO: Return actual pattern instances
|
||||
return all_patterns[:10] # Limit for demo
|
||||
|
||||
async def _execute_attack(
|
||||
self, pattern: Any, injection_point: Any, context: PipelineContext
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a single attack pattern.
|
||||
|
||||
Args:
|
||||
pattern: Attack pattern
|
||||
injection_point: Target injection point
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Test result
|
||||
"""
|
||||
from core.models import TestResult, TestStatus
|
||||
|
||||
# Simulate attack execution
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mock result
|
||||
return TestResult(
|
||||
pattern_id=str(pattern),
|
||||
injection_point_id=injection_point.id,
|
||||
status=TestStatus.SUCCESS,
|
||||
payload="test_payload",
|
||||
response=None,
|
||||
)
|
||||
|
||||
|
||||
class VerificationPhase(Phase):
|
||||
"""
|
||||
Phase 3: Verification
|
||||
|
||||
Analyzes attack responses to verify success.
|
||||
"""
|
||||
|
||||
name = "Verification"
|
||||
|
||||
async def execute(self, context: PipelineContext) -> PhaseResult:
|
||||
"""
|
||||
Execute verification phase.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Phase result with verified results
|
||||
"""
|
||||
from pit.ui.progress import create_spinner
|
||||
|
||||
test_results = context.test_results
|
||||
|
||||
if not test_results:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error="No test results to verify",
|
||||
)
|
||||
|
||||
console.print("[cyan]Analyzing responses...[/cyan]")
|
||||
|
||||
try:
|
||||
with create_spinner():
|
||||
# WAIT for verification to complete
|
||||
verified = await self._verify_results(test_results, context)
|
||||
|
||||
context.verified_results = verified
|
||||
|
||||
successful = sum(1 for r in verified if r.get("status") == "success")
|
||||
message = f"Verified {len(verified)} result(s), {successful} successful"
|
||||
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.COMPLETED,
|
||||
message=message,
|
||||
data={"verified_results": verified},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _verify_results(
|
||||
self, test_results: List[Any], context: PipelineContext
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Verify test results.
|
||||
|
||||
Args:
|
||||
test_results: List of test results
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
List of verified results with confidence scores
|
||||
"""
|
||||
# Simulate verification
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Mock verified results
|
||||
verified = []
|
||||
for result in test_results:
|
||||
verified.append({
|
||||
"pattern_id": result.pattern_id,
|
||||
"status": "success" if "test" in result.pattern_id else "failed",
|
||||
"severity": "medium",
|
||||
"confidence": 0.85,
|
||||
})
|
||||
|
||||
return verified
|
||||
|
||||
|
||||
class ReportingPhase(Phase):
|
||||
"""
|
||||
Phase 4: Reporting
|
||||
|
||||
Generates and saves the final report.
|
||||
"""
|
||||
|
||||
name = "Report Generation"
|
||||
|
||||
async def execute(self, context: PipelineContext) -> PhaseResult:
|
||||
"""
|
||||
Execute reporting phase.
|
||||
|
||||
Args:
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Phase result with report path
|
||||
"""
|
||||
verified_results = context.verified_results
|
||||
|
||||
if not verified_results:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error="No results to report",
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate report
|
||||
report = await self._generate_report(verified_results, context)
|
||||
context.report = report
|
||||
|
||||
# Save to file
|
||||
report_path = await self._save_report(report, context)
|
||||
context.report_path = report_path
|
||||
|
||||
# Display summary
|
||||
self._display_summary(report, report_path)
|
||||
|
||||
message = f"Report saved to {report_path}"
|
||||
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.COMPLETED,
|
||||
message=message,
|
||||
data={"report_path": str(report_path)},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return PhaseResult(
|
||||
status=PhaseStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def _generate_report(
|
||||
self, verified_results: List[Dict[str, Any]], context: PipelineContext
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate report data structure.
|
||||
|
||||
Args:
|
||||
verified_results: Verified results
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Report dictionary
|
||||
"""
|
||||
successful = [r for r in verified_results if r.get("status") == "success"]
|
||||
total = len(verified_results)
|
||||
success_rate = len(successful) / total if total > 0 else 0.0
|
||||
|
||||
report = {
|
||||
"metadata": {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"target": context.target_url,
|
||||
"duration_seconds": sum(context.phase_durations.values()),
|
||||
},
|
||||
"summary": {
|
||||
"total_tests": total,
|
||||
"successful_attacks": len(successful),
|
||||
"success_rate": success_rate,
|
||||
},
|
||||
"results": verified_results,
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
async def _save_report(
|
||||
self, report: Dict[str, Any], context: PipelineContext
|
||||
) -> Path:
|
||||
"""
|
||||
Save report to file.
|
||||
|
||||
Args:
|
||||
report: Report data
|
||||
context: Pipeline context
|
||||
|
||||
Returns:
|
||||
Path to saved report
|
||||
"""
|
||||
import json
|
||||
|
||||
output_path = context.config.reporting.output
|
||||
|
||||
if not output_path:
|
||||
# Auto-generate filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = Path(f"pit_report_{timestamp}.json")
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
return output_path
|
||||
|
||||
def _display_summary(
|
||||
self, report: Dict[str, Any], report_path: Path
|
||||
) -> None:
|
||||
"""
|
||||
Display summary to console.
|
||||
|
||||
Args:
|
||||
report: Report data
|
||||
report_path: Path to saved report
|
||||
"""
|
||||
from pit.ui.tables import create_summary_panel
|
||||
|
||||
summary = report["summary"]
|
||||
|
||||
panel = create_summary_panel(
|
||||
total_tests=summary["total_tests"],
|
||||
successful_attacks=summary["successful_attacks"],
|
||||
success_rate=summary["success_rate"],
|
||||
vulnerabilities_by_severity={},
|
||||
report_path=str(report_path),
|
||||
)
|
||||
|
||||
console.print()
|
||||
console.print(panel)
|
||||
157
tools/prompt_injection_tester/pit/orchestrator/pipeline.py
Normal file
157
tools/prompt_injection_tester/pit/orchestrator/pipeline.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Sequential pipeline executor for running phases one-by-one."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from pit.config import Config
|
||||
from pit.errors import PitError
|
||||
from pit.orchestrator.phases import Phase, PhaseResult, PhaseStatus
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
"""
|
||||
Shared state across pipeline phases.
|
||||
|
||||
This context is passed through each phase sequentially,
|
||||
allowing phases to read input from previous phases.
|
||||
"""
|
||||
|
||||
target_url: str
|
||||
config: Config
|
||||
|
||||
# Phase 1 output (Discovery)
|
||||
injection_points: List[Any] = field(default_factory=list)
|
||||
|
||||
# Phase 2 output (Attack)
|
||||
test_results: List[Any] = field(default_factory=list)
|
||||
|
||||
# Phase 3 output (Verification)
|
||||
verified_results: List[Any] = field(default_factory=list)
|
||||
|
||||
# Phase 4 output (Reporting)
|
||||
report: Optional[Dict[str, Any]] = None
|
||||
report_path: Optional[Path] = None
|
||||
|
||||
# Metadata
|
||||
start_time: datetime = field(default_factory=datetime.now)
|
||||
phase_durations: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Interrupt handling
|
||||
interrupted: bool = False
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Sequential phase executor.
|
||||
|
||||
Runs phases one-by-one, waiting for each to complete before
|
||||
starting the next. This ensures no concurrency errors.
|
||||
"""
|
||||
|
||||
def __init__(self, phases: List[Phase]):
|
||||
"""
|
||||
Initialize pipeline with phases.
|
||||
|
||||
Args:
|
||||
phases: List of phases to execute sequentially
|
||||
"""
|
||||
self.phases = phases
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""
|
||||
Run all phases sequentially.
|
||||
|
||||
Each phase MUST complete before the next begins.
|
||||
|
||||
Args:
|
||||
context: Pipeline context (shared state)
|
||||
|
||||
Returns:
|
||||
Updated context with results
|
||||
|
||||
Raises:
|
||||
PitError: If any phase fails
|
||||
KeyboardInterrupt: If user interrupts
|
||||
"""
|
||||
total_phases = len(self.phases)
|
||||
|
||||
try:
|
||||
for i, phase in enumerate(self.phases, start=1):
|
||||
self._print_phase_header(i, total_phases, phase.name)
|
||||
|
||||
# CRITICAL: Wait for phase to complete before continuing
|
||||
phase_start = datetime.now()
|
||||
result = await phase.execute(context)
|
||||
phase_duration = (datetime.now() - phase_start).total_seconds()
|
||||
|
||||
context.phase_durations[phase.name] = phase_duration
|
||||
|
||||
if result.status == PhaseStatus.FAILED:
|
||||
self._handle_phase_failure(phase, result)
|
||||
break
|
||||
|
||||
self._print_phase_success(phase.name, result)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]⚠ Scan Interrupted by User[/yellow]")
|
||||
context.interrupted = True
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]✗ Pipeline Error: {e.__class__.__name__}[/red]")
|
||||
raise
|
||||
|
||||
return context
|
||||
|
||||
def _print_phase_header(self, phase_num: int, total: int, name: str) -> None:
|
||||
"""Print phase header."""
|
||||
header = f"[{phase_num}/{total}] {name}"
|
||||
console.print()
|
||||
console.print(Panel(header, border_style="cyan", expand=False))
|
||||
|
||||
def _print_phase_success(self, name: str, result: PhaseResult) -> None:
|
||||
"""Print phase success message."""
|
||||
console.print(f"[green]✓ {name} Complete[/green]")
|
||||
if result.message:
|
||||
console.print(f" └─ {result.message}")
|
||||
|
||||
def _handle_phase_failure(self, phase: Phase, result: PhaseResult) -> None:
|
||||
"""Handle phase failure."""
|
||||
console.print(f"[red]✗ {phase.name} Failed[/red]")
|
||||
if result.error:
|
||||
console.print(f" └─ {result.error}")
|
||||
|
||||
|
||||
async def create_default_pipeline() -> Pipeline:
|
||||
"""
|
||||
Create the default 4-phase pipeline.
|
||||
|
||||
Returns:
|
||||
Pipeline with Discovery, Attack, Verification, and Reporting phases
|
||||
"""
|
||||
from pit.orchestrator.phases import (
|
||||
AttackPhase,
|
||||
DiscoveryPhase,
|
||||
ReportingPhase,
|
||||
VerificationPhase,
|
||||
)
|
||||
|
||||
return Pipeline(
|
||||
phases=[
|
||||
DiscoveryPhase(),
|
||||
AttackPhase(),
|
||||
VerificationPhase(),
|
||||
ReportingPhase(),
|
||||
]
|
||||
)
|
||||
24
tools/prompt_injection_tester/pit/ui/__init__.py
Normal file
24
tools/prompt_injection_tester/pit/ui/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Rich UI components for the CLI interface.
|
||||
"""
|
||||
|
||||
from pit.ui.console import console
|
||||
from pit.ui.progress import create_progress_bar, create_spinner
|
||||
from pit.ui.display import (
|
||||
print_banner,
|
||||
print_success,
|
||||
print_error,
|
||||
print_warning,
|
||||
print_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"console",
|
||||
"create_progress_bar",
|
||||
"create_spinner",
|
||||
"print_banner",
|
||||
"print_success",
|
||||
"print_error",
|
||||
"print_warning",
|
||||
"print_info",
|
||||
]
|
||||
34
tools/prompt_injection_tester/pit/ui/spinner.py
Normal file
34
tools/prompt_injection_tester/pit/ui/spinner.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Spinner animations for async operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def show_spinner(message: str, spinner_type: str = "dots") -> Iterator[None]:
|
||||
"""
|
||||
Show a spinner during an operation.
|
||||
|
||||
Args:
|
||||
message: Message to display next to spinner
|
||||
spinner_type: Type of spinner animation
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Example:
|
||||
with show_spinner("Discovering injection points..."):
|
||||
await discover()
|
||||
"""
|
||||
spinner = Spinner(spinner_type, text=message, style="cyan")
|
||||
|
||||
with Live(spinner, console=console, transient=True):
|
||||
yield
|
||||
114
tools/prompt_injection_tester/pit/ui/styles.py
Normal file
114
tools/prompt_injection_tester/pit/ui/styles.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Color schemes and styling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.models import Severity, TestStatus
|
||||
|
||||
|
||||
# Color scheme
|
||||
STYLES = {
|
||||
"primary": "cyan",
|
||||
"success": "green",
|
||||
"warning": "yellow",
|
||||
"error": "red",
|
||||
"info": "blue",
|
||||
"muted": "dim",
|
||||
}
|
||||
|
||||
# Severity colors
|
||||
SEVERITY_COLORS = {
|
||||
"critical": "red",
|
||||
"high": "red",
|
||||
"medium": "yellow",
|
||||
"low": "blue",
|
||||
"info": "green",
|
||||
}
|
||||
|
||||
SEVERITY_ICONS = {
|
||||
"critical": "🔴",
|
||||
"high": "🔴",
|
||||
"medium": "🟠",
|
||||
"low": "🟡",
|
||||
"info": "🟢",
|
||||
}
|
||||
|
||||
# Status symbols
|
||||
STATUS_SYMBOLS = {
|
||||
"success": "✓",
|
||||
"failed": "✗",
|
||||
"error": "⚠",
|
||||
"pending": "⋯",
|
||||
"running": "⠋",
|
||||
}
|
||||
|
||||
|
||||
def format_severity(severity: str | Severity) -> str:
|
||||
"""
|
||||
Format severity with color and icon.
|
||||
|
||||
Args:
|
||||
severity: Severity level
|
||||
|
||||
Returns:
|
||||
Formatted severity string with Rich markup
|
||||
"""
|
||||
if hasattr(severity, "value"):
|
||||
severity = severity.value
|
||||
|
||||
severity_lower = severity.lower()
|
||||
color = SEVERITY_COLORS.get(severity_lower, "white")
|
||||
icon = SEVERITY_ICONS.get(severity_lower, "⚪")
|
||||
|
||||
return f"[{color}]{icon} {severity.upper()}[/{color}]"
|
||||
|
||||
|
||||
def format_status(status: str | TestStatus, success: bool = True) -> str:
|
||||
"""
|
||||
Format test status with symbol and color.
|
||||
|
||||
Args:
|
||||
status: Status value
|
||||
success: Whether the test succeeded
|
||||
|
||||
Returns:
|
||||
Formatted status string with Rich markup
|
||||
"""
|
||||
if hasattr(status, "value"):
|
||||
status = status.value
|
||||
|
||||
status_lower = status.lower()
|
||||
symbol = STATUS_SYMBOLS.get(status_lower, "⋯")
|
||||
|
||||
if success:
|
||||
color = "green"
|
||||
elif status_lower == "error":
|
||||
color = "red"
|
||||
else:
|
||||
color = "yellow"
|
||||
|
||||
return f"[{color}]{symbol} {status.title()}[/{color}]"
|
||||
|
||||
|
||||
def format_confidence(confidence: float) -> str:
|
||||
"""
|
||||
Format confidence score with color.
|
||||
|
||||
Args:
|
||||
confidence: Confidence score (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Formatted confidence string with Rich markup
|
||||
"""
|
||||
percentage = int(confidence * 100)
|
||||
|
||||
if confidence >= 0.9:
|
||||
color = "green"
|
||||
elif confidence >= 0.7:
|
||||
color = "yellow"
|
||||
else:
|
||||
color = "red"
|
||||
|
||||
return f"[{color}]{percentage}%[/{color}]"
|
||||
0
tools/prompt_injection_tester/pit/utils/__init__.py
Normal file
0
tools/prompt_injection_tester/pit/utils/__init__.py
Normal file
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "prompt-injection-tester"
|
||||
version = "1.0.0"
|
||||
description = "Automated prompt injection testing framework for LLMs"
|
||||
version = "2.0.0"
|
||||
description = "Modern CLI for automated prompt injection testing of LLMs"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "CC BY-SA 4.0"}
|
||||
@@ -18,6 +18,7 @@ keywords = [
|
||||
"prompt-injection",
|
||||
"red-team",
|
||||
"testing",
|
||||
"cli",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
@@ -31,9 +32,12 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp>=3.9.0",
|
||||
"httpx>=0.24.0",
|
||||
"pyyaml>=6.0",
|
||||
"typer>=0.7.0",
|
||||
"typer>=0.9.0",
|
||||
"rich>=13.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"jinja2>=3.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -48,7 +52,7 @@ dev = [
|
||||
|
||||
[project.scripts]
|
||||
prompt-injection-tester = "prompt_injection_tester.cli:main"
|
||||
pit = "pit.app:app"
|
||||
pit = "pit.app:cli_main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/example/ai-llm-red-team-handbook"
|
||||
|
||||
2
tools/prompt_injection_tester/tests/__init__.py
Normal file
2
tools/prompt_injection_tester/tests/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test suite for the prompt injection tester."""
|
||||
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for attack patterns."""
|
||||
15
tools/prompt_injection_tester/utils/__init__.py
Normal file
15
tools/prompt_injection_tester/utils/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Utility modules for the prompt injection tester."""
|
||||
|
||||
from .encoding import encode_payload, decode_payload, translate_payload
|
||||
from .http_client import AsyncHTTPClient, LLMClient, HTTPResponse, RateLimiter
|
||||
|
||||
__all__ = [
|
||||
"encode_payload",
|
||||
"decode_payload",
|
||||
"translate_payload",
|
||||
"AsyncHTTPClient",
|
||||
"LLMClient",
|
||||
"HTTPResponse",
|
||||
"RateLimiter",
|
||||
]
|
||||
Reference in New Issue
Block a user