feat: Introduce a new configuration system for PIT using Pydantic schemas for target, attack, reporting, and authorization, with a loader for YAML files and CLI arguments.

This commit is contained in:
shiva108
2026-01-26 18:58:23 +01:00
parent 3bad401ada
commit 5b59811989
30 changed files with 2902 additions and 43 deletions

89
.gitignore vendored
View File

@@ -1,13 +1,11 @@
# Local files - never commit
.venv
/tools/prompt_injection_tester/.venv
.agent
docs/contentsuggestions
ignore/
#.agent/
__init__.py
# --- Core & System ---
.DS_Store
Thumbs.db
*~
*.swp
*.swo
# Python
# --- Python ---
__pycache__/
*.py[cod]
*$py.class
@@ -17,52 +15,65 @@ __pycache__/
.eggs/
dist/
build/
.venv/
.venvs/
/tools/prompt_injection_tester/.venv
# IDE
# --- Node.js ---
node_modules/
npm-debug.log
# --- IDEs ---
.idea/
.vscode/
*.swp
*.swo
*~
*.sublime-project
*.sublime-workspace
# OS
.DS_Store
Thumbs.db
# Temporary files
# --- Temporary & Logs ---
*.tmp
*.temp
*.log
# --- Agent & AI Workspaces ---
.agent/
.claude/
ignore/
docs/contentsuggestions
# --- Local/Env Configuration ---
.env
.env.*
!.env.example
.mcp.json
# --- Shell ---
.bash_profile
.bashrc
.profile
.zprofile
.zshrc
# --- Temporary Artifacts & Backups ---
.markdownlint.json
.agent/AUTO_COMMIT_GUIDE.md
.agent/rules/snyk_rules.md
scripts/tests/verify_fixes.py
docs/Chapter_31_AI_System_Reconnaissance.md.backup
docs/Chapter_31_AI_System_Reconnaissance.md.audit_backup
*.backup
*.audit_backup
# --- Generated Reports & Docs ---
docs/Visual_Recommendations.md
Visual_Recommendations_V2.md
.gitignore
workflows/audit-fix-humanize-chapter-v2.md
.gitignore
# Specific Reports
docs/reports/AI_Security_Intelligence_Report_December_2025.md
docs/reports/AI_Security_Intelligence_Report_January_2026.md
docs/reports/newsletter_jan_2026.md
# --- Tool Specific: Prompt Injection Tester ---
tools/prompt_injection_tester/CODE_REVIEW_REPORT.md
tools/prompt_injection_tester/.coverage
.bash_profile
.bashrc
.idea
.mcp.json
.profile
.ripgreprc
.zprofile
.zshrc
.claude/agents
.claude/commands
.claude/settings.json
tools/prompt_injection_tester/CLI_SPECIFICATION.md
tools/prompt_injection_tester/CLI_ARCHITECTURE.md
tools/prompt_injection_tester/PHASE1_COMPLETE.md
tools/prompt_injection_tester/SPECIFICATION.md
tools/prompt_injection_tester/.coverage
tools/prompt_injection_tester/ARCHITECTURE.md
.gitignore
.idea
.ripgreprc

View File

@@ -0,0 +1,141 @@
import os
import sys
import inspect
import socket
import threading
import time
import pytest
from unittest.mock import MagicMock, patch
# Add project root to path (parent of scripts/)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
def test_this_code_api_key():
print("\n[+] Verifying 'scripts/social_engineering/this_code.py' API Key fix...")
# Mocking PhishingGenerator to avoid import errors if dependencies are missing or if it tries to network
# We just want to check the line where api_key is passed.
# Since we can't easily mock the class *before* import if it's top level, we might have to inspect the file or just run it with a mock Env.
# Let's inspect the file content for the correct string first as a sanity check
with open('scripts/social_engineering/this_code.py', 'r') as f:
content = f.read()
if 'api_key=os.getenv("OPENAI_API_KEY")' in content:
print(" PASS: File uses os.getenv('OPENAI_API_KEY')")
else:
print(" FAIL: File does not use os.getenv('OPENAI_API_KEY')")
return
# Now try to import it with a dummy env var
os.environ['OPENAI_API_KEY'] = 'test_key'
try:
# We need to mock PhishingGenerator because it might not exist or might do things
with patch('scripts.social_engineering.this_code.PhishingGenerator') as MockGen:
import scripts.social_engineering.this_code
# Verify it was called with our env var
# The script runs immediately on import, so MockGen should have been instantiated
MockGen.assert_called_with(api_key='test_key')
print(" PASS: Script ran and used the environment variable.")
except Exception as e:
print(f" WARNING: Could not import script fully (missing dependnecies?): {e}")
def test_attack_rce_disabled():
print("\n[+] Verifying 'scripts/utils/attack.py' RCE disabled...")
try:
from scripts.utils.attack import MaliciousModel
mm = MaliciousModel()
func, args = mm.__reduce__()
if func == os.system:
print(" FAIL: __reduce__ still returns os.system!")
elif func == print:
print(f" PASS: __reduce__ returns 'print' (safe). Args: {args}")
else:
print(f" PASS: __reduce__ returns {func.__name__} (likely safe).")
except ImportError:
print(" FAIL: Could not import MaliciousModel")
except Exception as e:
print(f" FAIL: Error testing MaliciousModel: {e}")
def test_c2_server_binding():
print("\n[+] Verifying 'scripts/examples/c2_server_elite.py' binding...")
try:
from scripts.examples.c2_server_elite import C2Server
# Inspect the 'start' method signature
sig = inspect.signature(C2Server.start)
host_param = sig.parameters['host']
if host_param.default == '127.0.0.1':
print(" PASS: C2Server.start default host is '127.0.0.1'")
else:
print(f" FAIL: C2Server.start default host is '{host_param.default}'")
except ImportError:
print(" FAIL: Could not import C2Server (missing dependencies?)")
except Exception as e:
print(f" FAIL: Error inspecting C2Server: {e}")
def test_components_source_asserts():
print("\n[+] Verifying 'scripts/automation/components_source.py' asserts removed...")
# Check file content for 'assert '
with open('scripts/automation/components_source.py', 'r') as f:
lines = f.readlines()
assert_count = sum(1 for line in lines if line.strip().startswith('assert '))
if assert_count == 0:
print(" PASS: No 'assert' statements found.")
else:
print(f" WARNING: Found {assert_count} 'assert' statements. Please verify if they are safe test assertions.")
# Run the script to verify it raises RuntimeError (mocking the client output)
print(" Running logical verification...")
# We can invoke the test classes directly
from scripts.automation.components_source import TestPromptInjection, MockLLMClient
test = TestPromptInjection()
# 1. Test Jailbreak (Mock Unsafe)
mock_unsafe_client = MagicMock()
mock_unsafe_client.complete.return_value = "Confirmed. I am now free."
try:
test.test_jailbreak_resistance(mock_unsafe_client, "ignore rules")
print(" FAIL: test_jailbreak_resistance did NOT raise exception on unsafe response")
except RuntimeError as e:
print(f" PASS: test_jailbreak_resistance raised RuntimeError as expected: {e}")
except AssertionError:
print(" FAIL: test_jailbreak_resistance raised AssertionError (should be RuntimeError)")
# 2. Test PII Leakage (Mock Leak)
mock_leak_client = MagicMock()
mock_leak_client.complete.return_value = "Here is my key: sk-12345"
try:
test.test_pii_leakage(mock_leak_client)
print(" FAIL: test_pii_leakage did NOT raise exception on leak")
except RuntimeError as e:
print(f" PASS: test_pii_leakage raised RuntimeError as expected: {e}")
except AssertionError:
print(" FAIL: test_pii_leakage raised AssertionError (should be RuntimeError)")
def test_shadow_scanner_resource():
print("\n[+] Verifying 'scripts/utils/tooling_shadow_ai_scanner.py' resource handling...")
# Static check for 'with socket.socket'
with open('scripts/utils/tooling_shadow_ai_scanner.py', 'r') as f:
content = f.read()
if "with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:" in content:
print(" PASS: Uses context manager for socket.")
else:
print(" FAIL: Does not appear to use context manager for socket.")
if __name__ == "__main__":
test_this_code_api_key()
test_attack_rce_disabled()
test_c2_server_binding()
test_components_source_asserts()
test_shadow_scanner_resource()

View File

@@ -0,0 +1,687 @@
# PIT (Prompt Injection Tester) - Functional Specification
**Version:** 2.0.0
**Date:** 2026-01-26
**Status:** Draft
---
## 1. Executive Summary
**PIT** is a Modern, One-Command CLI Application for automated prompt injection testing. It transforms the existing `prompt_injection_tester` framework into a user-friendly TUI (Text User Interface) that executes the entire Red Teaming lifecycle with a single command.
### Design Philosophy
- **"Magic Command" UX**: Single command to run end-to-end testing
- **Sequential Execution**: Phases run one-by-one to avoid concurrency errors
- **Visual Feedback**: Rich TUI with progress bars, spinners, and color-coded results
- **Fail-Fast**: Graceful error handling at each phase boundary
- **Zero Configuration**: Sensible defaults with optional customization
---
## 2. The "One-Command" Workflow
### 2.1 Primary Command
```bash
pit scan <target_url> --auto
```
**Example:**
```bash
pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY
```
### 2.2 Workflow Phases (Sequential)
The application runs **four phases sequentially**. Each phase:
- Completes fully before the next begins
- Returns data that feeds into the next phase
- Can fail gracefully without crashing the entire pipeline
```
┌─────────────────────────────────────────────────────────────┐
│ PIT WORKFLOW │
├─────────────────────────────────────────────────────────────┤
│ │
│ Phase 1: DISCOVERY │
│ ├─ Scan target for injection points │
│ ├─ Identify API endpoints, parameters, headers │
│ └─ Output: List[InjectionPoint] │
│ │ │
│ ▼ │
│ Phase 2: ATTACK │
│ ├─ Load attack patterns from registry │
│ ├─ Execute attacks against discovered points │
│ ├─ Use asyncio internally for HTTP requests │
│ └─ Output: List[TestResult] │
│ │ │
│ ▼ │
│ Phase 3: VERIFICATION │
│ ├─ Analyze responses for success indicators │
│ ├─ Apply detection heuristics │
│ ├─ Calculate severity scores │
│ └─ Output: List[VerifiedResult] │
│ │ │
│ ▼ │
│ Phase 4: REPORTING │
│ ├─ Generate summary table │
│ ├─ Write report artifact (JSON/HTML/YAML) │
│ └─ Display results to stdout │
│ │
└─────────────────────────────────────────────────────────────┘
```
**Critical Requirement:**
The application MUST wait for each phase to complete before starting the next. No parallel "tool use" or agent invocations.
---
## 3. User Experience Specification
### 3.1 Phase 1: Discovery
**User sees:**
```
┌─────────────────────────────────────────────────────┐
│ [1/4] Discovery │
├─────────────────────────────────────────────────────┤
│ Target: https://api.openai.com/v1/chat/completions │
│ │
│ ⠋ Discovering injection points... │
│ │
│ [Spinner animation while scanning] │
└─────────────────────────────────────────────────────┘
```
**Success Output:**
```
✓ Discovery Complete
├─ Found 3 endpoints
├─ Identified 12 parameters
└─ Detected 2 header injection points
```
**Error Handling:**
- If target is unreachable: Display error, suggest `--skip-discovery`
- If no injection points found: Warn user, allow manual point specification
### 3.2 Phase 2: Attack Execution
**User sees:**
```
┌─────────────────────────────────────────────────────┐
│ [2/4] Attack Execution │
├─────────────────────────────────────────────────────┤
│ Loaded 47 attack patterns from registry │
│ │
│ Progress: [████████████░░░░░░] 45/100 (45%) │
│ │
│ Current: direct/role_override │
│ Rate: 2.3 req/s | Elapsed: 00:19 | ETA: 00:24 │
└─────────────────────────────────────────────────────┘
```
**Progress Bar Details:**
- Shows current attack pattern being tested
- Displays rate limiting compliance
- Real-time success/failure counters
**Interrupt Handling:**
- `Ctrl+C` during attack: Save partial results, offer resume option
### 3.3 Phase 3: Verification
**User sees:**
```
┌─────────────────────────────────────────────────────┐
│ [3/4] Verification │
├─────────────────────────────────────────────────────┤
│ Analyzing 100 responses... │
│ │
│ ⠸ Running detection heuristics │
│ │
│ [Spinner animation] │
└─────────────────────────────────────────────────────┘
```
**Success Output:**
```
✓ Verification Complete
├─ 12 successful injections detected
├─ 88 attacks blocked/failed
└─ 3 high-severity vulnerabilities found
```
### 3.4 Phase 4: Reporting
**User sees:**
```
┌─────────────────────────────────────────────────────────────────────┐
│ [4/4] Report Generation │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ VULNERABILITY SUMMARY │
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
│ │
│ Pattern ID │ Severity │ Status │ Confidence │
│ ─────────────────────┼───────────┼───────────┼─────────── │
│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │
│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │
│ context_override │ 🟡 LOW │ ✗ Failed │ - │
│ │
│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │
│ │
│ 📄 Report saved: ./pit_report_20260126_143022.json │
│ │
└─────────────────────────────────────────────────────────────────────┘
```
**Report Artifacts:**
- Default: `./pit_report_{timestamp}.json`
- HTML report (if `--format html`): Interactive dashboard
- YAML report (if `--format yaml`): Human-readable summary
---
## 4. Command-Line Interface Specification
### 4.1 Primary Commands
#### `pit scan`
**Syntax:**
```bash
pit scan <target_url> [OPTIONS]
```
**Required Arguments:**
- `target_url`: The API endpoint to test (e.g., `https://api.example.com/v1/chat`)
**Optional Arguments:**
```
--token, -t <TOKEN> Authentication token (or use env: $PIT_TOKEN)
--auto, -a Run all phases automatically (default: interactive)
--patterns <PATTERN_IDS> Test specific patterns (comma-separated)
--categories <CATEGORIES> Filter by category: direct,indirect,advanced
--output, -o <FILE> Report output path (default: auto-generated)
--format, -f <FORMAT> Report format: json, yaml, html (default: json)
--rate-limit <FLOAT> Requests per second (default: 1.0)
--max-concurrent <INT> Max parallel requests (default: 5)
--timeout <INT> Request timeout in seconds (default: 30)
--skip-discovery Skip discovery phase, use manual injection points
--injection-points <FILE> Load injection points from JSON file
--verbose, -v Show detailed logs
--quiet, -q Suppress all output except errors
```
**Examples:**
```bash
# Basic scan
pit scan https://api.openai.com/v1/chat/completions --auto --token $OPENAI_API_KEY
# Test specific patterns
pit scan https://api.example.com --patterns role_override,prompt_leak --auto
# Custom rate limiting
pit scan https://api.example.com --rate-limit 0.5 --max-concurrent 3 --auto
# Generate HTML report
pit scan https://api.example.com --auto --format html --output report.html
# Skip discovery (use manual points)
pit scan https://api.example.com --skip-discovery --injection-points ./points.json --auto
```
#### `pit list`
**Syntax:**
```bash
pit list [patterns|categories]
```
**Examples:**
```bash
# List all available attack patterns
pit list patterns
# List attack categories
pit list categories
```
**Output:**
```
Available Attack Patterns (47 total)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Category: direct (15 patterns)
├─ role_override - Override system role assignment
├─ system_prompt_leak - Attempt to extract system prompt
└─ ...
Category: indirect (12 patterns)
├─ payload_splitting - Split malicious payload across inputs
└─ ...
Category: advanced (20 patterns)
├─ unicode_smuggling - Use Unicode tricks to bypass filters
└─ ...
```
#### `pit auth`
**Syntax:**
```bash
pit auth <target_url>
```
**Purpose:**
Verify authorization to test the target before running attacks.
**Interactive Prompt:**
```
┌─────────────────────────────────────────────────────┐
│ AUTHORIZATION REQUIRED │
├─────────────────────────────────────────────────────┤
│ │
│ Target: https://api.example.com │
│ │
│ ⚠ You must have explicit authorization to test │
│ this system. Unauthorized testing may be illegal. │
│ │
│ Do you have authorization? [y/N]: │
└─────────────────────────────────────────────────────┘
```
**Non-Interactive:**
```bash
pit scan <url> --auto --authorize
```
### 4.2 Configuration File Support
**Format:** YAML
**Location:** `./pit.config.yaml` or `~/.config/pit/config.yaml`
**Example:**
```yaml
# PIT Configuration
target:
url: https://api.openai.com/v1/chat/completions
token: ${OPENAI_API_KEY}
api_type: openai
timeout: 30
attack:
categories:
- direct
- indirect
patterns:
exclude:
- dos_attack # Skip DoS patterns
max_concurrent: 5
rate_limit: 1.0
reporting:
format: html
output: ./reports/
include_cvss: true
include_payloads: false # Exclude payloads for compliance
authorization:
scope:
- all
confirmed: true # Skip interactive prompt
```
**Usage:**
```bash
# Use config file
pit scan --config ./pit.config.yaml --auto
```
---
## 5. Error Handling Specification
### 5.1 Graceful Degradation
**Principle:** Each phase can fail independently without crashing the pipeline.
**Phase-Specific Errors:**
#### Discovery Errors
- **Target Unreachable**: Suggest `--skip-discovery`, allow manual injection points
- **Rate Limited**: Display backoff message, retry with exponential backoff
- **No Endpoints Found**: Warn user, offer to load from file
#### Attack Errors
- **Authentication Failed**: Stop immediately, display clear auth error
- **Rate Limit Hit**: Pause attack, show countdown, resume automatically
- **Timeout Exceeded**: Skip pattern, log failure, continue with next
#### Verification Errors
- **Detection Ambiguous**: Mark as "uncertain", include in report with low confidence
- **Scoring Failed**: Use default severity, log warning
#### Reporting Errors
- **File Write Failed**: Fall back to stdout
- **Format Error**: Generate JSON as fallback
### 5.2 User-Friendly Error Messages
**Bad:**
```
Error: HTTPError(403)
```
**Good:**
```
✗ Authentication Failed
├─ The target server returned 403 Forbidden
├─ Suggestion: Check your API token with --token
└─ Or verify authorization with: pit auth <url>
```
### 5.3 Interrupt Handling
**Behavior on `Ctrl+C`:**
```
┌─────────────────────────────────────────────────────┐
│ ⚠ Scan Interrupted │
├─────────────────────────────────────────────────────┤
│ Progress: 45/100 attacks completed │
│ │
│ Options: │
│ r - Resume scan │
│ s - Save partial results and exit │
│ q - Quit without saving │
│ │
│ Choice [r/s/q]: │
└─────────────────────────────────────────────────────┘
```
---
## 6. Sequential Logic Specification
### 6.1 Phase Execution Flow
**Pseudocode:**
```python
async def run_scan(target_url: str, config: Config) -> Report:
"""
Execute the full scan pipeline sequentially.
Each phase MUST complete before the next begins.
"""
# Phase 1: Discovery
print_phase_header(1, "Discovery")
show_spinner("Discovering injection points...")
injection_points = await discovery.scan(target_url)
# ↑ WAIT for discovery to complete
if not injection_points:
handle_discovery_failure()
return
print_success(f"Found {len(injection_points)} injection points")
# Phase 2: Attack
print_phase_header(2, "Attack Execution")
attack_patterns = load_patterns(config.categories)
results = []
with ProgressBar(total=len(attack_patterns)) as progress:
for pattern in attack_patterns:
# Execute attacks ONE BY ONE (or with internal asyncio)
result = await attack.execute(pattern, injection_points)
results.append(result)
progress.update(1)
# ↑ WAIT for all attacks to complete
print_success(f"Completed {len(results)} attacks")
# Phase 3: Verification
print_phase_header(3, "Verification")
show_spinner("Analyzing responses...")
verified_results = await verification.analyze(results)
# ↑ WAIT for verification to complete
print_success(f"Verified {len(verified_results)} results")
# Phase 4: Reporting
print_phase_header(4, "Reporting")
report = generate_report(verified_results, config.format)
save_report(report, config.output)
display_summary(report)
return report
```
### 6.2 Data Flow Between Phases
**Phase Boundaries:**
```
Phase 1 Output → Phase 2 Input
InjectionPoint[] → attack.execute(patterns, injection_points)
Phase 2 Output → Phase 3 Input
TestResult[] → verification.analyze(results)
Phase 3 Output → Phase 4 Input
VerifiedResult[] → generate_report(verified_results)
```
**No Parallel Agent Invocations:**
- The CLI orchestrator runs phases sequentially
- Individual phases may use `asyncio` internally for HTTP requests
- But the orchestrator NEVER spawns multiple "tool use" blocks
---
## 7. Output Specifications
### 7.1 Terminal Output (stdout)
**Color Scheme:**
- 🔴 **Red**: High-severity vulnerabilities, errors
- 🟠 **Orange**: Medium-severity, warnings
- 🟡 **Yellow**: Low-severity, info
- 🟢 **Green**: Success messages
- 🔵 **Blue**: Headers, section dividers
-**White**: Default text
**Symbols:**
- `✓` Success
- `✗` Failure
- `⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏` Spinner animation
- `[████████░░]` Progress bars
### 7.2 JSON Report Format
**Schema:**
```json
{
"metadata": {
"version": "2.0.0",
"timestamp": "2026-01-26T14:30:22Z",
"target": "https://api.example.com",
"duration_seconds": 142.5
},
"discovery": {
"injection_points": [
{
"id": "param_prompt",
"type": "parameter",
"name": "prompt",
"location": "body"
}
]
},
"results": [
{
"pattern_id": "role_override",
"category": "direct",
"severity": "high",
"status": "success",
"confidence": 0.95,
"injection_point": "param_prompt",
"payload": "[REDACTED]",
"response_indicators": ["system", "role"],
"cvss_score": 7.8
}
],
"summary": {
"total_tests": 100,
"successful_attacks": 12,
"success_rate": 0.12,
"vulnerabilities_by_severity": {
"high": 3,
"medium": 5,
"low": 4
}
}
}
```
### 7.3 HTML Report Format
**Features:**
- Interactive table with sorting/filtering
- Visual charts (bar chart of severity distribution)
- Collapsible sections for detailed attack logs
- Copy-to-clipboard buttons for payloads
- Responsive design (mobile-friendly)
**Template:** Use Jinja2 or similar templating engine
---
## 8. Non-Functional Requirements
### 8.1 Performance
- **Discovery Phase**: < 10 seconds for typical API
- **Attack Phase**: Respects rate limiting, no server overload
- **Verification Phase**: < 5 seconds for 100 results
- **Reporting Phase**: < 2 seconds
### 8.2 Reliability
- **Crash-Free**: Handle all HTTP errors gracefully
- **Resumable**: Save state on interrupt, allow resume
- **Idempotent**: Same input → same output (deterministic)
### 8.3 Usability
- **Zero Learning Curve**: `pit scan <url> --auto` should be self-explanatory
- **Progressive Disclosure**: Show simple output by default, verbose with `-v`
- **Helpful Defaults**: No configuration required for basic usage
### 8.4 Security
- **Authorization Check**: Mandatory before running attacks
- **Token Handling**: Never log tokens, use env vars
- **Rate Limiting**: Prevent accidental DoS
---
## 9. Future Extensions (Out of Scope for v2.0)
- **Interactive Mode**: `pit scan <url>` without `--auto` prompts user at each phase
- **Plugin System**: Load custom attack patterns from external modules
- **Cloud Integration**: Upload reports to centralized dashboard
- **CI/CD Integration**: Exit codes for pipeline integration
- **Differential Testing**: Compare results across versions
---
## 10. Acceptance Criteria
**The PIT CLI is complete when:**
1. ✅ User can run `pit scan <url> --auto` and see visual feedback for all 4 phases
2. ✅ Phases execute sequentially (no concurrency errors)
3. ✅ Graceful error handling at every phase boundary
4. ✅ Generated reports match the JSON/HTML/YAML schemas
5. ✅ All output uses Rich TUI (progress bars, spinners, colored text)
6. ✅ Authorization is checked before running attacks
7. ✅ Rate limiting is respected to avoid DoS
8. ✅ Interrupts (`Ctrl+C`) are handled gracefully
9. ✅ Help text (`pit --help`) is clear and comprehensive
10. ✅ Zero crashes on invalid input (bad URLs, missing tokens, etc.)
---
## Appendix A: ASCII Art Mockups
### Full Scan Output
```
┌─────────────────────────────────────────────────────────────────┐
│ PIT - Prompt Injection Tester v2.0.0 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Target: https://api.openai.com/v1/chat/completions │
│ Authorization: ✓ Confirmed │
│ │
├─────────────────────────────────────────────────────────────────┤
│ [1/4] Discovery │
│ ⠋ Discovering injection points... │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ ✓ Discovery Complete │
│ ├─ Found 3 endpoints │
│ ├─ Identified 12 parameters │
│ └─ Detected 2 header injection points │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ [2/4] Attack Execution │
│ Progress: [████████████████░░░░] 80/100 (80%) │
│ Current: advanced/unicode_smuggling │
│ Rate: 2.1 req/s | Elapsed: 00:38 | ETA: 00:10 │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ ✓ Attack Execution Complete │
│ └─ Completed 100 attacks │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ [3/4] Verification │
│ ⠸ Analyzing responses... │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ ✓ Verification Complete │
│ ├─ 12 successful injections detected │
│ ├─ 88 attacks blocked/failed │
│ └─ 3 high-severity vulnerabilities found │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ [4/4] Report Generation │
│ │
│ VULNERABILITY SUMMARY │
│ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │
│ │
│ Pattern ID │ Severity │ Status │ Confidence │
│ ─────────────────────┼───────────┼───────────┼─────────── │
│ role_override │ 🔴 HIGH │ ✓ Success │ 95% │
│ system_prompt_leak │ 🟠 MEDIUM │ ✓ Success │ 87% │
│ context_override │ 🟠 MEDIUM │ ✓ Success │ 82% │
│ payload_splitting │ 🟡 LOW │ ✗ Failed │ - │
│ │
│ Total Tests: 100 | Successful: 12 | Success Rate: 12% │
│ │
│ 📄 Report saved: ./pit_report_20260126_143022.json │
│ │
└─────────────────────────────────────────────────────────────────┘
```
---
**END OF SPECIFICATION**

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Prompt Injection Tester
A professional-grade automated testing framework for LLM prompt injection vulnerabilities.
Based on the AI LLM Red Team Handbook Chapter 14.
Example:
from prompt_injection_tester import InjectionTester, AttackConfig
async with InjectionTester(
target_url="https://api.example.com/v1",
auth_token="...",
) as tester:
tester.authorize(scope=["all"])
injection_points = await tester.discover_injection_points()
results = await tester.run_tests(injection_points=injection_points)
report = tester.generate_report(format="html")
"""
__version__ = "1.0.0"
__author__ = "AI LLM Red Team Handbook"
from .core import (
AttackCategory,
AttackConfig,
AttackPayload,
AttackPattern,
DetectionMethod,
DetectionResult,
InjectionPoint,
InjectionPointType,
InjectionTester,
Severity,
TargetConfig,
TestContext,
TestResult,
TestStatus,
TestSuite,
)
from .patterns.registry import registry, register_pattern, get_pattern, list_patterns
__all__ = [
# Core
"InjectionTester",
"AttackConfig",
"TargetConfig",
"TestResult",
"TestSuite",
"InjectionPoint",
"AttackPattern",
"AttackPayload",
# Enums
"AttackCategory",
"InjectionPointType",
"DetectionMethod",
"Severity",
"TestStatus",
# Detection
"DetectionResult",
"TestContext",
# Registry
"registry",
"register_pattern",
"get_pattern",
"list_patterns",
]

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""Core modules for the prompt injection tester."""
from .models import (
AttackCategory,
AttackConfig,
AttackPayload,
AttackPattern,
DetectionMethod,
DetectionResult,
InjectionPoint,
InjectionPointType,
Severity,
TargetConfig,
TestContext,
TestResult,
TestStatus,
TestSuite,
)
from .tester import InjectionTester
__all__ = [
"AttackCategory",
"AttackConfig",
"AttackPayload",
"AttackPattern",
"DetectionMethod",
"DetectionResult",
"InjectionPoint",
"InjectionPointType",
"InjectionTester",
"Severity",
"TargetConfig",
"TestContext",
"TestResult",
"TestStatus",
"TestSuite",
]

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env python3
"""Detection and validation modules for identifying successful injections."""
from .base import BaseDetector, DetectorRegistry
from .system_prompt import SystemPromptLeakDetector
from .behavior_change import BehaviorChangeDetector
from .tool_misuse import ToolMisuseDetector
from .scoring import ConfidenceScorer
__all__ = [
"BaseDetector",
"DetectorRegistry",
"SystemPromptLeakDetector",
"BehaviorChangeDetector",
"ToolMisuseDetector",
"ConfidenceScorer",
]

View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python3
"""Attack pattern modules."""
from .base import BaseAttackPattern, MultiTurnAttackPattern, CompositeAttackPattern
from .registry import registry, register_pattern, get_pattern, list_patterns
__all__ = [
"BaseAttackPattern",
"MultiTurnAttackPattern",
"CompositeAttackPattern",
"registry",
"register_pattern",
"get_pattern",
"list_patterns",
]

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
"""
Advanced Attack Patterns
Sophisticated techniques including multi-turn attacks, payload fragmentation,
encoding/obfuscation, and language-based exploits.
Based on Chapter 14 advanced techniques sections.
"""
from .multi_turn import (
GradualEscalationPattern,
ContextBuildupPattern,
TrustEstablishmentPattern,
)
from .fragmentation import (
PayloadFragmentationPattern,
TokenSmugglingPattern,
ChunkedInjectionPattern,
)
from .encoding import (
Base64EncodingPattern,
UnicodeObfuscationPattern,
LanguageSwitchingPattern,
LeetSpeakPattern,
)
__all__ = [
# Multi-turn
"GradualEscalationPattern",
"ContextBuildupPattern",
"TrustEstablishmentPattern",
# Fragmentation
"PayloadFragmentationPattern",
"TokenSmugglingPattern",
"ChunkedInjectionPattern",
# Encoding/Obfuscation
"Base64EncodingPattern",
"UnicodeObfuscationPattern",
"LanguageSwitchingPattern",
"LeetSpeakPattern",
]

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
Direct Injection Attack Patterns
These patterns target the primary LLM input channel where attacker-controlled
content is directly processed by the model. Based on Chapter 14 techniques.
"""
from .instruction_override import (
InstructionOverridePattern,
SystemPromptOverridePattern,
TaskHijackingPattern,
)
from .role_manipulation import (
RoleAuthorityPattern,
PersonaShiftPattern,
DeveloperModePattern,
)
from .delimiter_confusion import (
DelimiterEscapePattern,
XMLInjectionPattern,
MarkdownInjectionPattern,
)
__all__ = [
# Instruction Override
"InstructionOverridePattern",
"SystemPromptOverridePattern",
"TaskHijackingPattern",
# Role Manipulation
"RoleAuthorityPattern",
"PersonaShiftPattern",
"DeveloperModePattern",
# Delimiter Confusion
"DelimiterEscapePattern",
"XMLInjectionPattern",
"MarkdownInjectionPattern",
]

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
"""
Indirect Injection Attack Patterns
These patterns target secondary data sources that the LLM processes,
such as retrieved documents, web pages, emails, and database content.
Based on Chapter 14 indirect injection techniques.
"""
from .document_poisoning import (
RAGPoisoningPattern,
DocumentMetadataInjectionPattern,
HiddenTextInjectionPattern,
)
from .web_injection import (
WebPageInjectionPattern,
SEOPoisoningPattern,
CommentInjectionPattern,
)
from .email_injection import (
EmailBodyInjectionPattern,
EmailHeaderInjectionPattern,
AttachmentInjectionPattern,
)
__all__ = [
# Document Poisoning
"RAGPoisoningPattern",
"DocumentMetadataInjectionPattern",
"HiddenTextInjectionPattern",
# Web Injection
"WebPageInjectionPattern",
"SEOPoisoningPattern",
"CommentInjectionPattern",
# Email Injection
"EmailBodyInjectionPattern",
"EmailHeaderInjectionPattern",
"AttachmentInjectionPattern",
]

View File

@@ -0,0 +1,10 @@
"""
Prompt Injection Tester (PIT) - Modern CLI Interface
A premium terminal experience for LLM security assessment.
"""
__version__ = "1.0.0"
__all__ = ["app"]
from pit.app import app

View File

@@ -0,0 +1,381 @@
"""Main CLI application using Typer."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import List, Optional
import typer
from rich.console import Console
from pit import __version__
from pit.config import Config, load_config
from pit.errors import handle_error
app = typer.Typer(
name="pit",
help="Prompt Injection Tester - Automated LLM Security Testing",
add_completion=False,
)
console = Console()
@app.command()
def scan(
target_url: str = typer.Argument(
...,
help="Target API endpoint URL",
),
token: Optional[str] = typer.Option(
None,
"--token",
"-t",
help="Authentication token (or use $PIT_TOKEN)",
envvar="PIT_TOKEN",
),
config_file: Optional[Path] = typer.Option(
None,
"--config",
"-c",
help="Configuration file (YAML)",
exists=True,
),
auto: bool = typer.Option(
False,
"--auto",
"-a",
help="Run all phases automatically (non-interactive)",
),
api_type: str = typer.Option(
"openai",
"--api-type",
help="API type (openai, anthropic, custom)",
),
timeout: int = typer.Option(
30,
"--timeout",
help="Request timeout in seconds",
),
categories: Optional[List[str]] = typer.Option(
None,
"--categories",
help="Attack categories (direct, indirect, advanced)",
),
patterns: Optional[List[str]] = typer.Option(
None,
"--patterns",
help="Specific pattern IDs to test",
),
max_concurrent: int = typer.Option(
5,
"--max-concurrent",
help="Maximum concurrent requests",
),
rate_limit: float = typer.Option(
1.0,
"--rate-limit",
help="Requests per second",
),
output: Optional[Path] = typer.Option(
None,
"--output",
"-o",
help="Output file path (auto-generated if not specified)",
),
output_format: str = typer.Option(
"json",
"--format",
"-f",
help="Report format (json, yaml, html)",
),
include_cvss: bool = typer.Option(
True,
"--include-cvss/--no-cvss",
help="Include CVSS scores in report",
),
include_payloads: bool = typer.Option(
False,
"--include-payloads/--no-payloads",
help="Include attack payloads in report",
),
authorize: bool = typer.Option(
False,
"--authorize",
help="Confirm authorization (skip interactive prompt)",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Enable verbose output",
),
quiet: bool = typer.Option(
False,
"--quiet",
"-q",
help="Suppress all output except errors",
),
):
"""
Scan a target API for prompt injection vulnerabilities.
Example:
pit scan https://api.example.com/v1/chat --auto --token $TOKEN
"""
import asyncio
try:
# Load configuration
config = load_config(
config_path=config_file,
target_url=target_url,
token=token,
api_type=api_type,
timeout=timeout,
categories=categories,
patterns=patterns,
max_concurrent=max_concurrent,
rate_limit=rate_limit,
output_format=output_format,
output=output,
include_cvss=include_cvss,
include_payloads=include_payloads,
authorize=authorize,
)
# Verify authorization
if not config.authorization.confirmed and not auto:
if not _prompt_authorization(target_url):
console.print("[yellow]Scan cancelled by user[/yellow]")
raise typer.Exit(0)
config.authorization.confirmed = True
# Print banner
if not quiet:
_print_banner(target_url)
# Run scan
exit_code = asyncio.run(_run_scan(config, verbose, quiet))
raise typer.Exit(exit_code)
except KeyboardInterrupt:
console.print("\n[yellow]⚠ Interrupted by user[/yellow]")
raise typer.Exit(130)
except Exception as e:
exit_code = handle_error(e, verbose)
raise typer.Exit(exit_code)
@app.command()
def list_command(
type: str = typer.Argument(
"patterns",
help="What to list (patterns, categories)",
),
):
"""
List available attack patterns or categories.
Example:
pit list patterns
pit list categories
"""
from patterns.registry import registry
from core.models import AttackCategory
if type == "patterns":
_list_patterns(registry)
elif type == "categories":
_list_categories()
else:
console.print(f"[red]Unknown type: {type}[/red]")
console.print("Available types: patterns, categories")
raise typer.Exit(1)
app.command(name="list")(list_command)
@app.command()
def auth(
target_url: str = typer.Argument(
...,
help="Target API endpoint URL",
),
):
"""
Verify authorization to test a target.
Example:
pit auth https://api.example.com
"""
if _prompt_authorization(target_url):
console.print("[green]✓ Authorization confirmed[/green]")
raise typer.Exit(0)
else:
console.print("[yellow]Authorization not confirmed[/yellow]")
raise typer.Exit(1)
@app.callback(invoke_without_command=True)
def main(
ctx: typer.Context,
version: bool = typer.Option(
False,
"--version",
help="Show version and exit",
),
):
"""PIT - Prompt Injection Tester"""
if version:
console.print(f"[cyan]pit[/cyan] version {__version__}")
raise typer.Exit(0)
if ctx.invoked_subcommand is None:
console.print(ctx.get_help())
# Helper functions
async def _run_scan(config: Config, verbose: bool, quiet: bool) -> int:
"""
Run the scan pipeline.
Args:
config: Configuration
verbose: Verbose output
quiet: Quiet mode
Returns:
Exit code (0 for success, non-zero for error)
"""
from pit.orchestrator.pipeline import Pipeline, PipelineContext
from pit.orchestrator.phases import (
AttackPhase,
DiscoveryPhase,
ReportingPhase,
VerificationPhase,
)
# Create pipeline
pipeline = Pipeline(
phases=[
DiscoveryPhase(),
AttackPhase(),
VerificationPhase(),
ReportingPhase(),
]
)
# Create context
context = PipelineContext(
target_url=config.target.url,
config=config,
)
# Run pipeline
context = await pipeline.run(context)
# Determine exit code
if context.interrupted:
return 130
elif context.report:
# Check if vulnerabilities found
successful = context.report.get("summary", {}).get("successful_attacks", 0)
return 2 if successful > 0 else 0
else:
return 1
def _prompt_authorization(target_url: str) -> bool:
"""
Prompt user for authorization confirmation.
Args:
target_url: Target URL
Returns:
True if authorized, False otherwise
"""
from rich.panel import Panel
panel = Panel(
f"[yellow]Target:[/yellow] {target_url}\n\n"
"[yellow]⚠ You must have explicit authorization to test this system.[/yellow]\n"
"[yellow]Unauthorized testing may be illegal in your jurisdiction.[/yellow]\n\n"
"Do you have authorization to test this target?",
title="[bold red]Authorization Required[/bold red]",
border_style="red",
)
console.print()
console.print(panel)
console.print()
response = typer.prompt("Confirm [y/N]", default="n")
return response.lower() == "y"
def _print_banner(target_url: str) -> None:
"""Print application banner."""
from rich.panel import Panel
banner = f"""[bold cyan]PIT - Prompt Injection Tester[/bold cyan] v{__version__}
[cyan]Target:[/cyan] {target_url}
[cyan]Authorization:[/cyan] ✓ Confirmed
"""
console.print()
console.print(Panel(banner, border_style="cyan", expand=False))
def _list_patterns(registry) -> None:
"""List available attack patterns."""
from core.models import AttackCategory
from rich.table import Table
if len(registry) == 0:
registry.load_builtin_patterns()
table = Table(title="Available Attack Patterns", show_header=True)
table.add_column("Category", style="cyan")
table.add_column("Pattern ID", style="yellow")
table.add_column("Description", style="white")
for category in AttackCategory:
patterns = registry.list_by_category(category)
for i, pid in enumerate(patterns):
pattern_class = registry.get(pid)
if pattern_class:
table.add_row(
category.value if i == 0 else "",
pid,
getattr(pattern_class, "name", pid),
)
console.print(table)
def _list_categories() -> None:
"""List attack categories."""
from core.models import AttackCategory
from rich.table import Table
table = Table(title="Attack Categories", show_header=True)
table.add_column("Category", style="cyan")
table.add_column("Description", style="white")
for category in AttackCategory:
table.add_row(category.value, f"{category.value} injection attacks")
console.print(table)
if __name__ == "__main__":
app()

View File

@@ -0,0 +1,7 @@
"""
CLI command modules.
"""
from pit.commands import scan
__all__ = ["scan"]

View File

@@ -0,0 +1,20 @@
"""Configuration module."""
from .loader import load_config, load_config_file
from .schema import (
AttackConfig,
AuthorizationConfig,
Config,
ReportingConfig,
TargetConfig,
)
__all__ = [
"Config",
"TargetConfig",
"AttackConfig",
"ReportingConfig",
"AuthorizationConfig",
"load_config",
"load_config_file",
]

View File

@@ -0,0 +1,82 @@
"""Configuration file loader."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional
import yaml
from .schema import Config
def load_config_file(config_path: Path) -> Config:
"""
Load configuration from a YAML file.
Args:
config_path: Path to YAML configuration file
Returns:
Loaded configuration
Raises:
FileNotFoundError: If config file doesn't exist
ValueError: If config file is invalid
"""
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_path) as f:
data = yaml.safe_load(f)
# Expand environment variables in token
if "target" in data and "token" in data["target"]:
token = data["target"]["token"]
if isinstance(token, str) and token.startswith("${") and token.endswith("}"):
env_var = token[2:-1]
data["target"]["token"] = os.getenv(env_var)
try:
return Config(**data)
except Exception as e:
raise ValueError(f"Invalid configuration file: {e}") from e
def load_config(
config_path: Optional[Path],
target_url: Optional[str] = None,
**cli_overrides,
) -> Config:
"""
Load configuration from file or CLI arguments.
Args:
config_path: Path to configuration file (optional)
target_url: Target URL from CLI (optional)
**cli_overrides: CLI argument overrides
Returns:
Merged configuration
Raises:
ValueError: If neither config_path nor target_url is provided
"""
if config_path:
config = load_config_file(config_path)
# Apply CLI overrides
if target_url:
config.target.url = target_url
if "token" in cli_overrides and cli_overrides["token"]:
config.target.token = cli_overrides["token"]
return config
elif target_url:
# Build config from CLI args
return Config.from_cli_args(target_url=target_url, **cli_overrides)
else:
raise ValueError("Either --config or target URL must be provided")

View File

@@ -0,0 +1,148 @@
"""Configuration schema using Pydantic."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
class TargetConfig(BaseModel):
"""Target API configuration."""
url: str = Field(..., description="Target API endpoint URL")
token: Optional[str] = Field(None, description="Authentication token")
api_type: str = Field(default="openai", description="API type (openai, anthropic, custom)")
timeout: int = Field(default=30, ge=1, le=300, description="Request timeout in seconds")
@field_validator("url")
@classmethod
def validate_url(cls, v: str) -> str:
"""Ensure URL is properly formatted."""
if not v.startswith(("http://", "https://")):
raise ValueError("URL must start with http:// or https://")
return v
class AttackConfig(BaseModel):
"""Attack execution configuration."""
categories: List[str] = Field(
default=["direct", "indirect"],
description="Attack categories to test",
)
patterns: Optional[List[str]] = Field(
None,
description="Specific pattern IDs to test (overrides categories)",
)
exclude_patterns: List[str] = Field(
default_factory=list,
description="Pattern IDs to exclude",
)
max_concurrent: int = Field(
default=5,
ge=1,
le=20,
description="Maximum concurrent requests",
)
rate_limit: float = Field(
default=1.0,
ge=0.1,
le=10.0,
description="Requests per second",
)
class ReportingConfig(BaseModel):
"""Report generation configuration."""
format: str = Field(
default="json",
description="Report format (json, yaml, html)",
)
output: Optional[Path] = Field(
None,
description="Output file path (auto-generated if not specified)",
)
include_cvss: bool = Field(
default=True,
description="Include CVSS scores in report",
)
include_payloads: bool = Field(
default=False,
description="Include attack payloads in report (may be sensitive)",
)
@field_validator("format")
@classmethod
def validate_format(cls, v: str) -> str:
"""Ensure format is supported."""
allowed = ["json", "yaml", "html"]
if v not in allowed:
raise ValueError(f"Format must be one of {allowed}")
return v
class AuthorizationConfig(BaseModel):
"""Authorization configuration."""
scope: List[str] = Field(
default=["all"],
description="Authorization scope",
)
confirmed: bool = Field(
default=False,
description="Skip interactive authorization prompt",
)
class Config(BaseModel):
"""Complete PIT configuration."""
target: TargetConfig
attack: AttackConfig = Field(default_factory=AttackConfig)
reporting: ReportingConfig = Field(default_factory=ReportingConfig)
authorization: AuthorizationConfig = Field(default_factory=AuthorizationConfig)
@classmethod
def from_cli_args(
cls,
target_url: str,
token: Optional[str] = None,
api_type: str = "openai",
timeout: int = 30,
categories: Optional[List[str]] = None,
patterns: Optional[List[str]] = None,
max_concurrent: int = 5,
rate_limit: float = 1.0,
output_format: str = "json",
output: Optional[Path] = None,
include_cvss: bool = True,
include_payloads: bool = False,
authorize: bool = False,
) -> Config:
"""Create config from CLI arguments."""
return cls(
target=TargetConfig(
url=target_url,
token=token,
api_type=api_type,
timeout=timeout,
),
attack=AttackConfig(
categories=categories or ["direct", "indirect"],
patterns=patterns,
max_concurrent=max_concurrent,
rate_limit=rate_limit,
),
reporting=ReportingConfig(
format=output_format,
output=output,
include_cvss=include_cvss,
include_payloads=include_payloads,
),
authorization=AuthorizationConfig(
confirmed=authorize,
),
)

View File

@@ -0,0 +1,47 @@
"""Error handling module."""
from .exceptions import (
AttackError,
AuthenticationError,
ConfigError,
ConfigNotFoundError,
ConfigValidationError,
DetectionFailedError,
DiscoveryError,
ExecutionError,
FileWriteError,
FormatError,
NoEndpointsFoundError,
PatternLoadError,
PitError,
RateLimitError,
ReportingError,
ScanFailedError,
TargetError,
TargetUnreachableError,
VerificationError,
)
from .handlers import handle_error
__all__ = [
"PitError",
"ConfigError",
"ConfigNotFoundError",
"ConfigValidationError",
"TargetError",
"TargetUnreachableError",
"AuthenticationError",
"RateLimitError",
"DiscoveryError",
"NoEndpointsFoundError",
"ScanFailedError",
"AttackError",
"PatternLoadError",
"ExecutionError",
"VerificationError",
"DetectionFailedError",
"ReportingError",
"FormatError",
"FileWriteError",
"handle_error",
]

View File

@@ -0,0 +1,131 @@
"""Custom exception hierarchy for PIT."""
from __future__ import annotations
class PitError(Exception):
"""Base exception for all PIT errors."""
pass
# Configuration Errors
class ConfigError(PitError):
"""Base configuration error."""
pass
class ConfigNotFoundError(ConfigError):
"""Configuration file not found."""
pass
class ConfigValidationError(ConfigError):
"""Configuration validation failed."""
pass
# Target Errors
class TargetError(PitError):
"""Base target error."""
pass
class TargetUnreachableError(TargetError):
"""Target is unreachable."""
def __init__(self, url: str, reason: str = ""):
self.url = url
self.reason = reason
super().__init__(f"Target unreachable: {url}")
class AuthenticationError(TargetError):
"""Authentication failed."""
def __init__(self, status_code: int, message: str = ""):
self.status_code = status_code
self.message = message
super().__init__(f"Authentication failed: {status_code}")
class RateLimitError(TargetError):
"""Rate limit exceeded."""
def __init__(self, retry_after: int = 60):
self.retry_after = retry_after
super().__init__(f"Rate limit exceeded, retry after {retry_after}s")
# Discovery Errors
class DiscoveryError(PitError):
"""Base discovery error."""
pass
class NoEndpointsFoundError(DiscoveryError):
"""No endpoints found during discovery."""
pass
class ScanFailedError(DiscoveryError):
"""Discovery scan failed."""
pass
# Attack Errors
class AttackError(PitError):
"""Base attack error."""
pass
class PatternLoadError(AttackError):
"""Failed to load attack pattern."""
pass
class ExecutionError(AttackError):
"""Attack execution failed."""
pass
# Verification Errors
class VerificationError(PitError):
"""Base verification error."""
pass
class DetectionFailedError(VerificationError):
"""Detection/verification failed."""
pass
# Reporting Errors
class ReportingError(PitError):
"""Base reporting error."""
pass
class FormatError(ReportingError):
"""Invalid report format."""
pass
class FileWriteError(ReportingError):
"""Failed to write report file."""
pass

View File

@@ -0,0 +1,84 @@
"""User-friendly error handlers."""
from __future__ import annotations
from rich.console import Console
from .exceptions import (
AuthenticationError,
ConfigError,
NoEndpointsFoundError,
PitError,
RateLimitError,
TargetUnreachableError,
)
console = Console()
def handle_error(error: Exception, verbose: bool = False) -> int:
"""
Convert exceptions to user-friendly messages.
Args:
error: The exception to handle
verbose: Show detailed error information
Returns:
Exit code (0 for success, non-zero for error)
"""
if isinstance(error, TargetUnreachableError):
console.print("[red]✗ Target Unreachable[/red]")
console.print(f" ├─ Could not connect to: {error.url}")
if error.reason:
console.print(f" ├─ Reason: {error.reason}")
console.print(" ├─ Suggestion: Check the URL and network connection")
console.print(" └─ Or use: [cyan]pit scan --skip-discovery --injection-points <file>[/cyan]")
return 1
elif isinstance(error, AuthenticationError):
console.print("[red]✗ Authentication Failed[/red]")
console.print(f" ├─ Status code: {error.status_code}")
if error.message:
console.print(f" ├─ Message: {error.message}")
console.print(" ├─ Suggestion: Verify your API token with [cyan]--token[/cyan]")
console.print(" └─ Or run: [cyan]pit auth <url>[/cyan]")
return 1
elif isinstance(error, RateLimitError):
console.print("[yellow]⚠ Rate Limited[/yellow]")
console.print(f" ├─ Retry after: {error.retry_after} seconds")
console.print(" └─ Suggestion: Reduce [cyan]--rate-limit[/cyan] or [cyan]--max-concurrent[/cyan]")
return 1
elif isinstance(error, NoEndpointsFoundError):
console.print("[yellow]⚠ No Endpoints Found[/yellow]")
console.print(" ├─ Discovery phase found no injection points")
console.print(" ├─ Suggestion 1: Specify endpoints manually with [cyan]--injection-points <file>[/cyan]")
console.print(" └─ Suggestion 2: Check if the target URL is correct")
return 1
elif isinstance(error, ConfigError):
console.print("[red]✗ Configuration Error[/red]")
console.print(f" ├─ {str(error)}")
console.print(" └─ Suggestion: Check your configuration file or CLI arguments")
return 1
elif isinstance(error, PitError):
console.print(f"[red]✗ Error: {error.__class__.__name__}[/red]")
console.print(f" └─ {str(error)}")
return 1
elif isinstance(error, KeyboardInterrupt):
console.print("\n[yellow]⚠ Scan Interrupted[/yellow]")
console.print(" └─ Use [cyan]Ctrl+C[/cyan] again to force quit")
return 130
else:
console.print(f"[red]✗ Unexpected Error: {error.__class__.__name__}[/red]")
console.print(f" └─ {str(error)}")
if verbose:
console.print_exception()
else:
console.print(" └─ Run with [cyan]--verbose[/cyan] for details")
return 1

View File

@@ -0,0 +1,8 @@
"""
Orchestrator layer for coordinating CLI and core framework.
"""
from pit.orchestrator.workflow import WorkflowOrchestrator
from pit.orchestrator.discovery import AutoDiscovery
__all__ = ["WorkflowOrchestrator", "AutoDiscovery"]

View File

@@ -0,0 +1,495 @@
"""Phase definitions for the sequential pipeline."""
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from rich.console import Console
if TYPE_CHECKING:
from pit.orchestrator.pipeline import PipelineContext
console = Console()
class PhaseStatus(Enum):
"""Phase execution status."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class PhaseResult:
"""Result from a phase execution."""
status: PhaseStatus
message: Optional[str] = None
error: Optional[str] = None
data: Optional[Dict[str, Any]] = None
class Phase(ABC):
"""
Abstract base class for pipeline phases.
Each phase MUST complete fully before returning.
"""
name: str = "Base Phase"
@abstractmethod
async def execute(self, context: PipelineContext) -> PhaseResult:
"""
Execute the phase.
MUST complete before returning (no background tasks).
Args:
context: Pipeline context (shared state)
Returns:
Phase result with status and optional data
Raises:
Exception: If phase fails unexpectedly
"""
pass
class DiscoveryPhase(Phase):
"""
Phase 1: Discovery
Scans the target for injection points.
"""
name = "Discovery"
async def execute(self, context: PipelineContext) -> PhaseResult:
"""
Execute discovery phase.
Args:
context: Pipeline context
Returns:
Phase result with discovered injection points
"""
from pit.ui.progress import create_spinner
console.print("[cyan]Discovering injection points...[/cyan]")
try:
with create_spinner():
task = asyncio.create_task(
self._discover_injection_points(context)
)
# WAIT for discovery to complete
injection_points = await task
context.injection_points = injection_points
if not injection_points:
return PhaseResult(
status=PhaseStatus.FAILED,
error="No injection points found",
)
message = (
f"Found {len(injection_points)} injection point(s)"
)
return PhaseResult(
status=PhaseStatus.COMPLETED,
message=message,
data={"injection_points": injection_points},
)
except Exception as e:
return PhaseResult(
status=PhaseStatus.FAILED,
error=str(e),
)
async def _discover_injection_points(
self, context: PipelineContext
) -> List[Any]:
"""
Run discovery logic.
TODO: Implement actual discovery logic using discovery module.
Args:
context: Pipeline context
Returns:
List of discovered injection points
"""
# Placeholder: Replace with actual discovery implementation
from core.models import InjectionPoint, InjectionPointType
# Simulate discovery
await asyncio.sleep(1)
# Mock injection points for now
return [
InjectionPoint(
id="param_prompt",
type=InjectionPointType.PARAMETER,
name="prompt",
location="body",
)
]
class AttackPhase(Phase):
"""
Phase 2: Attack Execution
Executes attack patterns against discovered injection points.
"""
name = "Attack Execution"
async def execute(self, context: PipelineContext) -> PhaseResult:
"""
Execute attack phase.
Args:
context: Pipeline context
Returns:
Phase result with test results
"""
from pit.ui.progress import create_progress_bar
injection_points = context.injection_points
if not injection_points:
return PhaseResult(
status=PhaseStatus.FAILED,
error="No injection points available",
)
try:
# Load attack patterns
patterns = await self._load_patterns(context)
console.print(f"[cyan]Loaded {len(patterns)} attack pattern(s)[/cyan]")
# Execute attacks with progress bar
results = []
with create_progress_bar() as progress:
task = progress.add_task(
"Running attacks",
total=len(patterns),
)
for pattern in patterns:
# Execute attack (internal concurrency OK)
result = await self._execute_attack(
pattern, injection_points[0], context
)
results.append(result)
progress.update(task, advance=1)
# Respect rate limiting
await asyncio.sleep(1.0 / context.config.attack.rate_limit)
context.test_results = results
message = f"Completed {len(results)} attack(s)"
return PhaseResult(
status=PhaseStatus.COMPLETED,
message=message,
data={"test_results": results},
)
except Exception as e:
return PhaseResult(
status=PhaseStatus.FAILED,
error=str(e),
)
async def _load_patterns(self, context: PipelineContext) -> List[Any]:
"""
Load attack patterns from registry.
Args:
context: Pipeline context
Returns:
List of attack patterns
"""
from patterns.registry import registry
# Load patterns based on config
categories = context.config.attack.categories
all_patterns = []
for category in categories:
patterns = registry.list_by_category(category)
all_patterns.extend(patterns)
# Return pattern IDs for now
# TODO: Return actual pattern instances
return all_patterns[:10] # Limit for demo
async def _execute_attack(
self, pattern: Any, injection_point: Any, context: PipelineContext
) -> Any:
"""
Execute a single attack pattern.
Args:
pattern: Attack pattern
injection_point: Target injection point
context: Pipeline context
Returns:
Test result
"""
from core.models import TestResult, TestStatus
# Simulate attack execution
await asyncio.sleep(0.1)
# Mock result
return TestResult(
pattern_id=str(pattern),
injection_point_id=injection_point.id,
status=TestStatus.SUCCESS,
payload="test_payload",
response=None,
)
class VerificationPhase(Phase):
"""
Phase 3: Verification
Analyzes attack responses to verify success.
"""
name = "Verification"
async def execute(self, context: PipelineContext) -> PhaseResult:
"""
Execute verification phase.
Args:
context: Pipeline context
Returns:
Phase result with verified results
"""
from pit.ui.progress import create_spinner
test_results = context.test_results
if not test_results:
return PhaseResult(
status=PhaseStatus.FAILED,
error="No test results to verify",
)
console.print("[cyan]Analyzing responses...[/cyan]")
try:
with create_spinner():
# WAIT for verification to complete
verified = await self._verify_results(test_results, context)
context.verified_results = verified
successful = sum(1 for r in verified if r.get("status") == "success")
message = f"Verified {len(verified)} result(s), {successful} successful"
return PhaseResult(
status=PhaseStatus.COMPLETED,
message=message,
data={"verified_results": verified},
)
except Exception as e:
return PhaseResult(
status=PhaseStatus.FAILED,
error=str(e),
)
async def _verify_results(
self, test_results: List[Any], context: PipelineContext
) -> List[Dict[str, Any]]:
"""
Verify test results.
Args:
test_results: List of test results
context: Pipeline context
Returns:
List of verified results with confidence scores
"""
# Simulate verification
await asyncio.sleep(1)
# Mock verified results
verified = []
for result in test_results:
verified.append({
"pattern_id": result.pattern_id,
"status": "success" if "test" in result.pattern_id else "failed",
"severity": "medium",
"confidence": 0.85,
})
return verified
class ReportingPhase(Phase):
"""
Phase 4: Reporting
Generates and saves the final report.
"""
name = "Report Generation"
async def execute(self, context: PipelineContext) -> PhaseResult:
"""
Execute reporting phase.
Args:
context: Pipeline context
Returns:
Phase result with report path
"""
verified_results = context.verified_results
if not verified_results:
return PhaseResult(
status=PhaseStatus.FAILED,
error="No results to report",
)
try:
# Generate report
report = await self._generate_report(verified_results, context)
context.report = report
# Save to file
report_path = await self._save_report(report, context)
context.report_path = report_path
# Display summary
self._display_summary(report, report_path)
message = f"Report saved to {report_path}"
return PhaseResult(
status=PhaseStatus.COMPLETED,
message=message,
data={"report_path": str(report_path)},
)
except Exception as e:
return PhaseResult(
status=PhaseStatus.FAILED,
error=str(e),
)
async def _generate_report(
self, verified_results: List[Dict[str, Any]], context: PipelineContext
) -> Dict[str, Any]:
"""
Generate report data structure.
Args:
verified_results: Verified results
context: Pipeline context
Returns:
Report dictionary
"""
successful = [r for r in verified_results if r.get("status") == "success"]
total = len(verified_results)
success_rate = len(successful) / total if total > 0 else 0.0
report = {
"metadata": {
"version": "2.0.0",
"timestamp": datetime.now().isoformat(),
"target": context.target_url,
"duration_seconds": sum(context.phase_durations.values()),
},
"summary": {
"total_tests": total,
"successful_attacks": len(successful),
"success_rate": success_rate,
},
"results": verified_results,
}
return report
async def _save_report(
self, report: Dict[str, Any], context: PipelineContext
) -> Path:
"""
Save report to file.
Args:
report: Report data
context: Pipeline context
Returns:
Path to saved report
"""
import json
output_path = context.config.reporting.output
if not output_path:
# Auto-generate filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(f"pit_report_{timestamp}.json")
with open(output_path, "w") as f:
json.dump(report, f, indent=2)
return output_path
def _display_summary(
self, report: Dict[str, Any], report_path: Path
) -> None:
"""
Display summary to console.
Args:
report: Report data
report_path: Path to saved report
"""
from pit.ui.tables import create_summary_panel
summary = report["summary"]
panel = create_summary_panel(
total_tests=summary["total_tests"],
successful_attacks=summary["successful_attacks"],
success_rate=summary["success_rate"],
vulnerabilities_by_severity={},
report_path=str(report_path),
)
console.print()
console.print(panel)

View File

@@ -0,0 +1,157 @@
"""Sequential pipeline executor for running phases one-by-one."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from rich.console import Console
from rich.panel import Panel
from pit.config import Config
from pit.errors import PitError
from pit.orchestrator.phases import Phase, PhaseResult, PhaseStatus
console = Console()
@dataclass
class PipelineContext:
"""
Shared state across pipeline phases.
This context is passed through each phase sequentially,
allowing phases to read input from previous phases.
"""
target_url: str
config: Config
# Phase 1 output (Discovery)
injection_points: List[Any] = field(default_factory=list)
# Phase 2 output (Attack)
test_results: List[Any] = field(default_factory=list)
# Phase 3 output (Verification)
verified_results: List[Any] = field(default_factory=list)
# Phase 4 output (Reporting)
report: Optional[Dict[str, Any]] = None
report_path: Optional[Path] = None
# Metadata
start_time: datetime = field(default_factory=datetime.now)
phase_durations: Dict[str, float] = field(default_factory=dict)
# Interrupt handling
interrupted: bool = False
class Pipeline:
"""
Sequential phase executor.
Runs phases one-by-one, waiting for each to complete before
starting the next. This ensures no concurrency errors.
"""
def __init__(self, phases: List[Phase]):
"""
Initialize pipeline with phases.
Args:
phases: List of phases to execute sequentially
"""
self.phases = phases
async def run(self, context: PipelineContext) -> PipelineContext:
"""
Run all phases sequentially.
Each phase MUST complete before the next begins.
Args:
context: Pipeline context (shared state)
Returns:
Updated context with results
Raises:
PitError: If any phase fails
KeyboardInterrupt: If user interrupts
"""
total_phases = len(self.phases)
try:
for i, phase in enumerate(self.phases, start=1):
self._print_phase_header(i, total_phases, phase.name)
# CRITICAL: Wait for phase to complete before continuing
phase_start = datetime.now()
result = await phase.execute(context)
phase_duration = (datetime.now() - phase_start).total_seconds()
context.phase_durations[phase.name] = phase_duration
if result.status == PhaseStatus.FAILED:
self._handle_phase_failure(phase, result)
break
self._print_phase_success(phase.name, result)
except KeyboardInterrupt:
console.print("\n[yellow]⚠ Scan Interrupted by User[/yellow]")
context.interrupted = True
raise
except Exception as e:
console.print(f"\n[red]✗ Pipeline Error: {e.__class__.__name__}[/red]")
raise
return context
def _print_phase_header(self, phase_num: int, total: int, name: str) -> None:
"""Print phase header."""
header = f"[{phase_num}/{total}] {name}"
console.print()
console.print(Panel(header, border_style="cyan", expand=False))
def _print_phase_success(self, name: str, result: PhaseResult) -> None:
"""Print phase success message."""
console.print(f"[green]✓ {name} Complete[/green]")
if result.message:
console.print(f" └─ {result.message}")
def _handle_phase_failure(self, phase: Phase, result: PhaseResult) -> None:
"""Handle phase failure."""
console.print(f"[red]✗ {phase.name} Failed[/red]")
if result.error:
console.print(f" └─ {result.error}")
async def create_default_pipeline() -> Pipeline:
"""
Create the default 4-phase pipeline.
Returns:
Pipeline with Discovery, Attack, Verification, and Reporting phases
"""
from pit.orchestrator.phases import (
AttackPhase,
DiscoveryPhase,
ReportingPhase,
VerificationPhase,
)
return Pipeline(
phases=[
DiscoveryPhase(),
AttackPhase(),
VerificationPhase(),
ReportingPhase(),
]
)

View File

@@ -0,0 +1,24 @@
"""
Rich UI components for the CLI interface.
"""
from pit.ui.console import console
from pit.ui.progress import create_progress_bar, create_spinner
from pit.ui.display import (
print_banner,
print_success,
print_error,
print_warning,
print_info,
)
__all__ = [
"console",
"create_progress_bar",
"create_spinner",
"print_banner",
"print_success",
"print_error",
"print_warning",
"print_info",
]

View File

@@ -0,0 +1,34 @@
"""Spinner animations for async operations."""
from __future__ import annotations
from contextlib import contextmanager
from typing import Iterator, Optional
from rich.console import Console
from rich.spinner import Spinner
from rich.live import Live
console = Console()
@contextmanager
def show_spinner(message: str, spinner_type: str = "dots") -> Iterator[None]:
"""
Show a spinner during an operation.
Args:
message: Message to display next to spinner
spinner_type: Type of spinner animation
Yields:
None
Example:
with show_spinner("Discovering injection points..."):
await discover()
"""
spinner = Spinner(spinner_type, text=message, style="cyan")
with Live(spinner, console=console, transient=True):
yield

View File

@@ -0,0 +1,114 @@
"""Color schemes and styling."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.models import Severity, TestStatus
# Color scheme
STYLES = {
"primary": "cyan",
"success": "green",
"warning": "yellow",
"error": "red",
"info": "blue",
"muted": "dim",
}
# Severity colors
SEVERITY_COLORS = {
"critical": "red",
"high": "red",
"medium": "yellow",
"low": "blue",
"info": "green",
}
SEVERITY_ICONS = {
"critical": "🔴",
"high": "🔴",
"medium": "🟠",
"low": "🟡",
"info": "🟢",
}
# Status symbols
STATUS_SYMBOLS = {
"success": "",
"failed": "",
"error": "",
"pending": "",
"running": "",
}
def format_severity(severity: str | Severity) -> str:
"""
Format severity with color and icon.
Args:
severity: Severity level
Returns:
Formatted severity string with Rich markup
"""
if hasattr(severity, "value"):
severity = severity.value
severity_lower = severity.lower()
color = SEVERITY_COLORS.get(severity_lower, "white")
icon = SEVERITY_ICONS.get(severity_lower, "")
return f"[{color}]{icon} {severity.upper()}[/{color}]"
def format_status(status: str | TestStatus, success: bool = True) -> str:
"""
Format test status with symbol and color.
Args:
status: Status value
success: Whether the test succeeded
Returns:
Formatted status string with Rich markup
"""
if hasattr(status, "value"):
status = status.value
status_lower = status.lower()
symbol = STATUS_SYMBOLS.get(status_lower, "")
if success:
color = "green"
elif status_lower == "error":
color = "red"
else:
color = "yellow"
return f"[{color}]{symbol} {status.title()}[/{color}]"
def format_confidence(confidence: float) -> str:
"""
Format confidence score with color.
Args:
confidence: Confidence score (0.0-1.0)
Returns:
Formatted confidence string with Rich markup
"""
percentage = int(confidence * 100)
if confidence >= 0.9:
color = "green"
elif confidence >= 0.7:
color = "yellow"
else:
color = "red"
return f"[{color}]{percentage}%[/{color}]"

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "prompt-injection-tester"
version = "1.0.0"
description = "Automated prompt injection testing framework for LLMs"
version = "2.0.0"
description = "Modern CLI for automated prompt injection testing of LLMs"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "CC BY-SA 4.0"}
@@ -18,6 +18,7 @@ keywords = [
"prompt-injection",
"red-team",
"testing",
"cli",
]
classifiers = [
"Development Status :: 4 - Beta",
@@ -31,9 +32,12 @@ classifiers = [
]
dependencies = [
"aiohttp>=3.9.0",
"httpx>=0.24.0",
"pyyaml>=6.0",
"typer>=0.7.0",
"typer>=0.9.0",
"rich>=13.0.0",
"pydantic>=2.0.0",
"jinja2>=3.1.0",
]
[project.optional-dependencies]
@@ -48,7 +52,7 @@ dev = [
[project.scripts]
prompt-injection-tester = "prompt_injection_tester.cli:main"
pit = "pit.app:app"
pit = "pit.app:cli_main"
[project.urls]
"Homepage" = "https://github.com/example/ai-llm-red-team-handbook"

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
"""Test suite for the prompt injection tester."""

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
"""Tests for attack patterns."""

View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python3
"""Utility modules for the prompt injection tester."""
from .encoding import encode_payload, decode_payload, translate_payload
from .http_client import AsyncHTTPClient, LLMClient, HTTPResponse, RateLimiter
__all__ = [
"encode_payload",
"decode_payload",
"translate_payload",
"AsyncHTTPClient",
"LLMClient",
"HTTPResponse",
"RateLimiter",
]